mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 12:32:30 +08:00
Refa: make RAGFlow more asynchronous 2 (#11689)
### What problem does this PR solve? Make RAGFlow more asynchronous 2. #11551, #11579, #11619. ### Type of change - [x] Refactoring - [x] Performance Improvement
This commit is contained in:
@ -416,13 +416,19 @@ class Canvas(Graph):
|
|||||||
|
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
tasks = []
|
tasks = []
|
||||||
|
|
||||||
|
def _run_async_in_thread(coro_func, **call_kwargs):
|
||||||
|
return asyncio.run(coro_func(**call_kwargs))
|
||||||
|
|
||||||
i = f
|
i = f
|
||||||
while i < t:
|
while i < t:
|
||||||
cpn = self.get_component_obj(self.path[i])
|
cpn = self.get_component_obj(self.path[i])
|
||||||
task_fn = None
|
task_fn = None
|
||||||
|
call_kwargs = None
|
||||||
|
|
||||||
if cpn.component_name.lower() in ["begin", "userfillup"]:
|
if cpn.component_name.lower() in ["begin", "userfillup"]:
|
||||||
task_fn = partial(cpn.invoke, inputs=kwargs.get("inputs", {}))
|
call_kwargs = {"inputs": kwargs.get("inputs", {})}
|
||||||
|
task_fn = cpn.invoke
|
||||||
i += 1
|
i += 1
|
||||||
else:
|
else:
|
||||||
for _, ele in cpn.get_input_elements().items():
|
for _, ele in cpn.get_input_elements().items():
|
||||||
@ -431,13 +437,18 @@ class Canvas(Graph):
|
|||||||
t -= 1
|
t -= 1
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
task_fn = partial(cpn.invoke, **cpn.get_input())
|
call_kwargs = cpn.get_input()
|
||||||
|
task_fn = cpn.invoke
|
||||||
i += 1
|
i += 1
|
||||||
|
|
||||||
if task_fn is None:
|
if task_fn is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
tasks.append(loop.run_in_executor(self._thread_pool, task_fn))
|
invoke_async = getattr(cpn, "invoke_async", None)
|
||||||
|
if invoke_async and asyncio.iscoroutinefunction(invoke_async):
|
||||||
|
tasks.append(loop.run_in_executor(self._thread_pool, partial(_run_async_in_thread, invoke_async, **(call_kwargs or {}))))
|
||||||
|
else:
|
||||||
|
tasks.append(loop.run_in_executor(self._thread_pool, partial(task_fn, **(call_kwargs or {}))))
|
||||||
|
|
||||||
if tasks:
|
if tasks:
|
||||||
await asyncio.gather(*tasks)
|
await asyncio.gather(*tasks)
|
||||||
|
|||||||
@ -13,6 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
@ -239,6 +240,86 @@ class Agent(LLM, ToolBase):
|
|||||||
self.set_output("use_tools", use_tools)
|
self.set_output("use_tools", use_tools)
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
|
async def _invoke_async(self, **kwargs):
|
||||||
|
"""
|
||||||
|
Async entry: reuse existing logic but offload heavy sync parts via async wrappers to reduce blocking.
|
||||||
|
"""
|
||||||
|
if self.check_if_canceled("Agent processing"):
|
||||||
|
return
|
||||||
|
|
||||||
|
if kwargs.get("user_prompt"):
|
||||||
|
usr_pmt = ""
|
||||||
|
if kwargs.get("reasoning"):
|
||||||
|
usr_pmt += "\nREASONING:\n{}\n".format(kwargs["reasoning"])
|
||||||
|
if kwargs.get("context"):
|
||||||
|
usr_pmt += "\nCONTEXT:\n{}\n".format(kwargs["context"])
|
||||||
|
if usr_pmt:
|
||||||
|
usr_pmt += "\nQUERY:\n{}\n".format(str(kwargs["user_prompt"]))
|
||||||
|
else:
|
||||||
|
usr_pmt = str(kwargs["user_prompt"])
|
||||||
|
self._param.prompts = [{"role": "user", "content": usr_pmt}]
|
||||||
|
|
||||||
|
if not self.tools:
|
||||||
|
if self.check_if_canceled("Agent processing"):
|
||||||
|
return
|
||||||
|
return await asyncio.to_thread(LLM._invoke, self, **kwargs)
|
||||||
|
|
||||||
|
prompt, msg, user_defined_prompt = self._prepare_prompt_variables()
|
||||||
|
output_schema = self._get_output_schema()
|
||||||
|
schema_prompt = ""
|
||||||
|
if output_schema:
|
||||||
|
schema = json.dumps(output_schema, ensure_ascii=False, indent=2)
|
||||||
|
schema_prompt = structured_output_prompt(schema)
|
||||||
|
|
||||||
|
downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else []
|
||||||
|
ex = self.exception_handler()
|
||||||
|
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not (ex and ex["goto"]) and not output_schema:
|
||||||
|
self.set_output("content", partial(self.stream_output_with_tools_async, prompt, msg, user_defined_prompt))
|
||||||
|
return
|
||||||
|
|
||||||
|
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
|
||||||
|
use_tools = []
|
||||||
|
ans = ""
|
||||||
|
async for delta_ans, tk in self._react_with_tools_streamly_async(prompt, msg, use_tools, user_defined_prompt, schema_prompt=schema_prompt):
|
||||||
|
if self.check_if_canceled("Agent processing"):
|
||||||
|
return
|
||||||
|
ans += delta_ans
|
||||||
|
|
||||||
|
if ans.find("**ERROR**") >= 0:
|
||||||
|
logging.error(f"Agent._chat got error. response: {ans}")
|
||||||
|
if self.get_exception_default_value():
|
||||||
|
self.set_output("content", self.get_exception_default_value())
|
||||||
|
else:
|
||||||
|
self.set_output("_ERROR", ans)
|
||||||
|
return
|
||||||
|
|
||||||
|
if output_schema:
|
||||||
|
error = ""
|
||||||
|
for _ in range(self._param.max_retries + 1):
|
||||||
|
try:
|
||||||
|
def clean_formated_answer(ans: str) -> str:
|
||||||
|
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||||
|
ans = re.sub(r"^.*```json", "", ans, flags=re.DOTALL)
|
||||||
|
return re.sub(r"```\n*$", "", ans, flags=re.DOTALL)
|
||||||
|
obj = json_repair.loads(clean_formated_answer(ans))
|
||||||
|
self.set_output("structured", obj)
|
||||||
|
if use_tools:
|
||||||
|
self.set_output("use_tools", use_tools)
|
||||||
|
return obj
|
||||||
|
except Exception:
|
||||||
|
error = "The answer cannot be parsed as JSON"
|
||||||
|
ans = self._force_format_to_schema(ans, schema_prompt)
|
||||||
|
if ans.find("**ERROR**") >= 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
self.set_output("_ERROR", error)
|
||||||
|
return
|
||||||
|
|
||||||
|
self.set_output("content", ans)
|
||||||
|
if use_tools:
|
||||||
|
self.set_output("use_tools", use_tools)
|
||||||
|
return ans
|
||||||
|
|
||||||
def stream_output_with_tools(self, prompt, msg, user_defined_prompt={}):
|
def stream_output_with_tools(self, prompt, msg, user_defined_prompt={}):
|
||||||
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
|
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
|
||||||
answer_without_toolcall = ""
|
answer_without_toolcall = ""
|
||||||
@ -261,6 +342,54 @@ class Agent(LLM, ToolBase):
|
|||||||
if use_tools:
|
if use_tools:
|
||||||
self.set_output("use_tools", use_tools)
|
self.set_output("use_tools", use_tools)
|
||||||
|
|
||||||
|
async def stream_output_with_tools_async(self, prompt, msg, user_defined_prompt={}):
|
||||||
|
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
|
||||||
|
answer_without_toolcall = ""
|
||||||
|
use_tools = []
|
||||||
|
async for delta_ans, _ in self._react_with_tools_streamly_async(prompt, msg, use_tools, user_defined_prompt):
|
||||||
|
if self.check_if_canceled("Agent streaming"):
|
||||||
|
return
|
||||||
|
|
||||||
|
if delta_ans.find("**ERROR**") >= 0:
|
||||||
|
if self.get_exception_default_value():
|
||||||
|
self.set_output("content", self.get_exception_default_value())
|
||||||
|
yield self.get_exception_default_value()
|
||||||
|
else:
|
||||||
|
self.set_output("_ERROR", delta_ans)
|
||||||
|
return
|
||||||
|
answer_without_toolcall += delta_ans
|
||||||
|
yield delta_ans
|
||||||
|
|
||||||
|
self.set_output("content", answer_without_toolcall)
|
||||||
|
if use_tools:
|
||||||
|
self.set_output("use_tools", use_tools)
|
||||||
|
|
||||||
|
async def _react_with_tools_streamly_async(self, prompt, history: list[dict], use_tools, user_defined_prompt={}, schema_prompt: str = ""):
|
||||||
|
"""
|
||||||
|
Async wrapper that offloads synchronous flow to a thread, yielding results without blocking the event loop.
|
||||||
|
"""
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
queue: asyncio.Queue = asyncio.Queue()
|
||||||
|
|
||||||
|
def worker():
|
||||||
|
try:
|
||||||
|
for delta_ans, tk in self._react_with_tools_streamly(prompt, history, use_tools, user_defined_prompt, schema_prompt=schema_prompt):
|
||||||
|
asyncio.run_coroutine_threadsafe(queue.put((delta_ans, tk)), loop)
|
||||||
|
except Exception as e:
|
||||||
|
asyncio.run_coroutine_threadsafe(queue.put(e), loop)
|
||||||
|
finally:
|
||||||
|
asyncio.run_coroutine_threadsafe(queue.put(StopAsyncIteration), loop)
|
||||||
|
|
||||||
|
await asyncio.to_thread(worker)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
item = await queue.get()
|
||||||
|
if item is StopAsyncIteration:
|
||||||
|
break
|
||||||
|
if isinstance(item, Exception):
|
||||||
|
raise item
|
||||||
|
yield item
|
||||||
|
|
||||||
def _gen_citations(self, text):
|
def _gen_citations(self, text):
|
||||||
retrievals = self._canvas.get_reference()
|
retrievals = self._canvas.get_reference()
|
||||||
retrievals = {"chunks": list(retrievals["chunks"].values()), "doc_aggs": list(retrievals["doc_aggs"].values())}
|
retrievals = {"chunks": list(retrievals["chunks"].values()), "doc_aggs": list(retrievals["doc_aggs"].values())}
|
||||||
@ -433,4 +562,3 @@ Respond immediately with your final comprehensive answer.
|
|||||||
for k in self._param.inputs.keys():
|
for k in self._param.inputs.keys():
|
||||||
self._param.inputs[k]["value"] = None
|
self._param.inputs[k]["value"] = None
|
||||||
self._param.debug_inputs = {}
|
self._param.debug_inputs = {}
|
||||||
|
|
||||||
|
|||||||
@ -14,6 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
@ -445,6 +446,34 @@ class ComponentBase(ABC):
|
|||||||
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
|
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
|
||||||
return self.output()
|
return self.output()
|
||||||
|
|
||||||
|
async def invoke_async(self, **kwargs) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Async wrapper for component invocation.
|
||||||
|
Prefers coroutine `_invoke_async` if present; otherwise falls back to `_invoke`.
|
||||||
|
Handles timing and error recording consistently with `invoke`.
|
||||||
|
"""
|
||||||
|
self.set_output("_created_time", time.perf_counter())
|
||||||
|
try:
|
||||||
|
if self.check_if_canceled("Component processing"):
|
||||||
|
return
|
||||||
|
|
||||||
|
fn_async = getattr(self, "_invoke_async", None)
|
||||||
|
if fn_async and asyncio.iscoroutinefunction(fn_async):
|
||||||
|
await fn_async(**kwargs)
|
||||||
|
elif asyncio.iscoroutinefunction(self._invoke):
|
||||||
|
await self._invoke(**kwargs)
|
||||||
|
else:
|
||||||
|
await asyncio.to_thread(self._invoke, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
if self.get_exception_default_value():
|
||||||
|
self.set_exception_default_value()
|
||||||
|
else:
|
||||||
|
self.set_output("_ERROR", str(e))
|
||||||
|
logging.exception(e)
|
||||||
|
self._param.debug_inputs = {}
|
||||||
|
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
|
||||||
|
return self.output()
|
||||||
|
|
||||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|||||||
@ -13,12 +13,14 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
import threading
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Any, Generator
|
from typing import Any, Generator, AsyncGenerator
|
||||||
import json_repair
|
import json_repair
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from common.constants import LLMType
|
from common.constants import LLMType
|
||||||
@ -171,6 +173,13 @@ class LLM(ComponentBase):
|
|||||||
return self.chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs)
|
return self.chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs)
|
||||||
return self.chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs)
|
return self.chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs)
|
||||||
|
|
||||||
|
async def _generate_async(self, msg: list[dict], **kwargs) -> str:
|
||||||
|
if not self.imgs and hasattr(self.chat_mdl, "async_chat"):
|
||||||
|
return await self.chat_mdl.async_chat(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs)
|
||||||
|
if self.imgs and hasattr(self.chat_mdl, "async_chat"):
|
||||||
|
return await self.chat_mdl.async_chat(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs)
|
||||||
|
return await asyncio.to_thread(self._generate, msg, **kwargs)
|
||||||
|
|
||||||
def _generate_streamly(self, msg:list[dict], **kwargs) -> Generator[str, None, None]:
|
def _generate_streamly(self, msg:list[dict], **kwargs) -> Generator[str, None, None]:
|
||||||
ans = ""
|
ans = ""
|
||||||
last_idx = 0
|
last_idx = 0
|
||||||
@ -205,6 +214,69 @@ class LLM(ComponentBase):
|
|||||||
for txt in self.chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs):
|
for txt in self.chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs):
|
||||||
yield delta(txt)
|
yield delta(txt)
|
||||||
|
|
||||||
|
async def _generate_streamly_async(self, msg: list[dict], **kwargs) -> AsyncGenerator[str, None]:
|
||||||
|
async def delta_wrapper(txt_iter):
|
||||||
|
ans = ""
|
||||||
|
last_idx = 0
|
||||||
|
endswith_think = False
|
||||||
|
|
||||||
|
def delta(txt):
|
||||||
|
nonlocal ans, last_idx, endswith_think
|
||||||
|
delta_ans = txt[last_idx:]
|
||||||
|
ans = txt
|
||||||
|
|
||||||
|
if delta_ans.find("<think>") == 0:
|
||||||
|
last_idx += len("<think>")
|
||||||
|
return "<think>"
|
||||||
|
elif delta_ans.find("<think>") > 0:
|
||||||
|
delta_ans = txt[last_idx:last_idx + delta_ans.find("<think>")]
|
||||||
|
last_idx += delta_ans.find("<think>")
|
||||||
|
return delta_ans
|
||||||
|
elif delta_ans.endswith("</think>"):
|
||||||
|
endswith_think = True
|
||||||
|
elif endswith_think:
|
||||||
|
endswith_think = False
|
||||||
|
return "</think>"
|
||||||
|
|
||||||
|
last_idx = len(ans)
|
||||||
|
if ans.endswith("</think>"):
|
||||||
|
last_idx -= len("</think>")
|
||||||
|
return re.sub(r"(<think>|</think>)", "", delta_ans)
|
||||||
|
|
||||||
|
async for t in txt_iter:
|
||||||
|
yield delta(t)
|
||||||
|
|
||||||
|
if not self.imgs and hasattr(self.chat_mdl, "async_chat_streamly"):
|
||||||
|
async for t in delta_wrapper(self.chat_mdl.async_chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs)):
|
||||||
|
yield t
|
||||||
|
return
|
||||||
|
if self.imgs and hasattr(self.chat_mdl, "async_chat_streamly"):
|
||||||
|
async for t in delta_wrapper(self.chat_mdl.async_chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs)):
|
||||||
|
yield t
|
||||||
|
return
|
||||||
|
|
||||||
|
# fallback
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
queue: asyncio.Queue = asyncio.Queue()
|
||||||
|
|
||||||
|
def worker():
|
||||||
|
try:
|
||||||
|
for item in self._generate_streamly(msg, **kwargs):
|
||||||
|
loop.call_soon_threadsafe(queue.put_nowait, item)
|
||||||
|
except Exception as e:
|
||||||
|
loop.call_soon_threadsafe(queue.put_nowait, e)
|
||||||
|
finally:
|
||||||
|
loop.call_soon_threadsafe(queue.put_nowait, StopAsyncIteration)
|
||||||
|
|
||||||
|
threading.Thread(target=worker, daemon=True).start()
|
||||||
|
while True:
|
||||||
|
item = await queue.get()
|
||||||
|
if item is StopAsyncIteration:
|
||||||
|
break
|
||||||
|
if isinstance(item, Exception):
|
||||||
|
raise item
|
||||||
|
yield item
|
||||||
|
|
||||||
async def _stream_output_async(self, prompt, msg):
|
async def _stream_output_async(self, prompt, msg):
|
||||||
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
|
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
|
||||||
answer = ""
|
answer = ""
|
||||||
|
|||||||
@ -17,6 +17,7 @@ import logging
|
|||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
import asyncio
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import TypedDict, List, Any
|
from typing import TypedDict, List, Any
|
||||||
from agent.component.base import ComponentParamBase, ComponentBase
|
from agent.component.base import ComponentParamBase, ComponentBase
|
||||||
@ -50,10 +51,14 @@ class LLMToolPluginCallSession(ToolCallSession):
|
|||||||
def tool_call(self, name: str, arguments: dict[str, Any]) -> Any:
|
def tool_call(self, name: str, arguments: dict[str, Any]) -> Any:
|
||||||
assert name in self.tools_map, f"LLM tool {name} does not exist"
|
assert name in self.tools_map, f"LLM tool {name} does not exist"
|
||||||
st = timer()
|
st = timer()
|
||||||
if isinstance(self.tools_map[name], MCPToolCallSession):
|
tool_obj = self.tools_map[name]
|
||||||
resp = self.tools_map[name].tool_call(name, arguments, 60)
|
if isinstance(tool_obj, MCPToolCallSession):
|
||||||
|
resp = tool_obj.tool_call(name, arguments, 60)
|
||||||
else:
|
else:
|
||||||
resp = self.tools_map[name].invoke(**arguments)
|
if hasattr(tool_obj, "invoke_async") and asyncio.iscoroutinefunction(tool_obj.invoke_async):
|
||||||
|
resp = asyncio.run(tool_obj.invoke_async(**arguments))
|
||||||
|
else:
|
||||||
|
resp = asyncio.run(asyncio.to_thread(tool_obj.invoke, **arguments))
|
||||||
|
|
||||||
self.callback(name, arguments, resp, elapsed_time=timer()-st)
|
self.callback(name, arguments, resp, elapsed_time=timer()-st)
|
||||||
return resp
|
return resp
|
||||||
@ -139,6 +144,33 @@ class ToolBase(ComponentBase):
|
|||||||
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
|
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
async def invoke_async(self, **kwargs):
|
||||||
|
"""
|
||||||
|
Async wrapper for tool invocation.
|
||||||
|
If `_invoke` is a coroutine, await it directly; otherwise run in a thread to avoid blocking.
|
||||||
|
Mirrors the exception handling of `invoke`.
|
||||||
|
"""
|
||||||
|
if self.check_if_canceled("Tool processing"):
|
||||||
|
return
|
||||||
|
|
||||||
|
self.set_output("_created_time", time.perf_counter())
|
||||||
|
try:
|
||||||
|
fn_async = getattr(self, "_invoke_async", None)
|
||||||
|
if fn_async and asyncio.iscoroutinefunction(fn_async):
|
||||||
|
res = await fn_async(**kwargs)
|
||||||
|
elif asyncio.iscoroutinefunction(self._invoke):
|
||||||
|
res = await self._invoke(**kwargs)
|
||||||
|
else:
|
||||||
|
res = await asyncio.to_thread(self._invoke, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
self._param.outputs["_ERROR"] = {"value": str(e)}
|
||||||
|
logging.exception(e)
|
||||||
|
res = str(e)
|
||||||
|
self._param.debug_inputs = []
|
||||||
|
|
||||||
|
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
|
||||||
|
return res
|
||||||
|
|
||||||
def _retrieve_chunks(self, res_list: list, get_title, get_url, get_content, get_score=None):
|
def _retrieve_chunks(self, res_list: list, get_title, get_url, get_content, get_score=None):
|
||||||
chunks = []
|
chunks = []
|
||||||
aggs = []
|
aggs = []
|
||||||
|
|||||||
@ -13,6 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
import asyncio
|
||||||
import datetime
|
import datetime
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
@ -147,31 +148,35 @@ async def set():
|
|||||||
d["available_int"] = req["available_int"]
|
d["available_int"] = req["available_int"]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
def _set_sync():
|
||||||
if not tenant_id:
|
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
||||||
return get_data_error_result(message="Tenant not found!")
|
if not tenant_id:
|
||||||
|
return get_data_error_result(message="Tenant not found!")
|
||||||
|
|
||||||
embd_id = DocumentService.get_embd_id(req["doc_id"])
|
embd_id = DocumentService.get_embd_id(req["doc_id"])
|
||||||
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_id)
|
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_id)
|
||||||
|
|
||||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||||
if not e:
|
if not e:
|
||||||
return get_data_error_result(message="Document not found!")
|
return get_data_error_result(message="Document not found!")
|
||||||
|
|
||||||
if doc.parser_id == ParserType.QA:
|
_d = d
|
||||||
arr = [
|
if doc.parser_id == ParserType.QA:
|
||||||
t for t in re.split(
|
arr = [
|
||||||
r"[\n\t]",
|
t for t in re.split(
|
||||||
req["content_with_weight"]) if len(t) > 1]
|
r"[\n\t]",
|
||||||
q, a = rmPrefix(arr[0]), rmPrefix("\n".join(arr[1:]))
|
req["content_with_weight"]) if len(t) > 1]
|
||||||
d = beAdoc(d, q, a, not any(
|
q, a = rmPrefix(arr[0]), rmPrefix("\n".join(arr[1:]))
|
||||||
[rag_tokenizer.is_chinese(t) for t in q + a]))
|
_d = beAdoc(d, q, a, not any(
|
||||||
|
[rag_tokenizer.is_chinese(t) for t in q + a]))
|
||||||
|
|
||||||
v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d.get("question_kwd") else "\n".join(d["question_kwd"])])
|
v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not _d.get("question_kwd") else "\n".join(_d["question_kwd"])])
|
||||||
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
|
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
|
||||||
d["q_%d_vec" % len(v)] = v.tolist()
|
_d["q_%d_vec" % len(v)] = v.tolist()
|
||||||
settings.docStoreConn.update({"id": req["chunk_id"]}, d, search.index_name(tenant_id), doc.kb_id)
|
settings.docStoreConn.update({"id": req["chunk_id"]}, _d, search.index_name(tenant_id), doc.kb_id)
|
||||||
return get_json_result(data=True)
|
return get_json_result(data=True)
|
||||||
|
|
||||||
|
return await asyncio.to_thread(_set_sync)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
@ -182,16 +187,19 @@ async def set():
|
|||||||
async def switch():
|
async def switch():
|
||||||
req = await get_request_json()
|
req = await get_request_json()
|
||||||
try:
|
try:
|
||||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
def _switch_sync():
|
||||||
if not e:
|
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||||
return get_data_error_result(message="Document not found!")
|
if not e:
|
||||||
for cid in req["chunk_ids"]:
|
return get_data_error_result(message="Document not found!")
|
||||||
if not settings.docStoreConn.update({"id": cid},
|
for cid in req["chunk_ids"]:
|
||||||
{"available_int": int(req["available_int"])},
|
if not settings.docStoreConn.update({"id": cid},
|
||||||
search.index_name(DocumentService.get_tenant_id(req["doc_id"])),
|
{"available_int": int(req["available_int"])},
|
||||||
doc.kb_id):
|
search.index_name(DocumentService.get_tenant_id(req["doc_id"])),
|
||||||
return get_data_error_result(message="Index updating failure")
|
doc.kb_id):
|
||||||
return get_json_result(data=True)
|
return get_data_error_result(message="Index updating failure")
|
||||||
|
return get_json_result(data=True)
|
||||||
|
|
||||||
|
return await asyncio.to_thread(_switch_sync)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
@ -202,20 +210,23 @@ async def switch():
|
|||||||
async def rm():
|
async def rm():
|
||||||
req = await get_request_json()
|
req = await get_request_json()
|
||||||
try:
|
try:
|
||||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
def _rm_sync():
|
||||||
if not e:
|
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||||
return get_data_error_result(message="Document not found!")
|
if not e:
|
||||||
if not settings.docStoreConn.delete({"id": req["chunk_ids"]},
|
return get_data_error_result(message="Document not found!")
|
||||||
search.index_name(DocumentService.get_tenant_id(req["doc_id"])),
|
if not settings.docStoreConn.delete({"id": req["chunk_ids"]},
|
||||||
doc.kb_id):
|
search.index_name(DocumentService.get_tenant_id(req["doc_id"])),
|
||||||
return get_data_error_result(message="Chunk deleting failure")
|
doc.kb_id):
|
||||||
deleted_chunk_ids = req["chunk_ids"]
|
return get_data_error_result(message="Chunk deleting failure")
|
||||||
chunk_number = len(deleted_chunk_ids)
|
deleted_chunk_ids = req["chunk_ids"]
|
||||||
DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0)
|
chunk_number = len(deleted_chunk_ids)
|
||||||
for cid in deleted_chunk_ids:
|
DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0)
|
||||||
if settings.STORAGE_IMPL.obj_exist(doc.kb_id, cid):
|
for cid in deleted_chunk_ids:
|
||||||
settings.STORAGE_IMPL.rm(doc.kb_id, cid)
|
if settings.STORAGE_IMPL.obj_exist(doc.kb_id, cid):
|
||||||
return get_json_result(data=True)
|
settings.STORAGE_IMPL.rm(doc.kb_id, cid)
|
||||||
|
return get_json_result(data=True)
|
||||||
|
|
||||||
|
return await asyncio.to_thread(_rm_sync)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
@ -245,35 +256,38 @@ async def create():
|
|||||||
d["tag_feas"] = req["tag_feas"]
|
d["tag_feas"] = req["tag_feas"]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
def _create_sync():
|
||||||
if not e:
|
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||||
return get_data_error_result(message="Document not found!")
|
if not e:
|
||||||
d["kb_id"] = [doc.kb_id]
|
return get_data_error_result(message="Document not found!")
|
||||||
d["docnm_kwd"] = doc.name
|
d["kb_id"] = [doc.kb_id]
|
||||||
d["title_tks"] = rag_tokenizer.tokenize(doc.name)
|
d["docnm_kwd"] = doc.name
|
||||||
d["doc_id"] = doc.id
|
d["title_tks"] = rag_tokenizer.tokenize(doc.name)
|
||||||
|
d["doc_id"] = doc.id
|
||||||
|
|
||||||
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
||||||
if not tenant_id:
|
if not tenant_id:
|
||||||
return get_data_error_result(message="Tenant not found!")
|
return get_data_error_result(message="Tenant not found!")
|
||||||
|
|
||||||
e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
|
e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
|
||||||
if not e:
|
if not e:
|
||||||
return get_data_error_result(message="Knowledgebase not found!")
|
return get_data_error_result(message="Knowledgebase not found!")
|
||||||
if kb.pagerank:
|
if kb.pagerank:
|
||||||
d[PAGERANK_FLD] = kb.pagerank
|
d[PAGERANK_FLD] = kb.pagerank
|
||||||
|
|
||||||
embd_id = DocumentService.get_embd_id(req["doc_id"])
|
embd_id = DocumentService.get_embd_id(req["doc_id"])
|
||||||
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING.value, embd_id)
|
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING.value, embd_id)
|
||||||
|
|
||||||
v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d["question_kwd"] else "\n".join(d["question_kwd"])])
|
v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d["question_kwd"] else "\n".join(d["question_kwd"])])
|
||||||
v = 0.1 * v[0] + 0.9 * v[1]
|
v = 0.1 * v[0] + 0.9 * v[1]
|
||||||
d["q_%d_vec" % len(v)] = v.tolist()
|
d["q_%d_vec" % len(v)] = v.tolist()
|
||||||
settings.docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id)
|
settings.docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id)
|
||||||
|
|
||||||
DocumentService.increment_chunk_num(
|
DocumentService.increment_chunk_num(
|
||||||
doc.id, doc.kb_id, c, 1, 0)
|
doc.id, doc.kb_id, c, 1, 0)
|
||||||
return get_json_result(data={"chunk_id": chunck_id})
|
return get_json_result(data={"chunk_id": chunck_id})
|
||||||
|
|
||||||
|
return await asyncio.to_thread(_create_sync)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
@ -297,25 +311,28 @@ async def retrieval_test():
|
|||||||
use_kg = req.get("use_kg", False)
|
use_kg = req.get("use_kg", False)
|
||||||
top = int(req.get("top_k", 1024))
|
top = int(req.get("top_k", 1024))
|
||||||
langs = req.get("cross_languages", [])
|
langs = req.get("cross_languages", [])
|
||||||
tenant_ids = []
|
user_id = current_user.id
|
||||||
|
|
||||||
if req.get("search_id", ""):
|
def _retrieval_sync():
|
||||||
search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {})
|
local_doc_ids = list(doc_ids) if doc_ids else []
|
||||||
meta_data_filter = search_config.get("meta_data_filter", {})
|
tenant_ids = []
|
||||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
|
||||||
if meta_data_filter.get("method") == "auto":
|
|
||||||
chat_mdl = LLMBundle(current_user.id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
|
|
||||||
filters: dict = gen_meta_filter(chat_mdl, metas, question)
|
|
||||||
doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
|
||||||
if not doc_ids:
|
|
||||||
doc_ids = None
|
|
||||||
elif meta_data_filter.get("method") == "manual":
|
|
||||||
doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and")))
|
|
||||||
if meta_data_filter["manual"] and not doc_ids:
|
|
||||||
doc_ids = ["-999"]
|
|
||||||
|
|
||||||
try:
|
if req.get("search_id", ""):
|
||||||
tenants = UserTenantService.query(user_id=current_user.id)
|
search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {})
|
||||||
|
meta_data_filter = search_config.get("meta_data_filter", {})
|
||||||
|
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||||
|
if meta_data_filter.get("method") == "auto":
|
||||||
|
chat_mdl = LLMBundle(user_id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
|
||||||
|
filters: dict = gen_meta_filter(chat_mdl, metas, question)
|
||||||
|
local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
||||||
|
if not local_doc_ids:
|
||||||
|
local_doc_ids = None
|
||||||
|
elif meta_data_filter.get("method") == "manual":
|
||||||
|
local_doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and")))
|
||||||
|
if meta_data_filter["manual"] and not local_doc_ids:
|
||||||
|
local_doc_ids = ["-999"]
|
||||||
|
|
||||||
|
tenants = UserTenantService.query(user_id=user_id)
|
||||||
for kb_id in kb_ids:
|
for kb_id in kb_ids:
|
||||||
for tenant in tenants:
|
for tenant in tenants:
|
||||||
if KnowledgebaseService.query(
|
if KnowledgebaseService.query(
|
||||||
@ -331,8 +348,9 @@ async def retrieval_test():
|
|||||||
if not e:
|
if not e:
|
||||||
return get_data_error_result(message="Knowledgebase not found!")
|
return get_data_error_result(message="Knowledgebase not found!")
|
||||||
|
|
||||||
|
_question = question
|
||||||
if langs:
|
if langs:
|
||||||
question = cross_languages(kb.tenant_id, None, question, langs)
|
_question = cross_languages(kb.tenant_id, None, _question, langs)
|
||||||
|
|
||||||
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
||||||
|
|
||||||
@ -342,19 +360,19 @@ async def retrieval_test():
|
|||||||
|
|
||||||
if req.get("keyword", False):
|
if req.get("keyword", False):
|
||||||
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
|
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
|
||||||
question += keyword_extraction(chat_mdl, question)
|
_question += keyword_extraction(chat_mdl, _question)
|
||||||
|
|
||||||
labels = label_question(question, [kb])
|
labels = label_question(_question, [kb])
|
||||||
ranks = settings.retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, page, size,
|
ranks = settings.retriever.retrieval(_question, embd_mdl, tenant_ids, kb_ids, page, size,
|
||||||
float(req.get("similarity_threshold", 0.0)),
|
float(req.get("similarity_threshold", 0.0)),
|
||||||
float(req.get("vector_similarity_weight", 0.3)),
|
float(req.get("vector_similarity_weight", 0.3)),
|
||||||
top,
|
top,
|
||||||
doc_ids, rerank_mdl=rerank_mdl,
|
local_doc_ids, rerank_mdl=rerank_mdl,
|
||||||
highlight=req.get("highlight", False),
|
highlight=req.get("highlight", False),
|
||||||
rank_feature=labels
|
rank_feature=labels
|
||||||
)
|
)
|
||||||
if use_kg:
|
if use_kg:
|
||||||
ck = settings.kg_retriever.retrieval(question,
|
ck = settings.kg_retriever.retrieval(_question,
|
||||||
tenant_ids,
|
tenant_ids,
|
||||||
kb_ids,
|
kb_ids,
|
||||||
embd_mdl,
|
embd_mdl,
|
||||||
@ -367,6 +385,9 @@ async def retrieval_test():
|
|||||||
ranks["labels"] = labels
|
ranks["labels"] = labels
|
||||||
|
|
||||||
return get_json_result(data=ranks)
|
return get_json_result(data=ranks)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return await asyncio.to_thread(_retrieval_sync)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if str(e).find("not_found") > 0:
|
if str(e).find("not_found") > 0:
|
||||||
return get_json_result(data=False, message='No chunk found! Check the chunk status please!',
|
return get_json_result(data=False, message='No chunk found! Check the chunk status please!',
|
||||||
|
|||||||
@ -168,10 +168,12 @@ async def _render_web_oauth_popup(flow_id: str, success: bool, message: str, sou
|
|||||||
status = "success" if success else "error"
|
status = "success" if success else "error"
|
||||||
auto_close = "window.close();" if success else ""
|
auto_close = "window.close();" if success else ""
|
||||||
escaped_message = escape(message)
|
escaped_message = escape(message)
|
||||||
|
# Drive: ragflow-google-drive-oauth
|
||||||
|
# Gmail: ragflow-gmail-oauth
|
||||||
|
payload_type = f"ragflow-{source}-oauth"
|
||||||
payload_json = json.dumps(
|
payload_json = json.dumps(
|
||||||
{
|
{
|
||||||
# TODO(google-oauth): include connector type (drive/gmail) in payload type if needed
|
"type": payload_type,
|
||||||
"type": f"ragflow-google-{source}-oauth",
|
|
||||||
"status": status,
|
"status": status,
|
||||||
"flowId": flow_id or "",
|
"flowId": flow_id or "",
|
||||||
"message": message,
|
"message": message,
|
||||||
|
|||||||
@ -462,7 +462,7 @@ async def related_questions():
|
|||||||
if "parameter" in gen_conf:
|
if "parameter" in gen_conf:
|
||||||
del gen_conf["parameter"]
|
del gen_conf["parameter"]
|
||||||
prompt = load_prompt("related_question")
|
prompt = load_prompt("related_question")
|
||||||
ans = chat_mdl.chat(
|
ans = await chat_mdl.async_chat(
|
||||||
prompt,
|
prompt,
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
|
|||||||
@ -13,6 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License
|
# limitations under the License
|
||||||
#
|
#
|
||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import os.path
|
import os.path
|
||||||
import pathlib
|
import pathlib
|
||||||
@ -72,7 +73,7 @@ async def upload():
|
|||||||
if not check_kb_team_permission(kb, current_user.id):
|
if not check_kb_team_permission(kb, current_user.id):
|
||||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||||
|
|
||||||
err, files = FileService.upload_document(kb, file_objs, current_user.id)
|
err, files = await asyncio.to_thread(FileService.upload_document, kb, file_objs, current_user.id)
|
||||||
if err:
|
if err:
|
||||||
return get_json_result(data=files, message="\n".join(err), code=RetCode.SERVER_ERROR)
|
return get_json_result(data=files, message="\n".join(err), code=RetCode.SERVER_ERROR)
|
||||||
|
|
||||||
@ -390,7 +391,7 @@ async def rm():
|
|||||||
if not DocumentService.accessible4deletion(doc_id, current_user.id):
|
if not DocumentService.accessible4deletion(doc_id, current_user.id):
|
||||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||||
|
|
||||||
errors = FileService.delete_docs(doc_ids, current_user.id)
|
errors = await asyncio.to_thread(FileService.delete_docs, doc_ids, current_user.id)
|
||||||
|
|
||||||
if errors:
|
if errors:
|
||||||
return get_json_result(data=False, message=errors, code=RetCode.SERVER_ERROR)
|
return get_json_result(data=False, message=errors, code=RetCode.SERVER_ERROR)
|
||||||
@ -403,44 +404,48 @@ async def rm():
|
|||||||
@validate_request("doc_ids", "run")
|
@validate_request("doc_ids", "run")
|
||||||
async def run():
|
async def run():
|
||||||
req = await get_request_json()
|
req = await get_request_json()
|
||||||
for doc_id in req["doc_ids"]:
|
|
||||||
if not DocumentService.accessible(doc_id, current_user.id):
|
|
||||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
|
||||||
try:
|
try:
|
||||||
kb_table_num_map = {}
|
def _run_sync():
|
||||||
for id in req["doc_ids"]:
|
for doc_id in req["doc_ids"]:
|
||||||
info = {"run": str(req["run"]), "progress": 0}
|
if not DocumentService.accessible(doc_id, current_user.id):
|
||||||
if str(req["run"]) == TaskStatus.RUNNING.value and req.get("delete", False):
|
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||||
info["progress_msg"] = ""
|
|
||||||
info["chunk_num"] = 0
|
|
||||||
info["token_num"] = 0
|
|
||||||
|
|
||||||
tenant_id = DocumentService.get_tenant_id(id)
|
kb_table_num_map = {}
|
||||||
if not tenant_id:
|
for id in req["doc_ids"]:
|
||||||
return get_data_error_result(message="Tenant not found!")
|
info = {"run": str(req["run"]), "progress": 0}
|
||||||
e, doc = DocumentService.get_by_id(id)
|
if str(req["run"]) == TaskStatus.RUNNING.value and req.get("delete", False):
|
||||||
if not e:
|
info["progress_msg"] = ""
|
||||||
return get_data_error_result(message="Document not found!")
|
info["chunk_num"] = 0
|
||||||
|
info["token_num"] = 0
|
||||||
|
|
||||||
if str(req["run"]) == TaskStatus.CANCEL.value:
|
tenant_id = DocumentService.get_tenant_id(id)
|
||||||
if str(doc.run) == TaskStatus.RUNNING.value:
|
if not tenant_id:
|
||||||
cancel_all_task_of(id)
|
return get_data_error_result(message="Tenant not found!")
|
||||||
else:
|
e, doc = DocumentService.get_by_id(id)
|
||||||
return get_data_error_result(message="Cannot cancel a task that is not in RUNNING status")
|
if not e:
|
||||||
if all([("delete" not in req or req["delete"]), str(req["run"]) == TaskStatus.RUNNING.value, str(doc.run) == TaskStatus.DONE.value]):
|
return get_data_error_result(message="Document not found!")
|
||||||
DocumentService.clear_chunk_num_when_rerun(doc.id)
|
|
||||||
|
|
||||||
DocumentService.update_by_id(id, info)
|
if str(req["run"]) == TaskStatus.CANCEL.value:
|
||||||
if req.get("delete", False):
|
if str(doc.run) == TaskStatus.RUNNING.value:
|
||||||
TaskService.filter_delete([Task.doc_id == id])
|
cancel_all_task_of(id)
|
||||||
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
else:
|
||||||
settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id)
|
return get_data_error_result(message="Cannot cancel a task that is not in RUNNING status")
|
||||||
|
if all([("delete" not in req or req["delete"]), str(req["run"]) == TaskStatus.RUNNING.value, str(doc.run) == TaskStatus.DONE.value]):
|
||||||
|
DocumentService.clear_chunk_num_when_rerun(doc.id)
|
||||||
|
|
||||||
if str(req["run"]) == TaskStatus.RUNNING.value:
|
DocumentService.update_by_id(id, info)
|
||||||
doc = doc.to_dict()
|
if req.get("delete", False):
|
||||||
DocumentService.run(tenant_id, doc, kb_table_num_map)
|
TaskService.filter_delete([Task.doc_id == id])
|
||||||
|
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
||||||
|
settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id)
|
||||||
|
|
||||||
return get_json_result(data=True)
|
if str(req["run"]) == TaskStatus.RUNNING.value:
|
||||||
|
doc_dict = doc.to_dict()
|
||||||
|
DocumentService.run(tenant_id, doc_dict, kb_table_num_map)
|
||||||
|
|
||||||
|
return get_json_result(data=True)
|
||||||
|
|
||||||
|
return await asyncio.to_thread(_run_sync)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
@ -450,45 +455,49 @@ async def run():
|
|||||||
@validate_request("doc_id", "name")
|
@validate_request("doc_id", "name")
|
||||||
async def rename():
|
async def rename():
|
||||||
req = await get_request_json()
|
req = await get_request_json()
|
||||||
if not DocumentService.accessible(req["doc_id"], current_user.id):
|
|
||||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
|
||||||
try:
|
try:
|
||||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
def _rename_sync():
|
||||||
if not e:
|
if not DocumentService.accessible(req["doc_id"], current_user.id):
|
||||||
return get_data_error_result(message="Document not found!")
|
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||||
if pathlib.Path(req["name"].lower()).suffix != pathlib.Path(doc.name.lower()).suffix:
|
|
||||||
return get_json_result(data=False, message="The extension of file can't be changed", code=RetCode.ARGUMENT_ERROR)
|
|
||||||
if len(req["name"].encode("utf-8")) > FILE_NAME_LEN_LIMIT:
|
|
||||||
return get_json_result(data=False, message=f"File name must be {FILE_NAME_LEN_LIMIT} bytes or less.", code=RetCode.ARGUMENT_ERROR)
|
|
||||||
|
|
||||||
for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id):
|
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||||
if d.name == req["name"]:
|
if not e:
|
||||||
return get_data_error_result(message="Duplicated document name in the same knowledgebase.")
|
return get_data_error_result(message="Document not found!")
|
||||||
|
if pathlib.Path(req["name"].lower()).suffix != pathlib.Path(doc.name.lower()).suffix:
|
||||||
|
return get_json_result(data=False, message="The extension of file can't be changed", code=RetCode.ARGUMENT_ERROR)
|
||||||
|
if len(req["name"].encode("utf-8")) > FILE_NAME_LEN_LIMIT:
|
||||||
|
return get_json_result(data=False, message=f"File name must be {FILE_NAME_LEN_LIMIT} bytes or less.", code=RetCode.ARGUMENT_ERROR)
|
||||||
|
|
||||||
if not DocumentService.update_by_id(req["doc_id"], {"name": req["name"]}):
|
for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id):
|
||||||
return get_data_error_result(message="Database error (Document rename)!")
|
if d.name == req["name"]:
|
||||||
|
return get_data_error_result(message="Duplicated document name in the same knowledgebase.")
|
||||||
|
|
||||||
informs = File2DocumentService.get_by_document_id(req["doc_id"])
|
if not DocumentService.update_by_id(req["doc_id"], {"name": req["name"]}):
|
||||||
if informs:
|
return get_data_error_result(message="Database error (Document rename)!")
|
||||||
e, file = FileService.get_by_id(informs[0].file_id)
|
|
||||||
FileService.update_by_id(file.id, {"name": req["name"]})
|
|
||||||
|
|
||||||
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
informs = File2DocumentService.get_by_document_id(req["doc_id"])
|
||||||
title_tks = rag_tokenizer.tokenize(req["name"])
|
if informs:
|
||||||
es_body = {
|
e, file = FileService.get_by_id(informs[0].file_id)
|
||||||
"docnm_kwd": req["name"],
|
FileService.update_by_id(file.id, {"name": req["name"]})
|
||||||
"title_tks": title_tks,
|
|
||||||
"title_sm_tks": rag_tokenizer.fine_grained_tokenize(title_tks),
|
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
||||||
}
|
title_tks = rag_tokenizer.tokenize(req["name"])
|
||||||
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
es_body = {
|
||||||
settings.docStoreConn.update(
|
"docnm_kwd": req["name"],
|
||||||
{"doc_id": req["doc_id"]},
|
"title_tks": title_tks,
|
||||||
es_body,
|
"title_sm_tks": rag_tokenizer.fine_grained_tokenize(title_tks),
|
||||||
search.index_name(tenant_id),
|
}
|
||||||
doc.kb_id,
|
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
||||||
)
|
settings.docStoreConn.update(
|
||||||
|
{"doc_id": req["doc_id"]},
|
||||||
|
es_body,
|
||||||
|
search.index_name(tenant_id),
|
||||||
|
doc.kb_id,
|
||||||
|
)
|
||||||
|
return get_json_result(data=True)
|
||||||
|
|
||||||
|
return await asyncio.to_thread(_rename_sync)
|
||||||
|
|
||||||
return get_json_result(data=True)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
@ -502,7 +511,8 @@ async def get(doc_id):
|
|||||||
return get_data_error_result(message="Document not found!")
|
return get_data_error_result(message="Document not found!")
|
||||||
|
|
||||||
b, n = File2DocumentService.get_storage_address(doc_id=doc_id)
|
b, n = File2DocumentService.get_storage_address(doc_id=doc_id)
|
||||||
response = await make_response(settings.STORAGE_IMPL.get(b, n))
|
data = await asyncio.to_thread(settings.STORAGE_IMPL.get, b, n)
|
||||||
|
response = await make_response(data)
|
||||||
|
|
||||||
ext = re.search(r"\.([^.]+)$", doc.name.lower())
|
ext = re.search(r"\.([^.]+)$", doc.name.lower())
|
||||||
ext = ext.group(1) if ext else None
|
ext = ext.group(1) if ext else None
|
||||||
@ -523,8 +533,7 @@ async def get(doc_id):
|
|||||||
async def download_attachment(attachment_id):
|
async def download_attachment(attachment_id):
|
||||||
try:
|
try:
|
||||||
ext = request.args.get("ext", "markdown")
|
ext = request.args.get("ext", "markdown")
|
||||||
data = settings.STORAGE_IMPL.get(current_user.id, attachment_id)
|
data = await asyncio.to_thread(settings.STORAGE_IMPL.get, current_user.id, attachment_id)
|
||||||
# data = settings.STORAGE_IMPL.get("eb500d50bb0411f0907561d2782adda5", attachment_id)
|
|
||||||
response = await make_response(data)
|
response = await make_response(data)
|
||||||
response.headers.set("Content-Type", CONTENT_TYPE_MAP.get(ext, f"application/{ext}"))
|
response.headers.set("Content-Type", CONTENT_TYPE_MAP.get(ext, f"application/{ext}"))
|
||||||
|
|
||||||
@ -596,7 +605,8 @@ async def get_image(image_id):
|
|||||||
if len(arr) != 2:
|
if len(arr) != 2:
|
||||||
return get_data_error_result(message="Image not found.")
|
return get_data_error_result(message="Image not found.")
|
||||||
bkt, nm = image_id.split("-")
|
bkt, nm = image_id.split("-")
|
||||||
response = await make_response(settings.STORAGE_IMPL.get(bkt, nm))
|
data = await asyncio.to_thread(settings.STORAGE_IMPL.get, bkt, nm)
|
||||||
|
response = await make_response(data)
|
||||||
response.headers.set("Content-Type", "image/JPEG")
|
response.headers.set("Content-Type", "image/JPEG")
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@ -14,6 +14,7 @@
|
|||||||
# limitations under the License
|
# limitations under the License
|
||||||
#
|
#
|
||||||
import logging
|
import logging
|
||||||
|
import asyncio
|
||||||
import os
|
import os
|
||||||
import pathlib
|
import pathlib
|
||||||
import re
|
import re
|
||||||
@ -61,9 +62,10 @@ async def upload():
|
|||||||
e, pf_folder = FileService.get_by_id(pf_id)
|
e, pf_folder = FileService.get_by_id(pf_id)
|
||||||
if not e:
|
if not e:
|
||||||
return get_data_error_result( message="Can't find this folder!")
|
return get_data_error_result( message="Can't find this folder!")
|
||||||
for file_obj in file_objs:
|
|
||||||
|
async def _handle_single_file(file_obj):
|
||||||
MAX_FILE_NUM_PER_USER: int = int(os.environ.get('MAX_FILE_NUM_PER_USER', 0))
|
MAX_FILE_NUM_PER_USER: int = int(os.environ.get('MAX_FILE_NUM_PER_USER', 0))
|
||||||
if 0 < MAX_FILE_NUM_PER_USER <= DocumentService.get_doc_count(current_user.id):
|
if 0 < MAX_FILE_NUM_PER_USER <= await asyncio.to_thread(DocumentService.get_doc_count, current_user.id):
|
||||||
return get_data_error_result( message="Exceed the maximum file number of a free user!")
|
return get_data_error_result( message="Exceed the maximum file number of a free user!")
|
||||||
|
|
||||||
# split file name path
|
# split file name path
|
||||||
@ -75,35 +77,36 @@ async def upload():
|
|||||||
file_len = len(file_obj_names)
|
file_len = len(file_obj_names)
|
||||||
|
|
||||||
# get folder
|
# get folder
|
||||||
file_id_list = FileService.get_id_list_by_id(pf_id, file_obj_names, 1, [pf_id])
|
file_id_list = await asyncio.to_thread(FileService.get_id_list_by_id, pf_id, file_obj_names, 1, [pf_id])
|
||||||
len_id_list = len(file_id_list)
|
len_id_list = len(file_id_list)
|
||||||
|
|
||||||
# create folder
|
# create folder
|
||||||
if file_len != len_id_list:
|
if file_len != len_id_list:
|
||||||
e, file = FileService.get_by_id(file_id_list[len_id_list - 1])
|
e, file = await asyncio.to_thread(FileService.get_by_id, file_id_list[len_id_list - 1])
|
||||||
if not e:
|
if not e:
|
||||||
return get_data_error_result(message="Folder not found!")
|
return get_data_error_result(message="Folder not found!")
|
||||||
last_folder = FileService.create_folder(file, file_id_list[len_id_list - 1], file_obj_names,
|
last_folder = await asyncio.to_thread(FileService.create_folder, file, file_id_list[len_id_list - 1], file_obj_names,
|
||||||
len_id_list)
|
len_id_list)
|
||||||
else:
|
else:
|
||||||
e, file = FileService.get_by_id(file_id_list[len_id_list - 2])
|
e, file = await asyncio.to_thread(FileService.get_by_id, file_id_list[len_id_list - 2])
|
||||||
if not e:
|
if not e:
|
||||||
return get_data_error_result(message="Folder not found!")
|
return get_data_error_result(message="Folder not found!")
|
||||||
last_folder = FileService.create_folder(file, file_id_list[len_id_list - 2], file_obj_names,
|
last_folder = await asyncio.to_thread(FileService.create_folder, file, file_id_list[len_id_list - 2], file_obj_names,
|
||||||
len_id_list)
|
len_id_list)
|
||||||
|
|
||||||
# file type
|
# file type
|
||||||
filetype = filename_type(file_obj_names[file_len - 1])
|
filetype = filename_type(file_obj_names[file_len - 1])
|
||||||
location = file_obj_names[file_len - 1]
|
location = file_obj_names[file_len - 1]
|
||||||
while settings.STORAGE_IMPL.obj_exist(last_folder.id, location):
|
while await asyncio.to_thread(settings.STORAGE_IMPL.obj_exist, last_folder.id, location):
|
||||||
location += "_"
|
location += "_"
|
||||||
blob = file_obj.read()
|
blob = await asyncio.to_thread(file_obj.read)
|
||||||
filename = duplicate_name(
|
filename = await asyncio.to_thread(
|
||||||
|
duplicate_name,
|
||||||
FileService.query,
|
FileService.query,
|
||||||
name=file_obj_names[file_len - 1],
|
name=file_obj_names[file_len - 1],
|
||||||
parent_id=last_folder.id)
|
parent_id=last_folder.id)
|
||||||
settings.STORAGE_IMPL.put(last_folder.id, location, blob)
|
await asyncio.to_thread(settings.STORAGE_IMPL.put, last_folder.id, location, blob)
|
||||||
file = {
|
file_data = {
|
||||||
"id": get_uuid(),
|
"id": get_uuid(),
|
||||||
"parent_id": last_folder.id,
|
"parent_id": last_folder.id,
|
||||||
"tenant_id": current_user.id,
|
"tenant_id": current_user.id,
|
||||||
@ -113,8 +116,13 @@ async def upload():
|
|||||||
"location": location,
|
"location": location,
|
||||||
"size": len(blob),
|
"size": len(blob),
|
||||||
}
|
}
|
||||||
file = FileService.insert(file)
|
inserted = await asyncio.to_thread(FileService.insert, file_data)
|
||||||
file_res.append(file.to_json())
|
return inserted.to_json()
|
||||||
|
|
||||||
|
for file_obj in file_objs:
|
||||||
|
res = await _handle_single_file(file_obj)
|
||||||
|
file_res.append(res)
|
||||||
|
|
||||||
return get_json_result(data=file_res)
|
return get_json_result(data=file_res)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
@ -242,55 +250,58 @@ async def rm():
|
|||||||
req = await get_request_json()
|
req = await get_request_json()
|
||||||
file_ids = req["file_ids"]
|
file_ids = req["file_ids"]
|
||||||
|
|
||||||
def _delete_single_file(file):
|
|
||||||
try:
|
|
||||||
if file.location:
|
|
||||||
settings.STORAGE_IMPL.rm(file.parent_id, file.location)
|
|
||||||
except Exception as e:
|
|
||||||
logging.exception(f"Fail to remove object: {file.parent_id}/{file.location}, error: {e}")
|
|
||||||
|
|
||||||
informs = File2DocumentService.get_by_file_id(file.id)
|
|
||||||
for inform in informs:
|
|
||||||
doc_id = inform.document_id
|
|
||||||
e, doc = DocumentService.get_by_id(doc_id)
|
|
||||||
if e and doc:
|
|
||||||
tenant_id = DocumentService.get_tenant_id(doc_id)
|
|
||||||
if tenant_id:
|
|
||||||
DocumentService.remove_document(doc, tenant_id)
|
|
||||||
File2DocumentService.delete_by_file_id(file.id)
|
|
||||||
|
|
||||||
FileService.delete(file)
|
|
||||||
|
|
||||||
def _delete_folder_recursive(folder, tenant_id):
|
|
||||||
sub_files = FileService.list_all_files_by_parent_id(folder.id)
|
|
||||||
for sub_file in sub_files:
|
|
||||||
if sub_file.type == FileType.FOLDER.value:
|
|
||||||
_delete_folder_recursive(sub_file, tenant_id)
|
|
||||||
else:
|
|
||||||
_delete_single_file(sub_file)
|
|
||||||
|
|
||||||
FileService.delete(folder)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for file_id in file_ids:
|
def _delete_single_file(file):
|
||||||
e, file = FileService.get_by_id(file_id)
|
try:
|
||||||
if not e or not file:
|
if file.location:
|
||||||
return get_data_error_result(message="File or Folder not found!")
|
settings.STORAGE_IMPL.rm(file.parent_id, file.location)
|
||||||
if not file.tenant_id:
|
except Exception as e:
|
||||||
return get_data_error_result(message="Tenant not found!")
|
logging.exception(f"Fail to remove object: {file.parent_id}/{file.location}, error: {e}")
|
||||||
if not check_file_team_permission(file, current_user.id):
|
|
||||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
|
||||||
|
|
||||||
if file.source_type == FileSource.KNOWLEDGEBASE:
|
informs = File2DocumentService.get_by_file_id(file.id)
|
||||||
continue
|
for inform in informs:
|
||||||
|
doc_id = inform.document_id
|
||||||
|
e, doc = DocumentService.get_by_id(doc_id)
|
||||||
|
if e and doc:
|
||||||
|
tenant_id = DocumentService.get_tenant_id(doc_id)
|
||||||
|
if tenant_id:
|
||||||
|
DocumentService.remove_document(doc, tenant_id)
|
||||||
|
File2DocumentService.delete_by_file_id(file.id)
|
||||||
|
|
||||||
if file.type == FileType.FOLDER.value:
|
FileService.delete(file)
|
||||||
_delete_folder_recursive(file, current_user.id)
|
|
||||||
continue
|
|
||||||
|
|
||||||
_delete_single_file(file)
|
def _delete_folder_recursive(folder, tenant_id):
|
||||||
|
sub_files = FileService.list_all_files_by_parent_id(folder.id)
|
||||||
|
for sub_file in sub_files:
|
||||||
|
if sub_file.type == FileType.FOLDER.value:
|
||||||
|
_delete_folder_recursive(sub_file, tenant_id)
|
||||||
|
else:
|
||||||
|
_delete_single_file(sub_file)
|
||||||
|
|
||||||
return get_json_result(data=True)
|
FileService.delete(folder)
|
||||||
|
|
||||||
|
def _rm_sync():
|
||||||
|
for file_id in file_ids:
|
||||||
|
e, file = FileService.get_by_id(file_id)
|
||||||
|
if not e or not file:
|
||||||
|
return get_data_error_result(message="File or Folder not found!")
|
||||||
|
if not file.tenant_id:
|
||||||
|
return get_data_error_result(message="Tenant not found!")
|
||||||
|
if not check_file_team_permission(file, current_user.id):
|
||||||
|
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||||
|
|
||||||
|
if file.source_type == FileSource.KNOWLEDGEBASE:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if file.type == FileType.FOLDER.value:
|
||||||
|
_delete_folder_recursive(file, current_user.id)
|
||||||
|
continue
|
||||||
|
|
||||||
|
_delete_single_file(file)
|
||||||
|
|
||||||
|
return get_json_result(data=True)
|
||||||
|
|
||||||
|
return await asyncio.to_thread(_rm_sync)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
@ -346,10 +357,10 @@ async def get(file_id):
|
|||||||
if not check_file_team_permission(file, current_user.id):
|
if not check_file_team_permission(file, current_user.id):
|
||||||
return get_json_result(data=False, message='No authorization.', code=RetCode.AUTHENTICATION_ERROR)
|
return get_json_result(data=False, message='No authorization.', code=RetCode.AUTHENTICATION_ERROR)
|
||||||
|
|
||||||
blob = settings.STORAGE_IMPL.get(file.parent_id, file.location)
|
blob = await asyncio.to_thread(settings.STORAGE_IMPL.get, file.parent_id, file.location)
|
||||||
if not blob:
|
if not blob:
|
||||||
b, n = File2DocumentService.get_storage_address(file_id=file_id)
|
b, n = File2DocumentService.get_storage_address(file_id=file_id)
|
||||||
blob = settings.STORAGE_IMPL.get(b, n)
|
blob = await asyncio.to_thread(settings.STORAGE_IMPL.get, b, n)
|
||||||
|
|
||||||
response = await make_response(blob)
|
response = await make_response(blob)
|
||||||
ext = re.search(r"\.([^.]+)$", file.name.lower())
|
ext = re.search(r"\.([^.]+)$", file.name.lower())
|
||||||
@ -444,10 +455,12 @@ async def move():
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
for file in files:
|
def _move_sync():
|
||||||
_move_entry_recursive(file, dest_folder)
|
for file in files:
|
||||||
|
_move_entry_recursive(file, dest_folder)
|
||||||
|
return get_json_result(data=True)
|
||||||
|
|
||||||
return get_json_result(data=True)
|
return await asyncio.to_thread(_move_sync)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|||||||
@ -17,6 +17,7 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
|
import asyncio
|
||||||
|
|
||||||
from quart import request
|
from quart import request
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -116,12 +117,22 @@ async def update():
|
|||||||
|
|
||||||
if kb.pagerank != req.get("pagerank", 0):
|
if kb.pagerank != req.get("pagerank", 0):
|
||||||
if req.get("pagerank", 0) > 0:
|
if req.get("pagerank", 0) > 0:
|
||||||
settings.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]},
|
await asyncio.to_thread(
|
||||||
search.index_name(kb.tenant_id), kb.id)
|
settings.docStoreConn.update,
|
||||||
|
{"kb_id": kb.id},
|
||||||
|
{PAGERANK_FLD: req["pagerank"]},
|
||||||
|
search.index_name(kb.tenant_id),
|
||||||
|
kb.id,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# Elasticsearch requires PAGERANK_FLD be non-zero!
|
# Elasticsearch requires PAGERANK_FLD be non-zero!
|
||||||
settings.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD},
|
await asyncio.to_thread(
|
||||||
search.index_name(kb.tenant_id), kb.id)
|
settings.docStoreConn.update,
|
||||||
|
{"exists": PAGERANK_FLD},
|
||||||
|
{"remove": PAGERANK_FLD},
|
||||||
|
search.index_name(kb.tenant_id),
|
||||||
|
kb.id,
|
||||||
|
)
|
||||||
|
|
||||||
e, kb = KnowledgebaseService.get_by_id(kb.id)
|
e, kb = KnowledgebaseService.get_by_id(kb.id)
|
||||||
if not e:
|
if not e:
|
||||||
@ -224,25 +235,28 @@ async def rm():
|
|||||||
data=False, message='Only owner of knowledgebase authorized for this operation.',
|
data=False, message='Only owner of knowledgebase authorized for this operation.',
|
||||||
code=RetCode.OPERATING_ERROR)
|
code=RetCode.OPERATING_ERROR)
|
||||||
|
|
||||||
for doc in DocumentService.query(kb_id=req["kb_id"]):
|
def _rm_sync():
|
||||||
if not DocumentService.remove_document(doc, kbs[0].tenant_id):
|
for doc in DocumentService.query(kb_id=req["kb_id"]):
|
||||||
|
if not DocumentService.remove_document(doc, kbs[0].tenant_id):
|
||||||
|
return get_data_error_result(
|
||||||
|
message="Database error (Document removal)!")
|
||||||
|
f2d = File2DocumentService.get_by_document_id(doc.id)
|
||||||
|
if f2d:
|
||||||
|
FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id])
|
||||||
|
File2DocumentService.delete_by_document_id(doc.id)
|
||||||
|
FileService.filter_delete(
|
||||||
|
[File.source_type == FileSource.KNOWLEDGEBASE, File.type == "folder", File.name == kbs[0].name])
|
||||||
|
if not KnowledgebaseService.delete_by_id(req["kb_id"]):
|
||||||
return get_data_error_result(
|
return get_data_error_result(
|
||||||
message="Database error (Document removal)!")
|
message="Database error (Knowledgebase removal)!")
|
||||||
f2d = File2DocumentService.get_by_document_id(doc.id)
|
for kb in kbs:
|
||||||
if f2d:
|
settings.docStoreConn.delete({"kb_id": kb.id}, search.index_name(kb.tenant_id), kb.id)
|
||||||
FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id])
|
settings.docStoreConn.deleteIdx(search.index_name(kb.tenant_id), kb.id)
|
||||||
File2DocumentService.delete_by_document_id(doc.id)
|
if hasattr(settings.STORAGE_IMPL, 'remove_bucket'):
|
||||||
FileService.filter_delete(
|
settings.STORAGE_IMPL.remove_bucket(kb.id)
|
||||||
[File.source_type == FileSource.KNOWLEDGEBASE, File.type == "folder", File.name == kbs[0].name])
|
return get_json_result(data=True)
|
||||||
if not KnowledgebaseService.delete_by_id(req["kb_id"]):
|
|
||||||
return get_data_error_result(
|
return await asyncio.to_thread(_rm_sync)
|
||||||
message="Database error (Knowledgebase removal)!")
|
|
||||||
for kb in kbs:
|
|
||||||
settings.docStoreConn.delete({"kb_id": kb.id}, search.index_name(kb.tenant_id), kb.id)
|
|
||||||
settings.docStoreConn.deleteIdx(search.index_name(kb.tenant_id), kb.id)
|
|
||||||
if hasattr(settings.STORAGE_IMPL, 'remove_bucket'):
|
|
||||||
settings.STORAGE_IMPL.remove_bucket(kb.id)
|
|
||||||
return get_json_result(data=True)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
@ -922,5 +936,3 @@ async def check_embedding():
|
|||||||
if summary["avg_cos_sim"] > 0.9:
|
if summary["avg_cos_sim"] > 0.9:
|
||||||
return get_json_result(data={"summary": summary, "results": results})
|
return get_json_result(data={"summary": summary, "results": results})
|
||||||
return get_json_result(code=RetCode.NOT_EFFECTIVE, message="Embedding model switch failed: the average similarity between old and new vectors is below 0.9, indicating incompatible vector spaces.", data={"summary": summary, "results": results})
|
return get_json_result(code=RetCode.NOT_EFFECTIVE, message="Embedding model switch failed: the average similarity between old and new vectors is below 0.9, indicating incompatible vector spaces.", data={"summary": summary, "results": results})
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -13,6 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
@ -787,7 +788,7 @@ Reason:
|
|||||||
- At the same time, related terms can also help search engines better understand user needs and return more accurate search results.
|
- At the same time, related terms can also help search engines better understand user needs and return more accurate search results.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
ans = chat_mdl.chat(
|
ans = await chat_mdl.async_chat(
|
||||||
prompt,
|
prompt,
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
@ -963,28 +964,30 @@ async def retrieval_test_embedded():
|
|||||||
use_kg = req.get("use_kg", False)
|
use_kg = req.get("use_kg", False)
|
||||||
top = int(req.get("top_k", 1024))
|
top = int(req.get("top_k", 1024))
|
||||||
langs = req.get("cross_languages", [])
|
langs = req.get("cross_languages", [])
|
||||||
tenant_ids = []
|
|
||||||
|
|
||||||
tenant_id = objs[0].tenant_id
|
tenant_id = objs[0].tenant_id
|
||||||
if not tenant_id:
|
if not tenant_id:
|
||||||
return get_error_data_result(message="permission denined.")
|
return get_error_data_result(message="permission denined.")
|
||||||
|
|
||||||
if req.get("search_id", ""):
|
def _retrieval_sync():
|
||||||
search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {})
|
local_doc_ids = list(doc_ids) if doc_ids else []
|
||||||
meta_data_filter = search_config.get("meta_data_filter", {})
|
tenant_ids = []
|
||||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
_question = question
|
||||||
if meta_data_filter.get("method") == "auto":
|
|
||||||
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
|
if req.get("search_id", ""):
|
||||||
filters: dict = gen_meta_filter(chat_mdl, metas, question)
|
search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {})
|
||||||
doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
meta_data_filter = search_config.get("meta_data_filter", {})
|
||||||
if not doc_ids:
|
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||||
doc_ids = None
|
if meta_data_filter.get("method") == "auto":
|
||||||
elif meta_data_filter.get("method") == "manual":
|
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
|
||||||
doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and")))
|
filters: dict = gen_meta_filter(chat_mdl, metas, _question)
|
||||||
if meta_data_filter["manual"] and not doc_ids:
|
local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
||||||
doc_ids = ["-999"]
|
if not local_doc_ids:
|
||||||
|
local_doc_ids = None
|
||||||
|
elif meta_data_filter.get("method") == "manual":
|
||||||
|
local_doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and")))
|
||||||
|
if meta_data_filter["manual"] and not local_doc_ids:
|
||||||
|
local_doc_ids = ["-999"]
|
||||||
|
|
||||||
try:
|
|
||||||
tenants = UserTenantService.query(user_id=tenant_id)
|
tenants = UserTenantService.query(user_id=tenant_id)
|
||||||
for kb_id in kb_ids:
|
for kb_id in kb_ids:
|
||||||
for tenant in tenants:
|
for tenant in tenants:
|
||||||
@ -1000,7 +1003,7 @@ async def retrieval_test_embedded():
|
|||||||
return get_error_data_result(message="Knowledgebase not found!")
|
return get_error_data_result(message="Knowledgebase not found!")
|
||||||
|
|
||||||
if langs:
|
if langs:
|
||||||
question = cross_languages(kb.tenant_id, None, question, langs)
|
_question = cross_languages(kb.tenant_id, None, _question, langs)
|
||||||
|
|
||||||
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
||||||
|
|
||||||
@ -1010,15 +1013,15 @@ async def retrieval_test_embedded():
|
|||||||
|
|
||||||
if req.get("keyword", False):
|
if req.get("keyword", False):
|
||||||
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
|
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
|
||||||
question += keyword_extraction(chat_mdl, question)
|
_question += keyword_extraction(chat_mdl, _question)
|
||||||
|
|
||||||
labels = label_question(question, [kb])
|
labels = label_question(_question, [kb])
|
||||||
ranks = settings.retriever.retrieval(
|
ranks = settings.retriever.retrieval(
|
||||||
question, embd_mdl, tenant_ids, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top,
|
_question, embd_mdl, tenant_ids, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top,
|
||||||
doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels
|
local_doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels
|
||||||
)
|
)
|
||||||
if use_kg:
|
if use_kg:
|
||||||
ck = settings.kg_retriever.retrieval(question, tenant_ids, kb_ids, embd_mdl,
|
ck = settings.kg_retriever.retrieval(_question, tenant_ids, kb_ids, embd_mdl,
|
||||||
LLMBundle(kb.tenant_id, LLMType.CHAT))
|
LLMBundle(kb.tenant_id, LLMType.CHAT))
|
||||||
if ck["content_with_weight"]:
|
if ck["content_with_weight"]:
|
||||||
ranks["chunks"].insert(0, ck)
|
ranks["chunks"].insert(0, ck)
|
||||||
@ -1028,6 +1031,9 @@ async def retrieval_test_embedded():
|
|||||||
ranks["labels"] = labels
|
ranks["labels"] = labels
|
||||||
|
|
||||||
return get_json_result(data=ranks)
|
return get_json_result(data=ranks)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return await asyncio.to_thread(_retrieval_sync)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if str(e).find("not_found") > 0:
|
if str(e).find("not_found") > 0:
|
||||||
return get_json_result(data=False, message="No chunk found! Check the chunk status please!",
|
return get_json_result(data=False, message="No chunk found! Check the chunk status please!",
|
||||||
@ -1064,7 +1070,7 @@ async def related_questions_embedded():
|
|||||||
|
|
||||||
gen_conf = search_config.get("llm_setting", {"temperature": 0.9})
|
gen_conf = search_config.get("llm_setting", {"temperature": 0.9})
|
||||||
prompt = load_prompt("related_question")
|
prompt = load_prompt("related_question")
|
||||||
ans = chat_mdl.chat(
|
ans = await chat_mdl.async_chat(
|
||||||
prompt,
|
prompt,
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
|
|||||||
@ -719,10 +719,14 @@ class DocumentService(CommonService):
|
|||||||
# only for special task and parsed docs and unfinished
|
# only for special task and parsed docs and unfinished
|
||||||
freeze_progress = special_task_running and doc_progress >= 1 and not finished
|
freeze_progress = special_task_running and doc_progress >= 1 and not finished
|
||||||
msg = "\n".join(sorted(msg))
|
msg = "\n".join(sorted(msg))
|
||||||
|
begin_at = d.get("process_begin_at")
|
||||||
|
if not begin_at:
|
||||||
|
begin_at = datetime.now()
|
||||||
|
# fallback
|
||||||
|
cls.update_by_id(d["id"], {"process_begin_at": begin_at})
|
||||||
|
|
||||||
info = {
|
info = {
|
||||||
"process_duration": datetime.timestamp(
|
"process_duration": max(datetime.timestamp(datetime.now()) - begin_at.timestamp(), 0),
|
||||||
datetime.now()) -
|
|
||||||
d["process_begin_at"].timestamp(),
|
|
||||||
"run": status}
|
"run": status}
|
||||||
if prg != 0 and not freeze_progress:
|
if prg != 0 and not freeze_progress:
|
||||||
info["progress"] = prg
|
info["progress"] = prg
|
||||||
|
|||||||
@ -1685,12 +1685,17 @@ class LiteLLMBase(ABC):
|
|||||||
|
|
||||||
yield ans, tol
|
yield ans, tol
|
||||||
|
|
||||||
async def async_chat(self, history, gen_conf, **kwargs):
|
async def async_chat(self, system, history, gen_conf, **kwargs):
|
||||||
logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2))
|
hist = list(history) if history else []
|
||||||
|
if system:
|
||||||
|
if not hist or hist[0].get("role") != "system":
|
||||||
|
hist.insert(0, {"role": "system", "content": system})
|
||||||
|
|
||||||
|
logging.info("[HISTORY]" + json.dumps(hist, ensure_ascii=False, indent=2))
|
||||||
if self.model_name.lower().find("qwen3") >= 0:
|
if self.model_name.lower().find("qwen3") >= 0:
|
||||||
kwargs["extra_body"] = {"enable_thinking": False}
|
kwargs["extra_body"] = {"enable_thinking": False}
|
||||||
|
|
||||||
completion_args = self._construct_completion_args(history=history, stream=False, tools=False, **gen_conf)
|
completion_args = self._construct_completion_args(history=hist, stream=False, tools=False, **gen_conf)
|
||||||
|
|
||||||
for attempt in range(self.max_retries + 1):
|
for attempt in range(self.max_retries + 1):
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -13,6 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
import asyncio
|
||||||
import datetime
|
import datetime
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
@ -360,6 +361,10 @@ def analyze_task(chat_mdl, prompt, task_name, tools_description: list[dict], use
|
|||||||
return kwd
|
return kwd
|
||||||
|
|
||||||
|
|
||||||
|
async def analyze_task_async(chat_mdl, prompt, task_name, tools_description: list[dict], user_defined_prompts: dict={}):
|
||||||
|
return await asyncio.to_thread(analyze_task, chat_mdl, prompt, task_name, tools_description, user_defined_prompts)
|
||||||
|
|
||||||
|
|
||||||
def next_step(chat_mdl, history:list, tools_description: list[dict], task_desc, user_defined_prompts: dict={}):
|
def next_step(chat_mdl, history:list, tools_description: list[dict], task_desc, user_defined_prompts: dict={}):
|
||||||
if not tools_description:
|
if not tools_description:
|
||||||
return ""
|
return ""
|
||||||
@ -378,6 +383,10 @@ def next_step(chat_mdl, history:list, tools_description: list[dict], task_desc,
|
|||||||
return json_str, tk_cnt
|
return json_str, tk_cnt
|
||||||
|
|
||||||
|
|
||||||
|
async def next_step_async(chat_mdl, history:list, tools_description: list[dict], task_desc, user_defined_prompts: dict={}):
|
||||||
|
return await asyncio.to_thread(next_step, chat_mdl, history, tools_description, task_desc, user_defined_prompts)
|
||||||
|
|
||||||
|
|
||||||
def reflect(chat_mdl, history: list[dict], tool_call_res: list[Tuple], user_defined_prompts: dict={}):
|
def reflect(chat_mdl, history: list[dict], tool_call_res: list[Tuple], user_defined_prompts: dict={}):
|
||||||
tool_calls = [{"name": p[0], "result": p[1]} for p in tool_call_res]
|
tool_calls = [{"name": p[0], "result": p[1]} for p in tool_call_res]
|
||||||
goal = history[1]["content"]
|
goal = history[1]["content"]
|
||||||
@ -429,6 +438,14 @@ def rank_memories(chat_mdl, goal:str, sub_goal:str, tool_call_summaries: list[st
|
|||||||
return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||||
|
|
||||||
|
|
||||||
|
async def reflect_async(chat_mdl, history: list[dict], tool_call_res: list[Tuple], user_defined_prompts: dict={}):
|
||||||
|
return await asyncio.to_thread(reflect, chat_mdl, history, tool_call_res, user_defined_prompts)
|
||||||
|
|
||||||
|
|
||||||
|
async def rank_memories_async(chat_mdl, goal:str, sub_goal:str, tool_call_summaries: list[str], user_defined_prompts: dict={}):
|
||||||
|
return await asyncio.to_thread(rank_memories, chat_mdl, goal, sub_goal, tool_call_summaries, user_defined_prompts)
|
||||||
|
|
||||||
|
|
||||||
def gen_meta_filter(chat_mdl, meta_data:dict, query: str) -> dict:
|
def gen_meta_filter(chat_mdl, meta_data:dict, query: str) -> dict:
|
||||||
meta_data_structure = {}
|
meta_data_structure = {}
|
||||||
for key, values in meta_data.items():
|
for key, values in meta_data.items():
|
||||||
|
|||||||
Reference in New Issue
Block a user