Refa: make RAGFlow more asynchronous 2 (#11689)

### What problem does this PR solve?

Make RAGFlow more asynchronous 2. #11551, #11579, #11619.

### Type of change

- [x] Refactoring
- [x] Performance Improvement
This commit is contained in:
Yongteng Lei
2025-12-03 14:19:53 +08:00
committed by GitHub
parent b5ad7b7062
commit e3f40db963
15 changed files with 654 additions and 292 deletions

View File

@ -416,13 +416,19 @@ class Canvas(Graph):
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
tasks = [] tasks = []
def _run_async_in_thread(coro_func, **call_kwargs):
return asyncio.run(coro_func(**call_kwargs))
i = f i = f
while i < t: while i < t:
cpn = self.get_component_obj(self.path[i]) cpn = self.get_component_obj(self.path[i])
task_fn = None task_fn = None
call_kwargs = None
if cpn.component_name.lower() in ["begin", "userfillup"]: if cpn.component_name.lower() in ["begin", "userfillup"]:
task_fn = partial(cpn.invoke, inputs=kwargs.get("inputs", {})) call_kwargs = {"inputs": kwargs.get("inputs", {})}
task_fn = cpn.invoke
i += 1 i += 1
else: else:
for _, ele in cpn.get_input_elements().items(): for _, ele in cpn.get_input_elements().items():
@ -431,13 +437,18 @@ class Canvas(Graph):
t -= 1 t -= 1
break break
else: else:
task_fn = partial(cpn.invoke, **cpn.get_input()) call_kwargs = cpn.get_input()
task_fn = cpn.invoke
i += 1 i += 1
if task_fn is None: if task_fn is None:
continue continue
tasks.append(loop.run_in_executor(self._thread_pool, task_fn)) invoke_async = getattr(cpn, "invoke_async", None)
if invoke_async and asyncio.iscoroutinefunction(invoke_async):
tasks.append(loop.run_in_executor(self._thread_pool, partial(_run_async_in_thread, invoke_async, **(call_kwargs or {}))))
else:
tasks.append(loop.run_in_executor(self._thread_pool, partial(task_fn, **(call_kwargs or {}))))
if tasks: if tasks:
await asyncio.gather(*tasks) await asyncio.gather(*tasks)

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import asyncio
import json import json
import logging import logging
import os import os
@ -239,6 +240,86 @@ class Agent(LLM, ToolBase):
self.set_output("use_tools", use_tools) self.set_output("use_tools", use_tools)
return ans return ans
async def _invoke_async(self, **kwargs):
"""
Async entry: reuse existing logic but offload heavy sync parts via async wrappers to reduce blocking.
"""
if self.check_if_canceled("Agent processing"):
return
if kwargs.get("user_prompt"):
usr_pmt = ""
if kwargs.get("reasoning"):
usr_pmt += "\nREASONING:\n{}\n".format(kwargs["reasoning"])
if kwargs.get("context"):
usr_pmt += "\nCONTEXT:\n{}\n".format(kwargs["context"])
if usr_pmt:
usr_pmt += "\nQUERY:\n{}\n".format(str(kwargs["user_prompt"]))
else:
usr_pmt = str(kwargs["user_prompt"])
self._param.prompts = [{"role": "user", "content": usr_pmt}]
if not self.tools:
if self.check_if_canceled("Agent processing"):
return
return await asyncio.to_thread(LLM._invoke, self, **kwargs)
prompt, msg, user_defined_prompt = self._prepare_prompt_variables()
output_schema = self._get_output_schema()
schema_prompt = ""
if output_schema:
schema = json.dumps(output_schema, ensure_ascii=False, indent=2)
schema_prompt = structured_output_prompt(schema)
downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else []
ex = self.exception_handler()
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not (ex and ex["goto"]) and not output_schema:
self.set_output("content", partial(self.stream_output_with_tools_async, prompt, msg, user_defined_prompt))
return
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
use_tools = []
ans = ""
async for delta_ans, tk in self._react_with_tools_streamly_async(prompt, msg, use_tools, user_defined_prompt, schema_prompt=schema_prompt):
if self.check_if_canceled("Agent processing"):
return
ans += delta_ans
if ans.find("**ERROR**") >= 0:
logging.error(f"Agent._chat got error. response: {ans}")
if self.get_exception_default_value():
self.set_output("content", self.get_exception_default_value())
else:
self.set_output("_ERROR", ans)
return
if output_schema:
error = ""
for _ in range(self._param.max_retries + 1):
try:
def clean_formated_answer(ans: str) -> str:
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
ans = re.sub(r"^.*```json", "", ans, flags=re.DOTALL)
return re.sub(r"```\n*$", "", ans, flags=re.DOTALL)
obj = json_repair.loads(clean_formated_answer(ans))
self.set_output("structured", obj)
if use_tools:
self.set_output("use_tools", use_tools)
return obj
except Exception:
error = "The answer cannot be parsed as JSON"
ans = self._force_format_to_schema(ans, schema_prompt)
if ans.find("**ERROR**") >= 0:
continue
self.set_output("_ERROR", error)
return
self.set_output("content", ans)
if use_tools:
self.set_output("use_tools", use_tools)
return ans
def stream_output_with_tools(self, prompt, msg, user_defined_prompt={}): def stream_output_with_tools(self, prompt, msg, user_defined_prompt={}):
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97)) _, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
answer_without_toolcall = "" answer_without_toolcall = ""
@ -261,6 +342,54 @@ class Agent(LLM, ToolBase):
if use_tools: if use_tools:
self.set_output("use_tools", use_tools) self.set_output("use_tools", use_tools)
async def stream_output_with_tools_async(self, prompt, msg, user_defined_prompt={}):
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
answer_without_toolcall = ""
use_tools = []
async for delta_ans, _ in self._react_with_tools_streamly_async(prompt, msg, use_tools, user_defined_prompt):
if self.check_if_canceled("Agent streaming"):
return
if delta_ans.find("**ERROR**") >= 0:
if self.get_exception_default_value():
self.set_output("content", self.get_exception_default_value())
yield self.get_exception_default_value()
else:
self.set_output("_ERROR", delta_ans)
return
answer_without_toolcall += delta_ans
yield delta_ans
self.set_output("content", answer_without_toolcall)
if use_tools:
self.set_output("use_tools", use_tools)
async def _react_with_tools_streamly_async(self, prompt, history: list[dict], use_tools, user_defined_prompt={}, schema_prompt: str = ""):
"""
Async wrapper that offloads synchronous flow to a thread, yielding results without blocking the event loop.
"""
loop = asyncio.get_running_loop()
queue: asyncio.Queue = asyncio.Queue()
def worker():
try:
for delta_ans, tk in self._react_with_tools_streamly(prompt, history, use_tools, user_defined_prompt, schema_prompt=schema_prompt):
asyncio.run_coroutine_threadsafe(queue.put((delta_ans, tk)), loop)
except Exception as e:
asyncio.run_coroutine_threadsafe(queue.put(e), loop)
finally:
asyncio.run_coroutine_threadsafe(queue.put(StopAsyncIteration), loop)
await asyncio.to_thread(worker)
while True:
item = await queue.get()
if item is StopAsyncIteration:
break
if isinstance(item, Exception):
raise item
yield item
def _gen_citations(self, text): def _gen_citations(self, text):
retrievals = self._canvas.get_reference() retrievals = self._canvas.get_reference()
retrievals = {"chunks": list(retrievals["chunks"].values()), "doc_aggs": list(retrievals["doc_aggs"].values())} retrievals = {"chunks": list(retrievals["chunks"].values()), "doc_aggs": list(retrievals["doc_aggs"].values())}
@ -433,4 +562,3 @@ Respond immediately with your final comprehensive answer.
for k in self._param.inputs.keys(): for k in self._param.inputs.keys():
self._param.inputs[k]["value"] = None self._param.inputs[k]["value"] = None
self._param.debug_inputs = {} self._param.debug_inputs = {}

View File

@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
# #
import asyncio
import re import re
import time import time
from abc import ABC from abc import ABC
@ -445,6 +446,34 @@ class ComponentBase(ABC):
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time")) self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
return self.output() return self.output()
async def invoke_async(self, **kwargs) -> dict[str, Any]:
"""
Async wrapper for component invocation.
Prefers coroutine `_invoke_async` if present; otherwise falls back to `_invoke`.
Handles timing and error recording consistently with `invoke`.
"""
self.set_output("_created_time", time.perf_counter())
try:
if self.check_if_canceled("Component processing"):
return
fn_async = getattr(self, "_invoke_async", None)
if fn_async and asyncio.iscoroutinefunction(fn_async):
await fn_async(**kwargs)
elif asyncio.iscoroutinefunction(self._invoke):
await self._invoke(**kwargs)
else:
await asyncio.to_thread(self._invoke, **kwargs)
except Exception as e:
if self.get_exception_default_value():
self.set_exception_default_value()
else:
self.set_output("_ERROR", str(e))
logging.exception(e)
self._param.debug_inputs = {}
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
return self.output()
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))) @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
def _invoke(self, **kwargs): def _invoke(self, **kwargs):
raise NotImplementedError() raise NotImplementedError()

View File

@ -13,12 +13,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import asyncio
import json import json
import logging import logging
import os import os
import re import re
import threading
from copy import deepcopy from copy import deepcopy
from typing import Any, Generator from typing import Any, Generator, AsyncGenerator
import json_repair import json_repair
from functools import partial from functools import partial
from common.constants import LLMType from common.constants import LLMType
@ -171,6 +173,13 @@ class LLM(ComponentBase):
return self.chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs) return self.chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs)
return self.chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs) return self.chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs)
async def _generate_async(self, msg: list[dict], **kwargs) -> str:
if not self.imgs and hasattr(self.chat_mdl, "async_chat"):
return await self.chat_mdl.async_chat(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs)
if self.imgs and hasattr(self.chat_mdl, "async_chat"):
return await self.chat_mdl.async_chat(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs)
return await asyncio.to_thread(self._generate, msg, **kwargs)
def _generate_streamly(self, msg:list[dict], **kwargs) -> Generator[str, None, None]: def _generate_streamly(self, msg:list[dict], **kwargs) -> Generator[str, None, None]:
ans = "" ans = ""
last_idx = 0 last_idx = 0
@ -205,6 +214,69 @@ class LLM(ComponentBase):
for txt in self.chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs): for txt in self.chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs):
yield delta(txt) yield delta(txt)
async def _generate_streamly_async(self, msg: list[dict], **kwargs) -> AsyncGenerator[str, None]:
async def delta_wrapper(txt_iter):
ans = ""
last_idx = 0
endswith_think = False
def delta(txt):
nonlocal ans, last_idx, endswith_think
delta_ans = txt[last_idx:]
ans = txt
if delta_ans.find("<think>") == 0:
last_idx += len("<think>")
return "<think>"
elif delta_ans.find("<think>") > 0:
delta_ans = txt[last_idx:last_idx + delta_ans.find("<think>")]
last_idx += delta_ans.find("<think>")
return delta_ans
elif delta_ans.endswith("</think>"):
endswith_think = True
elif endswith_think:
endswith_think = False
return "</think>"
last_idx = len(ans)
if ans.endswith("</think>"):
last_idx -= len("</think>")
return re.sub(r"(<think>|</think>)", "", delta_ans)
async for t in txt_iter:
yield delta(t)
if not self.imgs and hasattr(self.chat_mdl, "async_chat_streamly"):
async for t in delta_wrapper(self.chat_mdl.async_chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs)):
yield t
return
if self.imgs and hasattr(self.chat_mdl, "async_chat_streamly"):
async for t in delta_wrapper(self.chat_mdl.async_chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs)):
yield t
return
# fallback
loop = asyncio.get_running_loop()
queue: asyncio.Queue = asyncio.Queue()
def worker():
try:
for item in self._generate_streamly(msg, **kwargs):
loop.call_soon_threadsafe(queue.put_nowait, item)
except Exception as e:
loop.call_soon_threadsafe(queue.put_nowait, e)
finally:
loop.call_soon_threadsafe(queue.put_nowait, StopAsyncIteration)
threading.Thread(target=worker, daemon=True).start()
while True:
item = await queue.get()
if item is StopAsyncIteration:
break
if isinstance(item, Exception):
raise item
yield item
async def _stream_output_async(self, prompt, msg): async def _stream_output_async(self, prompt, msg):
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97)) _, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
answer = "" answer = ""

View File

@ -17,6 +17,7 @@ import logging
import re import re
import time import time
from copy import deepcopy from copy import deepcopy
import asyncio
from functools import partial from functools import partial
from typing import TypedDict, List, Any from typing import TypedDict, List, Any
from agent.component.base import ComponentParamBase, ComponentBase from agent.component.base import ComponentParamBase, ComponentBase
@ -50,10 +51,14 @@ class LLMToolPluginCallSession(ToolCallSession):
def tool_call(self, name: str, arguments: dict[str, Any]) -> Any: def tool_call(self, name: str, arguments: dict[str, Any]) -> Any:
assert name in self.tools_map, f"LLM tool {name} does not exist" assert name in self.tools_map, f"LLM tool {name} does not exist"
st = timer() st = timer()
if isinstance(self.tools_map[name], MCPToolCallSession): tool_obj = self.tools_map[name]
resp = self.tools_map[name].tool_call(name, arguments, 60) if isinstance(tool_obj, MCPToolCallSession):
resp = tool_obj.tool_call(name, arguments, 60)
else: else:
resp = self.tools_map[name].invoke(**arguments) if hasattr(tool_obj, "invoke_async") and asyncio.iscoroutinefunction(tool_obj.invoke_async):
resp = asyncio.run(tool_obj.invoke_async(**arguments))
else:
resp = asyncio.run(asyncio.to_thread(tool_obj.invoke, **arguments))
self.callback(name, arguments, resp, elapsed_time=timer()-st) self.callback(name, arguments, resp, elapsed_time=timer()-st)
return resp return resp
@ -139,6 +144,33 @@ class ToolBase(ComponentBase):
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time")) self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
return res return res
async def invoke_async(self, **kwargs):
"""
Async wrapper for tool invocation.
If `_invoke` is a coroutine, await it directly; otherwise run in a thread to avoid blocking.
Mirrors the exception handling of `invoke`.
"""
if self.check_if_canceled("Tool processing"):
return
self.set_output("_created_time", time.perf_counter())
try:
fn_async = getattr(self, "_invoke_async", None)
if fn_async and asyncio.iscoroutinefunction(fn_async):
res = await fn_async(**kwargs)
elif asyncio.iscoroutinefunction(self._invoke):
res = await self._invoke(**kwargs)
else:
res = await asyncio.to_thread(self._invoke, **kwargs)
except Exception as e:
self._param.outputs["_ERROR"] = {"value": str(e)}
logging.exception(e)
res = str(e)
self._param.debug_inputs = []
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
return res
def _retrieve_chunks(self, res_list: list, get_title, get_url, get_content, get_score=None): def _retrieve_chunks(self, res_list: list, get_title, get_url, get_content, get_score=None):
chunks = [] chunks = []
aggs = [] aggs = []

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import asyncio
import datetime import datetime
import json import json
import re import re
@ -147,31 +148,35 @@ async def set():
d["available_int"] = req["available_int"] d["available_int"] = req["available_int"]
try: try:
tenant_id = DocumentService.get_tenant_id(req["doc_id"]) def _set_sync():
if not tenant_id: tenant_id = DocumentService.get_tenant_id(req["doc_id"])
return get_data_error_result(message="Tenant not found!") if not tenant_id:
return get_data_error_result(message="Tenant not found!")
embd_id = DocumentService.get_embd_id(req["doc_id"]) embd_id = DocumentService.get_embd_id(req["doc_id"])
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_id) embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_id)
e, doc = DocumentService.get_by_id(req["doc_id"]) e, doc = DocumentService.get_by_id(req["doc_id"])
if not e: if not e:
return get_data_error_result(message="Document not found!") return get_data_error_result(message="Document not found!")
if doc.parser_id == ParserType.QA: _d = d
arr = [ if doc.parser_id == ParserType.QA:
t for t in re.split( arr = [
r"[\n\t]", t for t in re.split(
req["content_with_weight"]) if len(t) > 1] r"[\n\t]",
q, a = rmPrefix(arr[0]), rmPrefix("\n".join(arr[1:])) req["content_with_weight"]) if len(t) > 1]
d = beAdoc(d, q, a, not any( q, a = rmPrefix(arr[0]), rmPrefix("\n".join(arr[1:]))
[rag_tokenizer.is_chinese(t) for t in q + a])) _d = beAdoc(d, q, a, not any(
[rag_tokenizer.is_chinese(t) for t in q + a]))
v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d.get("question_kwd") else "\n".join(d["question_kwd"])]) v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not _d.get("question_kwd") else "\n".join(_d["question_kwd"])])
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1] v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
d["q_%d_vec" % len(v)] = v.tolist() _d["q_%d_vec" % len(v)] = v.tolist()
settings.docStoreConn.update({"id": req["chunk_id"]}, d, search.index_name(tenant_id), doc.kb_id) settings.docStoreConn.update({"id": req["chunk_id"]}, _d, search.index_name(tenant_id), doc.kb_id)
return get_json_result(data=True) return get_json_result(data=True)
return await asyncio.to_thread(_set_sync)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@ -182,16 +187,19 @@ async def set():
async def switch(): async def switch():
req = await get_request_json() req = await get_request_json()
try: try:
e, doc = DocumentService.get_by_id(req["doc_id"]) def _switch_sync():
if not e: e, doc = DocumentService.get_by_id(req["doc_id"])
return get_data_error_result(message="Document not found!") if not e:
for cid in req["chunk_ids"]: return get_data_error_result(message="Document not found!")
if not settings.docStoreConn.update({"id": cid}, for cid in req["chunk_ids"]:
{"available_int": int(req["available_int"])}, if not settings.docStoreConn.update({"id": cid},
search.index_name(DocumentService.get_tenant_id(req["doc_id"])), {"available_int": int(req["available_int"])},
doc.kb_id): search.index_name(DocumentService.get_tenant_id(req["doc_id"])),
return get_data_error_result(message="Index updating failure") doc.kb_id):
return get_json_result(data=True) return get_data_error_result(message="Index updating failure")
return get_json_result(data=True)
return await asyncio.to_thread(_switch_sync)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@ -202,20 +210,23 @@ async def switch():
async def rm(): async def rm():
req = await get_request_json() req = await get_request_json()
try: try:
e, doc = DocumentService.get_by_id(req["doc_id"]) def _rm_sync():
if not e: e, doc = DocumentService.get_by_id(req["doc_id"])
return get_data_error_result(message="Document not found!") if not e:
if not settings.docStoreConn.delete({"id": req["chunk_ids"]}, return get_data_error_result(message="Document not found!")
search.index_name(DocumentService.get_tenant_id(req["doc_id"])), if not settings.docStoreConn.delete({"id": req["chunk_ids"]},
doc.kb_id): search.index_name(DocumentService.get_tenant_id(req["doc_id"])),
return get_data_error_result(message="Chunk deleting failure") doc.kb_id):
deleted_chunk_ids = req["chunk_ids"] return get_data_error_result(message="Chunk deleting failure")
chunk_number = len(deleted_chunk_ids) deleted_chunk_ids = req["chunk_ids"]
DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0) chunk_number = len(deleted_chunk_ids)
for cid in deleted_chunk_ids: DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0)
if settings.STORAGE_IMPL.obj_exist(doc.kb_id, cid): for cid in deleted_chunk_ids:
settings.STORAGE_IMPL.rm(doc.kb_id, cid) if settings.STORAGE_IMPL.obj_exist(doc.kb_id, cid):
return get_json_result(data=True) settings.STORAGE_IMPL.rm(doc.kb_id, cid)
return get_json_result(data=True)
return await asyncio.to_thread(_rm_sync)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@ -245,35 +256,38 @@ async def create():
d["tag_feas"] = req["tag_feas"] d["tag_feas"] = req["tag_feas"]
try: try:
e, doc = DocumentService.get_by_id(req["doc_id"]) def _create_sync():
if not e: e, doc = DocumentService.get_by_id(req["doc_id"])
return get_data_error_result(message="Document not found!") if not e:
d["kb_id"] = [doc.kb_id] return get_data_error_result(message="Document not found!")
d["docnm_kwd"] = doc.name d["kb_id"] = [doc.kb_id]
d["title_tks"] = rag_tokenizer.tokenize(doc.name) d["docnm_kwd"] = doc.name
d["doc_id"] = doc.id d["title_tks"] = rag_tokenizer.tokenize(doc.name)
d["doc_id"] = doc.id
tenant_id = DocumentService.get_tenant_id(req["doc_id"]) tenant_id = DocumentService.get_tenant_id(req["doc_id"])
if not tenant_id: if not tenant_id:
return get_data_error_result(message="Tenant not found!") return get_data_error_result(message="Tenant not found!")
e, kb = KnowledgebaseService.get_by_id(doc.kb_id) e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
if not e: if not e:
return get_data_error_result(message="Knowledgebase not found!") return get_data_error_result(message="Knowledgebase not found!")
if kb.pagerank: if kb.pagerank:
d[PAGERANK_FLD] = kb.pagerank d[PAGERANK_FLD] = kb.pagerank
embd_id = DocumentService.get_embd_id(req["doc_id"]) embd_id = DocumentService.get_embd_id(req["doc_id"])
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING.value, embd_id) embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING.value, embd_id)
v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d["question_kwd"] else "\n".join(d["question_kwd"])]) v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d["question_kwd"] else "\n".join(d["question_kwd"])])
v = 0.1 * v[0] + 0.9 * v[1] v = 0.1 * v[0] + 0.9 * v[1]
d["q_%d_vec" % len(v)] = v.tolist() d["q_%d_vec" % len(v)] = v.tolist()
settings.docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id) settings.docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id)
DocumentService.increment_chunk_num( DocumentService.increment_chunk_num(
doc.id, doc.kb_id, c, 1, 0) doc.id, doc.kb_id, c, 1, 0)
return get_json_result(data={"chunk_id": chunck_id}) return get_json_result(data={"chunk_id": chunck_id})
return await asyncio.to_thread(_create_sync)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@ -297,25 +311,28 @@ async def retrieval_test():
use_kg = req.get("use_kg", False) use_kg = req.get("use_kg", False)
top = int(req.get("top_k", 1024)) top = int(req.get("top_k", 1024))
langs = req.get("cross_languages", []) langs = req.get("cross_languages", [])
tenant_ids = [] user_id = current_user.id
if req.get("search_id", ""): def _retrieval_sync():
search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {}) local_doc_ids = list(doc_ids) if doc_ids else []
meta_data_filter = search_config.get("meta_data_filter", {}) tenant_ids = []
metas = DocumentService.get_meta_by_kbs(kb_ids)
if meta_data_filter.get("method") == "auto":
chat_mdl = LLMBundle(current_user.id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
filters: dict = gen_meta_filter(chat_mdl, metas, question)
doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
if not doc_ids:
doc_ids = None
elif meta_data_filter.get("method") == "manual":
doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and")))
if meta_data_filter["manual"] and not doc_ids:
doc_ids = ["-999"]
try: if req.get("search_id", ""):
tenants = UserTenantService.query(user_id=current_user.id) search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {})
meta_data_filter = search_config.get("meta_data_filter", {})
metas = DocumentService.get_meta_by_kbs(kb_ids)
if meta_data_filter.get("method") == "auto":
chat_mdl = LLMBundle(user_id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
filters: dict = gen_meta_filter(chat_mdl, metas, question)
local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
if not local_doc_ids:
local_doc_ids = None
elif meta_data_filter.get("method") == "manual":
local_doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and")))
if meta_data_filter["manual"] and not local_doc_ids:
local_doc_ids = ["-999"]
tenants = UserTenantService.query(user_id=user_id)
for kb_id in kb_ids: for kb_id in kb_ids:
for tenant in tenants: for tenant in tenants:
if KnowledgebaseService.query( if KnowledgebaseService.query(
@ -331,8 +348,9 @@ async def retrieval_test():
if not e: if not e:
return get_data_error_result(message="Knowledgebase not found!") return get_data_error_result(message="Knowledgebase not found!")
_question = question
if langs: if langs:
question = cross_languages(kb.tenant_id, None, question, langs) _question = cross_languages(kb.tenant_id, None, _question, langs)
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id) embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
@ -342,19 +360,19 @@ async def retrieval_test():
if req.get("keyword", False): if req.get("keyword", False):
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT) chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
question += keyword_extraction(chat_mdl, question) _question += keyword_extraction(chat_mdl, _question)
labels = label_question(question, [kb]) labels = label_question(_question, [kb])
ranks = settings.retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, page, size, ranks = settings.retriever.retrieval(_question, embd_mdl, tenant_ids, kb_ids, page, size,
float(req.get("similarity_threshold", 0.0)), float(req.get("similarity_threshold", 0.0)),
float(req.get("vector_similarity_weight", 0.3)), float(req.get("vector_similarity_weight", 0.3)),
top, top,
doc_ids, rerank_mdl=rerank_mdl, local_doc_ids, rerank_mdl=rerank_mdl,
highlight=req.get("highlight", False), highlight=req.get("highlight", False),
rank_feature=labels rank_feature=labels
) )
if use_kg: if use_kg:
ck = settings.kg_retriever.retrieval(question, ck = settings.kg_retriever.retrieval(_question,
tenant_ids, tenant_ids,
kb_ids, kb_ids,
embd_mdl, embd_mdl,
@ -367,6 +385,9 @@ async def retrieval_test():
ranks["labels"] = labels ranks["labels"] = labels
return get_json_result(data=ranks) return get_json_result(data=ranks)
try:
return await asyncio.to_thread(_retrieval_sync)
except Exception as e: except Exception as e:
if str(e).find("not_found") > 0: if str(e).find("not_found") > 0:
return get_json_result(data=False, message='No chunk found! Check the chunk status please!', return get_json_result(data=False, message='No chunk found! Check the chunk status please!',

View File

@ -168,10 +168,12 @@ async def _render_web_oauth_popup(flow_id: str, success: bool, message: str, sou
status = "success" if success else "error" status = "success" if success else "error"
auto_close = "window.close();" if success else "" auto_close = "window.close();" if success else ""
escaped_message = escape(message) escaped_message = escape(message)
# Drive: ragflow-google-drive-oauth
# Gmail: ragflow-gmail-oauth
payload_type = f"ragflow-{source}-oauth"
payload_json = json.dumps( payload_json = json.dumps(
{ {
# TODO(google-oauth): include connector type (drive/gmail) in payload type if needed "type": payload_type,
"type": f"ragflow-google-{source}-oauth",
"status": status, "status": status,
"flowId": flow_id or "", "flowId": flow_id or "",
"message": message, "message": message,

View File

@ -462,7 +462,7 @@ async def related_questions():
if "parameter" in gen_conf: if "parameter" in gen_conf:
del gen_conf["parameter"] del gen_conf["parameter"]
prompt = load_prompt("related_question") prompt = load_prompt("related_question")
ans = chat_mdl.chat( ans = await chat_mdl.async_chat(
prompt, prompt,
[ [
{ {

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License # limitations under the License
# #
import asyncio
import json import json
import os.path import os.path
import pathlib import pathlib
@ -72,7 +73,7 @@ async def upload():
if not check_kb_team_permission(kb, current_user.id): if not check_kb_team_permission(kb, current_user.id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
err, files = FileService.upload_document(kb, file_objs, current_user.id) err, files = await asyncio.to_thread(FileService.upload_document, kb, file_objs, current_user.id)
if err: if err:
return get_json_result(data=files, message="\n".join(err), code=RetCode.SERVER_ERROR) return get_json_result(data=files, message="\n".join(err), code=RetCode.SERVER_ERROR)
@ -390,7 +391,7 @@ async def rm():
if not DocumentService.accessible4deletion(doc_id, current_user.id): if not DocumentService.accessible4deletion(doc_id, current_user.id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
errors = FileService.delete_docs(doc_ids, current_user.id) errors = await asyncio.to_thread(FileService.delete_docs, doc_ids, current_user.id)
if errors: if errors:
return get_json_result(data=False, message=errors, code=RetCode.SERVER_ERROR) return get_json_result(data=False, message=errors, code=RetCode.SERVER_ERROR)
@ -403,44 +404,48 @@ async def rm():
@validate_request("doc_ids", "run") @validate_request("doc_ids", "run")
async def run(): async def run():
req = await get_request_json() req = await get_request_json()
for doc_id in req["doc_ids"]:
if not DocumentService.accessible(doc_id, current_user.id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
try: try:
kb_table_num_map = {} def _run_sync():
for id in req["doc_ids"]: for doc_id in req["doc_ids"]:
info = {"run": str(req["run"]), "progress": 0} if not DocumentService.accessible(doc_id, current_user.id):
if str(req["run"]) == TaskStatus.RUNNING.value and req.get("delete", False): return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
info["progress_msg"] = ""
info["chunk_num"] = 0
info["token_num"] = 0
tenant_id = DocumentService.get_tenant_id(id) kb_table_num_map = {}
if not tenant_id: for id in req["doc_ids"]:
return get_data_error_result(message="Tenant not found!") info = {"run": str(req["run"]), "progress": 0}
e, doc = DocumentService.get_by_id(id) if str(req["run"]) == TaskStatus.RUNNING.value and req.get("delete", False):
if not e: info["progress_msg"] = ""
return get_data_error_result(message="Document not found!") info["chunk_num"] = 0
info["token_num"] = 0
if str(req["run"]) == TaskStatus.CANCEL.value: tenant_id = DocumentService.get_tenant_id(id)
if str(doc.run) == TaskStatus.RUNNING.value: if not tenant_id:
cancel_all_task_of(id) return get_data_error_result(message="Tenant not found!")
else: e, doc = DocumentService.get_by_id(id)
return get_data_error_result(message="Cannot cancel a task that is not in RUNNING status") if not e:
if all([("delete" not in req or req["delete"]), str(req["run"]) == TaskStatus.RUNNING.value, str(doc.run) == TaskStatus.DONE.value]): return get_data_error_result(message="Document not found!")
DocumentService.clear_chunk_num_when_rerun(doc.id)
DocumentService.update_by_id(id, info) if str(req["run"]) == TaskStatus.CANCEL.value:
if req.get("delete", False): if str(doc.run) == TaskStatus.RUNNING.value:
TaskService.filter_delete([Task.doc_id == id]) cancel_all_task_of(id)
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id): else:
settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id) return get_data_error_result(message="Cannot cancel a task that is not in RUNNING status")
if all([("delete" not in req or req["delete"]), str(req["run"]) == TaskStatus.RUNNING.value, str(doc.run) == TaskStatus.DONE.value]):
DocumentService.clear_chunk_num_when_rerun(doc.id)
if str(req["run"]) == TaskStatus.RUNNING.value: DocumentService.update_by_id(id, info)
doc = doc.to_dict() if req.get("delete", False):
DocumentService.run(tenant_id, doc, kb_table_num_map) TaskService.filter_delete([Task.doc_id == id])
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id)
return get_json_result(data=True) if str(req["run"]) == TaskStatus.RUNNING.value:
doc_dict = doc.to_dict()
DocumentService.run(tenant_id, doc_dict, kb_table_num_map)
return get_json_result(data=True)
return await asyncio.to_thread(_run_sync)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@ -450,45 +455,49 @@ async def run():
@validate_request("doc_id", "name") @validate_request("doc_id", "name")
async def rename(): async def rename():
req = await get_request_json() req = await get_request_json()
if not DocumentService.accessible(req["doc_id"], current_user.id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
try: try:
e, doc = DocumentService.get_by_id(req["doc_id"]) def _rename_sync():
if not e: if not DocumentService.accessible(req["doc_id"], current_user.id):
return get_data_error_result(message="Document not found!") return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
if pathlib.Path(req["name"].lower()).suffix != pathlib.Path(doc.name.lower()).suffix:
return get_json_result(data=False, message="The extension of file can't be changed", code=RetCode.ARGUMENT_ERROR)
if len(req["name"].encode("utf-8")) > FILE_NAME_LEN_LIMIT:
return get_json_result(data=False, message=f"File name must be {FILE_NAME_LEN_LIMIT} bytes or less.", code=RetCode.ARGUMENT_ERROR)
for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id): e, doc = DocumentService.get_by_id(req["doc_id"])
if d.name == req["name"]: if not e:
return get_data_error_result(message="Duplicated document name in the same knowledgebase.") return get_data_error_result(message="Document not found!")
if pathlib.Path(req["name"].lower()).suffix != pathlib.Path(doc.name.lower()).suffix:
return get_json_result(data=False, message="The extension of file can't be changed", code=RetCode.ARGUMENT_ERROR)
if len(req["name"].encode("utf-8")) > FILE_NAME_LEN_LIMIT:
return get_json_result(data=False, message=f"File name must be {FILE_NAME_LEN_LIMIT} bytes or less.", code=RetCode.ARGUMENT_ERROR)
if not DocumentService.update_by_id(req["doc_id"], {"name": req["name"]}): for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id):
return get_data_error_result(message="Database error (Document rename)!") if d.name == req["name"]:
return get_data_error_result(message="Duplicated document name in the same knowledgebase.")
informs = File2DocumentService.get_by_document_id(req["doc_id"]) if not DocumentService.update_by_id(req["doc_id"], {"name": req["name"]}):
if informs: return get_data_error_result(message="Database error (Document rename)!")
e, file = FileService.get_by_id(informs[0].file_id)
FileService.update_by_id(file.id, {"name": req["name"]})
tenant_id = DocumentService.get_tenant_id(req["doc_id"]) informs = File2DocumentService.get_by_document_id(req["doc_id"])
title_tks = rag_tokenizer.tokenize(req["name"]) if informs:
es_body = { e, file = FileService.get_by_id(informs[0].file_id)
"docnm_kwd": req["name"], FileService.update_by_id(file.id, {"name": req["name"]})
"title_tks": title_tks,
"title_sm_tks": rag_tokenizer.fine_grained_tokenize(title_tks), tenant_id = DocumentService.get_tenant_id(req["doc_id"])
} title_tks = rag_tokenizer.tokenize(req["name"])
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id): es_body = {
settings.docStoreConn.update( "docnm_kwd": req["name"],
{"doc_id": req["doc_id"]}, "title_tks": title_tks,
es_body, "title_sm_tks": rag_tokenizer.fine_grained_tokenize(title_tks),
search.index_name(tenant_id), }
doc.kb_id, if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
) settings.docStoreConn.update(
{"doc_id": req["doc_id"]},
es_body,
search.index_name(tenant_id),
doc.kb_id,
)
return get_json_result(data=True)
return await asyncio.to_thread(_rename_sync)
return get_json_result(data=True)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@ -502,7 +511,8 @@ async def get(doc_id):
return get_data_error_result(message="Document not found!") return get_data_error_result(message="Document not found!")
b, n = File2DocumentService.get_storage_address(doc_id=doc_id) b, n = File2DocumentService.get_storage_address(doc_id=doc_id)
response = await make_response(settings.STORAGE_IMPL.get(b, n)) data = await asyncio.to_thread(settings.STORAGE_IMPL.get, b, n)
response = await make_response(data)
ext = re.search(r"\.([^.]+)$", doc.name.lower()) ext = re.search(r"\.([^.]+)$", doc.name.lower())
ext = ext.group(1) if ext else None ext = ext.group(1) if ext else None
@ -523,8 +533,7 @@ async def get(doc_id):
async def download_attachment(attachment_id): async def download_attachment(attachment_id):
try: try:
ext = request.args.get("ext", "markdown") ext = request.args.get("ext", "markdown")
data = settings.STORAGE_IMPL.get(current_user.id, attachment_id) data = await asyncio.to_thread(settings.STORAGE_IMPL.get, current_user.id, attachment_id)
# data = settings.STORAGE_IMPL.get("eb500d50bb0411f0907561d2782adda5", attachment_id)
response = await make_response(data) response = await make_response(data)
response.headers.set("Content-Type", CONTENT_TYPE_MAP.get(ext, f"application/{ext}")) response.headers.set("Content-Type", CONTENT_TYPE_MAP.get(ext, f"application/{ext}"))
@ -596,7 +605,8 @@ async def get_image(image_id):
if len(arr) != 2: if len(arr) != 2:
return get_data_error_result(message="Image not found.") return get_data_error_result(message="Image not found.")
bkt, nm = image_id.split("-") bkt, nm = image_id.split("-")
response = await make_response(settings.STORAGE_IMPL.get(bkt, nm)) data = await asyncio.to_thread(settings.STORAGE_IMPL.get, bkt, nm)
response = await make_response(data)
response.headers.set("Content-Type", "image/JPEG") response.headers.set("Content-Type", "image/JPEG")
return response return response
except Exception as e: except Exception as e:

View File

@ -14,6 +14,7 @@
# limitations under the License # limitations under the License
# #
import logging import logging
import asyncio
import os import os
import pathlib import pathlib
import re import re
@ -61,9 +62,10 @@ async def upload():
e, pf_folder = FileService.get_by_id(pf_id) e, pf_folder = FileService.get_by_id(pf_id)
if not e: if not e:
return get_data_error_result( message="Can't find this folder!") return get_data_error_result( message="Can't find this folder!")
for file_obj in file_objs:
async def _handle_single_file(file_obj):
MAX_FILE_NUM_PER_USER: int = int(os.environ.get('MAX_FILE_NUM_PER_USER', 0)) MAX_FILE_NUM_PER_USER: int = int(os.environ.get('MAX_FILE_NUM_PER_USER', 0))
if 0 < MAX_FILE_NUM_PER_USER <= DocumentService.get_doc_count(current_user.id): if 0 < MAX_FILE_NUM_PER_USER <= await asyncio.to_thread(DocumentService.get_doc_count, current_user.id):
return get_data_error_result( message="Exceed the maximum file number of a free user!") return get_data_error_result( message="Exceed the maximum file number of a free user!")
# split file name path # split file name path
@ -75,35 +77,36 @@ async def upload():
file_len = len(file_obj_names) file_len = len(file_obj_names)
# get folder # get folder
file_id_list = FileService.get_id_list_by_id(pf_id, file_obj_names, 1, [pf_id]) file_id_list = await asyncio.to_thread(FileService.get_id_list_by_id, pf_id, file_obj_names, 1, [pf_id])
len_id_list = len(file_id_list) len_id_list = len(file_id_list)
# create folder # create folder
if file_len != len_id_list: if file_len != len_id_list:
e, file = FileService.get_by_id(file_id_list[len_id_list - 1]) e, file = await asyncio.to_thread(FileService.get_by_id, file_id_list[len_id_list - 1])
if not e: if not e:
return get_data_error_result(message="Folder not found!") return get_data_error_result(message="Folder not found!")
last_folder = FileService.create_folder(file, file_id_list[len_id_list - 1], file_obj_names, last_folder = await asyncio.to_thread(FileService.create_folder, file, file_id_list[len_id_list - 1], file_obj_names,
len_id_list) len_id_list)
else: else:
e, file = FileService.get_by_id(file_id_list[len_id_list - 2]) e, file = await asyncio.to_thread(FileService.get_by_id, file_id_list[len_id_list - 2])
if not e: if not e:
return get_data_error_result(message="Folder not found!") return get_data_error_result(message="Folder not found!")
last_folder = FileService.create_folder(file, file_id_list[len_id_list - 2], file_obj_names, last_folder = await asyncio.to_thread(FileService.create_folder, file, file_id_list[len_id_list - 2], file_obj_names,
len_id_list) len_id_list)
# file type # file type
filetype = filename_type(file_obj_names[file_len - 1]) filetype = filename_type(file_obj_names[file_len - 1])
location = file_obj_names[file_len - 1] location = file_obj_names[file_len - 1]
while settings.STORAGE_IMPL.obj_exist(last_folder.id, location): while await asyncio.to_thread(settings.STORAGE_IMPL.obj_exist, last_folder.id, location):
location += "_" location += "_"
blob = file_obj.read() blob = await asyncio.to_thread(file_obj.read)
filename = duplicate_name( filename = await asyncio.to_thread(
duplicate_name,
FileService.query, FileService.query,
name=file_obj_names[file_len - 1], name=file_obj_names[file_len - 1],
parent_id=last_folder.id) parent_id=last_folder.id)
settings.STORAGE_IMPL.put(last_folder.id, location, blob) await asyncio.to_thread(settings.STORAGE_IMPL.put, last_folder.id, location, blob)
file = { file_data = {
"id": get_uuid(), "id": get_uuid(),
"parent_id": last_folder.id, "parent_id": last_folder.id,
"tenant_id": current_user.id, "tenant_id": current_user.id,
@ -113,8 +116,13 @@ async def upload():
"location": location, "location": location,
"size": len(blob), "size": len(blob),
} }
file = FileService.insert(file) inserted = await asyncio.to_thread(FileService.insert, file_data)
file_res.append(file.to_json()) return inserted.to_json()
for file_obj in file_objs:
res = await _handle_single_file(file_obj)
file_res.append(res)
return get_json_result(data=file_res) return get_json_result(data=file_res)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@ -242,55 +250,58 @@ async def rm():
req = await get_request_json() req = await get_request_json()
file_ids = req["file_ids"] file_ids = req["file_ids"]
def _delete_single_file(file):
try:
if file.location:
settings.STORAGE_IMPL.rm(file.parent_id, file.location)
except Exception as e:
logging.exception(f"Fail to remove object: {file.parent_id}/{file.location}, error: {e}")
informs = File2DocumentService.get_by_file_id(file.id)
for inform in informs:
doc_id = inform.document_id
e, doc = DocumentService.get_by_id(doc_id)
if e and doc:
tenant_id = DocumentService.get_tenant_id(doc_id)
if tenant_id:
DocumentService.remove_document(doc, tenant_id)
File2DocumentService.delete_by_file_id(file.id)
FileService.delete(file)
def _delete_folder_recursive(folder, tenant_id):
sub_files = FileService.list_all_files_by_parent_id(folder.id)
for sub_file in sub_files:
if sub_file.type == FileType.FOLDER.value:
_delete_folder_recursive(sub_file, tenant_id)
else:
_delete_single_file(sub_file)
FileService.delete(folder)
try: try:
for file_id in file_ids: def _delete_single_file(file):
e, file = FileService.get_by_id(file_id) try:
if not e or not file: if file.location:
return get_data_error_result(message="File or Folder not found!") settings.STORAGE_IMPL.rm(file.parent_id, file.location)
if not file.tenant_id: except Exception as e:
return get_data_error_result(message="Tenant not found!") logging.exception(f"Fail to remove object: {file.parent_id}/{file.location}, error: {e}")
if not check_file_team_permission(file, current_user.id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
if file.source_type == FileSource.KNOWLEDGEBASE: informs = File2DocumentService.get_by_file_id(file.id)
continue for inform in informs:
doc_id = inform.document_id
e, doc = DocumentService.get_by_id(doc_id)
if e and doc:
tenant_id = DocumentService.get_tenant_id(doc_id)
if tenant_id:
DocumentService.remove_document(doc, tenant_id)
File2DocumentService.delete_by_file_id(file.id)
if file.type == FileType.FOLDER.value: FileService.delete(file)
_delete_folder_recursive(file, current_user.id)
continue
_delete_single_file(file) def _delete_folder_recursive(folder, tenant_id):
sub_files = FileService.list_all_files_by_parent_id(folder.id)
for sub_file in sub_files:
if sub_file.type == FileType.FOLDER.value:
_delete_folder_recursive(sub_file, tenant_id)
else:
_delete_single_file(sub_file)
return get_json_result(data=True) FileService.delete(folder)
def _rm_sync():
for file_id in file_ids:
e, file = FileService.get_by_id(file_id)
if not e or not file:
return get_data_error_result(message="File or Folder not found!")
if not file.tenant_id:
return get_data_error_result(message="Tenant not found!")
if not check_file_team_permission(file, current_user.id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
if file.source_type == FileSource.KNOWLEDGEBASE:
continue
if file.type == FileType.FOLDER.value:
_delete_folder_recursive(file, current_user.id)
continue
_delete_single_file(file)
return get_json_result(data=True)
return await asyncio.to_thread(_rm_sync)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@ -346,10 +357,10 @@ async def get(file_id):
if not check_file_team_permission(file, current_user.id): if not check_file_team_permission(file, current_user.id):
return get_json_result(data=False, message='No authorization.', code=RetCode.AUTHENTICATION_ERROR) return get_json_result(data=False, message='No authorization.', code=RetCode.AUTHENTICATION_ERROR)
blob = settings.STORAGE_IMPL.get(file.parent_id, file.location) blob = await asyncio.to_thread(settings.STORAGE_IMPL.get, file.parent_id, file.location)
if not blob: if not blob:
b, n = File2DocumentService.get_storage_address(file_id=file_id) b, n = File2DocumentService.get_storage_address(file_id=file_id)
blob = settings.STORAGE_IMPL.get(b, n) blob = await asyncio.to_thread(settings.STORAGE_IMPL.get, b, n)
response = await make_response(blob) response = await make_response(blob)
ext = re.search(r"\.([^.]+)$", file.name.lower()) ext = re.search(r"\.([^.]+)$", file.name.lower())
@ -444,10 +455,12 @@ async def move():
}, },
) )
for file in files: def _move_sync():
_move_entry_recursive(file, dest_folder) for file in files:
_move_entry_recursive(file, dest_folder)
return get_json_result(data=True)
return get_json_result(data=True) return await asyncio.to_thread(_move_sync)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)

View File

@ -17,6 +17,7 @@ import json
import logging import logging
import random import random
import re import re
import asyncio
from quart import request from quart import request
import numpy as np import numpy as np
@ -116,12 +117,22 @@ async def update():
if kb.pagerank != req.get("pagerank", 0): if kb.pagerank != req.get("pagerank", 0):
if req.get("pagerank", 0) > 0: if req.get("pagerank", 0) > 0:
settings.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]}, await asyncio.to_thread(
search.index_name(kb.tenant_id), kb.id) settings.docStoreConn.update,
{"kb_id": kb.id},
{PAGERANK_FLD: req["pagerank"]},
search.index_name(kb.tenant_id),
kb.id,
)
else: else:
# Elasticsearch requires PAGERANK_FLD be non-zero! # Elasticsearch requires PAGERANK_FLD be non-zero!
settings.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD}, await asyncio.to_thread(
search.index_name(kb.tenant_id), kb.id) settings.docStoreConn.update,
{"exists": PAGERANK_FLD},
{"remove": PAGERANK_FLD},
search.index_name(kb.tenant_id),
kb.id,
)
e, kb = KnowledgebaseService.get_by_id(kb.id) e, kb = KnowledgebaseService.get_by_id(kb.id)
if not e: if not e:
@ -224,25 +235,28 @@ async def rm():
data=False, message='Only owner of knowledgebase authorized for this operation.', data=False, message='Only owner of knowledgebase authorized for this operation.',
code=RetCode.OPERATING_ERROR) code=RetCode.OPERATING_ERROR)
for doc in DocumentService.query(kb_id=req["kb_id"]): def _rm_sync():
if not DocumentService.remove_document(doc, kbs[0].tenant_id): for doc in DocumentService.query(kb_id=req["kb_id"]):
if not DocumentService.remove_document(doc, kbs[0].tenant_id):
return get_data_error_result(
message="Database error (Document removal)!")
f2d = File2DocumentService.get_by_document_id(doc.id)
if f2d:
FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id])
File2DocumentService.delete_by_document_id(doc.id)
FileService.filter_delete(
[File.source_type == FileSource.KNOWLEDGEBASE, File.type == "folder", File.name == kbs[0].name])
if not KnowledgebaseService.delete_by_id(req["kb_id"]):
return get_data_error_result( return get_data_error_result(
message="Database error (Document removal)!") message="Database error (Knowledgebase removal)!")
f2d = File2DocumentService.get_by_document_id(doc.id) for kb in kbs:
if f2d: settings.docStoreConn.delete({"kb_id": kb.id}, search.index_name(kb.tenant_id), kb.id)
FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id]) settings.docStoreConn.deleteIdx(search.index_name(kb.tenant_id), kb.id)
File2DocumentService.delete_by_document_id(doc.id) if hasattr(settings.STORAGE_IMPL, 'remove_bucket'):
FileService.filter_delete( settings.STORAGE_IMPL.remove_bucket(kb.id)
[File.source_type == FileSource.KNOWLEDGEBASE, File.type == "folder", File.name == kbs[0].name]) return get_json_result(data=True)
if not KnowledgebaseService.delete_by_id(req["kb_id"]):
return get_data_error_result( return await asyncio.to_thread(_rm_sync)
message="Database error (Knowledgebase removal)!")
for kb in kbs:
settings.docStoreConn.delete({"kb_id": kb.id}, search.index_name(kb.tenant_id), kb.id)
settings.docStoreConn.deleteIdx(search.index_name(kb.tenant_id), kb.id)
if hasattr(settings.STORAGE_IMPL, 'remove_bucket'):
settings.STORAGE_IMPL.remove_bucket(kb.id)
return get_json_result(data=True)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@ -922,5 +936,3 @@ async def check_embedding():
if summary["avg_cos_sim"] > 0.9: if summary["avg_cos_sim"] > 0.9:
return get_json_result(data={"summary": summary, "results": results}) return get_json_result(data={"summary": summary, "results": results})
return get_json_result(code=RetCode.NOT_EFFECTIVE, message="Embedding model switch failed: the average similarity between old and new vectors is below 0.9, indicating incompatible vector spaces.", data={"summary": summary, "results": results}) return get_json_result(code=RetCode.NOT_EFFECTIVE, message="Embedding model switch failed: the average similarity between old and new vectors is below 0.9, indicating incompatible vector spaces.", data={"summary": summary, "results": results})

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import asyncio
import json import json
import re import re
import time import time
@ -787,7 +788,7 @@ Reason:
- At the same time, related terms can also help search engines better understand user needs and return more accurate search results. - At the same time, related terms can also help search engines better understand user needs and return more accurate search results.
""" """
ans = chat_mdl.chat( ans = await chat_mdl.async_chat(
prompt, prompt,
[ [
{ {
@ -963,28 +964,30 @@ async def retrieval_test_embedded():
use_kg = req.get("use_kg", False) use_kg = req.get("use_kg", False)
top = int(req.get("top_k", 1024)) top = int(req.get("top_k", 1024))
langs = req.get("cross_languages", []) langs = req.get("cross_languages", [])
tenant_ids = []
tenant_id = objs[0].tenant_id tenant_id = objs[0].tenant_id
if not tenant_id: if not tenant_id:
return get_error_data_result(message="permission denined.") return get_error_data_result(message="permission denined.")
if req.get("search_id", ""): def _retrieval_sync():
search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {}) local_doc_ids = list(doc_ids) if doc_ids else []
meta_data_filter = search_config.get("meta_data_filter", {}) tenant_ids = []
metas = DocumentService.get_meta_by_kbs(kb_ids) _question = question
if meta_data_filter.get("method") == "auto":
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_name=search_config.get("chat_id", "")) if req.get("search_id", ""):
filters: dict = gen_meta_filter(chat_mdl, metas, question) search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {})
doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) meta_data_filter = search_config.get("meta_data_filter", {})
if not doc_ids: metas = DocumentService.get_meta_by_kbs(kb_ids)
doc_ids = None if meta_data_filter.get("method") == "auto":
elif meta_data_filter.get("method") == "manual": chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and"))) filters: dict = gen_meta_filter(chat_mdl, metas, _question)
if meta_data_filter["manual"] and not doc_ids: local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
doc_ids = ["-999"] if not local_doc_ids:
local_doc_ids = None
elif meta_data_filter.get("method") == "manual":
local_doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and")))
if meta_data_filter["manual"] and not local_doc_ids:
local_doc_ids = ["-999"]
try:
tenants = UserTenantService.query(user_id=tenant_id) tenants = UserTenantService.query(user_id=tenant_id)
for kb_id in kb_ids: for kb_id in kb_ids:
for tenant in tenants: for tenant in tenants:
@ -1000,7 +1003,7 @@ async def retrieval_test_embedded():
return get_error_data_result(message="Knowledgebase not found!") return get_error_data_result(message="Knowledgebase not found!")
if langs: if langs:
question = cross_languages(kb.tenant_id, None, question, langs) _question = cross_languages(kb.tenant_id, None, _question, langs)
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id) embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
@ -1010,15 +1013,15 @@ async def retrieval_test_embedded():
if req.get("keyword", False): if req.get("keyword", False):
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT) chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
question += keyword_extraction(chat_mdl, question) _question += keyword_extraction(chat_mdl, _question)
labels = label_question(question, [kb]) labels = label_question(_question, [kb])
ranks = settings.retriever.retrieval( ranks = settings.retriever.retrieval(
question, embd_mdl, tenant_ids, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top, _question, embd_mdl, tenant_ids, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top,
doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels local_doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels
) )
if use_kg: if use_kg:
ck = settings.kg_retriever.retrieval(question, tenant_ids, kb_ids, embd_mdl, ck = settings.kg_retriever.retrieval(_question, tenant_ids, kb_ids, embd_mdl,
LLMBundle(kb.tenant_id, LLMType.CHAT)) LLMBundle(kb.tenant_id, LLMType.CHAT))
if ck["content_with_weight"]: if ck["content_with_weight"]:
ranks["chunks"].insert(0, ck) ranks["chunks"].insert(0, ck)
@ -1028,6 +1031,9 @@ async def retrieval_test_embedded():
ranks["labels"] = labels ranks["labels"] = labels
return get_json_result(data=ranks) return get_json_result(data=ranks)
try:
return await asyncio.to_thread(_retrieval_sync)
except Exception as e: except Exception as e:
if str(e).find("not_found") > 0: if str(e).find("not_found") > 0:
return get_json_result(data=False, message="No chunk found! Check the chunk status please!", return get_json_result(data=False, message="No chunk found! Check the chunk status please!",
@ -1064,7 +1070,7 @@ async def related_questions_embedded():
gen_conf = search_config.get("llm_setting", {"temperature": 0.9}) gen_conf = search_config.get("llm_setting", {"temperature": 0.9})
prompt = load_prompt("related_question") prompt = load_prompt("related_question")
ans = chat_mdl.chat( ans = await chat_mdl.async_chat(
prompt, prompt,
[ [
{ {

View File

@ -719,10 +719,14 @@ class DocumentService(CommonService):
# only for special task and parsed docs and unfinished # only for special task and parsed docs and unfinished
freeze_progress = special_task_running and doc_progress >= 1 and not finished freeze_progress = special_task_running and doc_progress >= 1 and not finished
msg = "\n".join(sorted(msg)) msg = "\n".join(sorted(msg))
begin_at = d.get("process_begin_at")
if not begin_at:
begin_at = datetime.now()
# fallback
cls.update_by_id(d["id"], {"process_begin_at": begin_at})
info = { info = {
"process_duration": datetime.timestamp( "process_duration": max(datetime.timestamp(datetime.now()) - begin_at.timestamp(), 0),
datetime.now()) -
d["process_begin_at"].timestamp(),
"run": status} "run": status}
if prg != 0 and not freeze_progress: if prg != 0 and not freeze_progress:
info["progress"] = prg info["progress"] = prg

View File

@ -1685,12 +1685,17 @@ class LiteLLMBase(ABC):
yield ans, tol yield ans, tol
async def async_chat(self, history, gen_conf, **kwargs): async def async_chat(self, system, history, gen_conf, **kwargs):
logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2)) hist = list(history) if history else []
if system:
if not hist or hist[0].get("role") != "system":
hist.insert(0, {"role": "system", "content": system})
logging.info("[HISTORY]" + json.dumps(hist, ensure_ascii=False, indent=2))
if self.model_name.lower().find("qwen3") >= 0: if self.model_name.lower().find("qwen3") >= 0:
kwargs["extra_body"] = {"enable_thinking": False} kwargs["extra_body"] = {"enable_thinking": False}
completion_args = self._construct_completion_args(history=history, stream=False, tools=False, **gen_conf) completion_args = self._construct_completion_args(history=hist, stream=False, tools=False, **gen_conf)
for attempt in range(self.max_retries + 1): for attempt in range(self.max_retries + 1):
try: try:

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import asyncio
import datetime import datetime
import json import json
import logging import logging
@ -360,6 +361,10 @@ def analyze_task(chat_mdl, prompt, task_name, tools_description: list[dict], use
return kwd return kwd
async def analyze_task_async(chat_mdl, prompt, task_name, tools_description: list[dict], user_defined_prompts: dict={}):
return await asyncio.to_thread(analyze_task, chat_mdl, prompt, task_name, tools_description, user_defined_prompts)
def next_step(chat_mdl, history:list, tools_description: list[dict], task_desc, user_defined_prompts: dict={}): def next_step(chat_mdl, history:list, tools_description: list[dict], task_desc, user_defined_prompts: dict={}):
if not tools_description: if not tools_description:
return "" return ""
@ -378,6 +383,10 @@ def next_step(chat_mdl, history:list, tools_description: list[dict], task_desc,
return json_str, tk_cnt return json_str, tk_cnt
async def next_step_async(chat_mdl, history:list, tools_description: list[dict], task_desc, user_defined_prompts: dict={}):
return await asyncio.to_thread(next_step, chat_mdl, history, tools_description, task_desc, user_defined_prompts)
def reflect(chat_mdl, history: list[dict], tool_call_res: list[Tuple], user_defined_prompts: dict={}): def reflect(chat_mdl, history: list[dict], tool_call_res: list[Tuple], user_defined_prompts: dict={}):
tool_calls = [{"name": p[0], "result": p[1]} for p in tool_call_res] tool_calls = [{"name": p[0], "result": p[1]} for p in tool_call_res]
goal = history[1]["content"] goal = history[1]["content"]
@ -429,6 +438,14 @@ def rank_memories(chat_mdl, goal:str, sub_goal:str, tool_call_summaries: list[st
return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL) return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
async def reflect_async(chat_mdl, history: list[dict], tool_call_res: list[Tuple], user_defined_prompts: dict={}):
return await asyncio.to_thread(reflect, chat_mdl, history, tool_call_res, user_defined_prompts)
async def rank_memories_async(chat_mdl, goal:str, sub_goal:str, tool_call_summaries: list[str], user_defined_prompts: dict={}):
return await asyncio.to_thread(rank_memories, chat_mdl, goal, sub_goal, tool_call_summaries, user_defined_prompts)
def gen_meta_filter(chat_mdl, meta_data:dict, query: str) -> dict: def gen_meta_filter(chat_mdl, meta_data:dict, query: str) -> dict:
meta_data_structure = {} meta_data_structure = {}
for key, values in meta_data.items(): for key, values in meta_data.items():