diff --git a/agent/canvas.py b/agent/canvas.py index 667caec29..5fb7af83f 100644 --- a/agent/canvas.py +++ b/agent/canvas.py @@ -416,13 +416,19 @@ class Canvas(Graph): loop = asyncio.get_running_loop() tasks = [] + + def _run_async_in_thread(coro_func, **call_kwargs): + return asyncio.run(coro_func(**call_kwargs)) + i = f while i < t: cpn = self.get_component_obj(self.path[i]) task_fn = None + call_kwargs = None if cpn.component_name.lower() in ["begin", "userfillup"]: - task_fn = partial(cpn.invoke, inputs=kwargs.get("inputs", {})) + call_kwargs = {"inputs": kwargs.get("inputs", {})} + task_fn = cpn.invoke i += 1 else: for _, ele in cpn.get_input_elements().items(): @@ -431,13 +437,18 @@ class Canvas(Graph): t -= 1 break else: - task_fn = partial(cpn.invoke, **cpn.get_input()) + call_kwargs = cpn.get_input() + task_fn = cpn.invoke i += 1 if task_fn is None: continue - tasks.append(loop.run_in_executor(self._thread_pool, task_fn)) + invoke_async = getattr(cpn, "invoke_async", None) + if invoke_async and asyncio.iscoroutinefunction(invoke_async): + tasks.append(loop.run_in_executor(self._thread_pool, partial(_run_async_in_thread, invoke_async, **(call_kwargs or {})))) + else: + tasks.append(loop.run_in_executor(self._thread_pool, partial(task_fn, **(call_kwargs or {})))) if tasks: await asyncio.gather(*tasks) diff --git a/agent/component/agent_with_tools.py b/agent/component/agent_with_tools.py index 979b636af..e0633e751 100644 --- a/agent/component/agent_with_tools.py +++ b/agent/component/agent_with_tools.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import asyncio import json import logging import os @@ -239,6 +240,86 @@ class Agent(LLM, ToolBase): self.set_output("use_tools", use_tools) return ans + async def _invoke_async(self, **kwargs): + """ + Async entry: reuse existing logic but offload heavy sync parts via async wrappers to reduce blocking. + """ + if self.check_if_canceled("Agent processing"): + return + + if kwargs.get("user_prompt"): + usr_pmt = "" + if kwargs.get("reasoning"): + usr_pmt += "\nREASONING:\n{}\n".format(kwargs["reasoning"]) + if kwargs.get("context"): + usr_pmt += "\nCONTEXT:\n{}\n".format(kwargs["context"]) + if usr_pmt: + usr_pmt += "\nQUERY:\n{}\n".format(str(kwargs["user_prompt"])) + else: + usr_pmt = str(kwargs["user_prompt"]) + self._param.prompts = [{"role": "user", "content": usr_pmt}] + + if not self.tools: + if self.check_if_canceled("Agent processing"): + return + return await asyncio.to_thread(LLM._invoke, self, **kwargs) + + prompt, msg, user_defined_prompt = self._prepare_prompt_variables() + output_schema = self._get_output_schema() + schema_prompt = "" + if output_schema: + schema = json.dumps(output_schema, ensure_ascii=False, indent=2) + schema_prompt = structured_output_prompt(schema) + + downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else [] + ex = self.exception_handler() + if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not (ex and ex["goto"]) and not output_schema: + self.set_output("content", partial(self.stream_output_with_tools_async, prompt, msg, user_defined_prompt)) + return + + _, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97)) + use_tools = [] + ans = "" + async for delta_ans, tk in self._react_with_tools_streamly_async(prompt, msg, use_tools, user_defined_prompt, schema_prompt=schema_prompt): + if self.check_if_canceled("Agent processing"): + return + ans += delta_ans + + if ans.find("**ERROR**") >= 0: + logging.error(f"Agent._chat got error. response: {ans}") + if self.get_exception_default_value(): + self.set_output("content", self.get_exception_default_value()) + else: + self.set_output("_ERROR", ans) + return + + if output_schema: + error = "" + for _ in range(self._param.max_retries + 1): + try: + def clean_formated_answer(ans: str) -> str: + ans = re.sub(r"^.*", "", ans, flags=re.DOTALL) + ans = re.sub(r"^.*```json", "", ans, flags=re.DOTALL) + return re.sub(r"```\n*$", "", ans, flags=re.DOTALL) + obj = json_repair.loads(clean_formated_answer(ans)) + self.set_output("structured", obj) + if use_tools: + self.set_output("use_tools", use_tools) + return obj + except Exception: + error = "The answer cannot be parsed as JSON" + ans = self._force_format_to_schema(ans, schema_prompt) + if ans.find("**ERROR**") >= 0: + continue + + self.set_output("_ERROR", error) + return + + self.set_output("content", ans) + if use_tools: + self.set_output("use_tools", use_tools) + return ans + def stream_output_with_tools(self, prompt, msg, user_defined_prompt={}): _, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97)) answer_without_toolcall = "" @@ -261,6 +342,54 @@ class Agent(LLM, ToolBase): if use_tools: self.set_output("use_tools", use_tools) + async def stream_output_with_tools_async(self, prompt, msg, user_defined_prompt={}): + _, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97)) + answer_without_toolcall = "" + use_tools = [] + async for delta_ans, _ in self._react_with_tools_streamly_async(prompt, msg, use_tools, user_defined_prompt): + if self.check_if_canceled("Agent streaming"): + return + + if delta_ans.find("**ERROR**") >= 0: + if self.get_exception_default_value(): + self.set_output("content", self.get_exception_default_value()) + yield self.get_exception_default_value() + else: + self.set_output("_ERROR", delta_ans) + return + answer_without_toolcall += delta_ans + yield delta_ans + + self.set_output("content", answer_without_toolcall) + if use_tools: + self.set_output("use_tools", use_tools) + + async def _react_with_tools_streamly_async(self, prompt, history: list[dict], use_tools, user_defined_prompt={}, schema_prompt: str = ""): + """ + Async wrapper that offloads synchronous flow to a thread, yielding results without blocking the event loop. + """ + loop = asyncio.get_running_loop() + queue: asyncio.Queue = asyncio.Queue() + + def worker(): + try: + for delta_ans, tk in self._react_with_tools_streamly(prompt, history, use_tools, user_defined_prompt, schema_prompt=schema_prompt): + asyncio.run_coroutine_threadsafe(queue.put((delta_ans, tk)), loop) + except Exception as e: + asyncio.run_coroutine_threadsafe(queue.put(e), loop) + finally: + asyncio.run_coroutine_threadsafe(queue.put(StopAsyncIteration), loop) + + await asyncio.to_thread(worker) + + while True: + item = await queue.get() + if item is StopAsyncIteration: + break + if isinstance(item, Exception): + raise item + yield item + def _gen_citations(self, text): retrievals = self._canvas.get_reference() retrievals = {"chunks": list(retrievals["chunks"].values()), "doc_aggs": list(retrievals["doc_aggs"].values())} @@ -433,4 +562,3 @@ Respond immediately with your final comprehensive answer. for k in self._param.inputs.keys(): self._param.inputs[k]["value"] = None self._param.debug_inputs = {} - diff --git a/agent/component/base.py b/agent/component/base.py index 0864ccb9e..6ac95e09a 100644 --- a/agent/component/base.py +++ b/agent/component/base.py @@ -14,6 +14,7 @@ # limitations under the License. # +import asyncio import re import time from abc import ABC @@ -445,6 +446,34 @@ class ComponentBase(ABC): self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time")) return self.output() + async def invoke_async(self, **kwargs) -> dict[str, Any]: + """ + Async wrapper for component invocation. + Prefers coroutine `_invoke_async` if present; otherwise falls back to `_invoke`. + Handles timing and error recording consistently with `invoke`. + """ + self.set_output("_created_time", time.perf_counter()) + try: + if self.check_if_canceled("Component processing"): + return + + fn_async = getattr(self, "_invoke_async", None) + if fn_async and asyncio.iscoroutinefunction(fn_async): + await fn_async(**kwargs) + elif asyncio.iscoroutinefunction(self._invoke): + await self._invoke(**kwargs) + else: + await asyncio.to_thread(self._invoke, **kwargs) + except Exception as e: + if self.get_exception_default_value(): + self.set_exception_default_value() + else: + self.set_output("_ERROR", str(e)) + logging.exception(e) + self._param.debug_inputs = {} + self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time")) + return self.output() + @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))) def _invoke(self, **kwargs): raise NotImplementedError() diff --git a/agent/component/llm.py b/agent/component/llm.py index a29a36860..483fde647 100644 --- a/agent/component/llm.py +++ b/agent/component/llm.py @@ -13,12 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import asyncio import json import logging import os import re +import threading from copy import deepcopy -from typing import Any, Generator +from typing import Any, Generator, AsyncGenerator import json_repair from functools import partial from common.constants import LLMType @@ -171,6 +173,13 @@ class LLM(ComponentBase): return self.chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs) return self.chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs) + async def _generate_async(self, msg: list[dict], **kwargs) -> str: + if not self.imgs and hasattr(self.chat_mdl, "async_chat"): + return await self.chat_mdl.async_chat(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs) + if self.imgs and hasattr(self.chat_mdl, "async_chat"): + return await self.chat_mdl.async_chat(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs) + return await asyncio.to_thread(self._generate, msg, **kwargs) + def _generate_streamly(self, msg:list[dict], **kwargs) -> Generator[str, None, None]: ans = "" last_idx = 0 @@ -205,6 +214,69 @@ class LLM(ComponentBase): for txt in self.chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs): yield delta(txt) + async def _generate_streamly_async(self, msg: list[dict], **kwargs) -> AsyncGenerator[str, None]: + async def delta_wrapper(txt_iter): + ans = "" + last_idx = 0 + endswith_think = False + + def delta(txt): + nonlocal ans, last_idx, endswith_think + delta_ans = txt[last_idx:] + ans = txt + + if delta_ans.find("") == 0: + last_idx += len("") + return "" + elif delta_ans.find("") > 0: + delta_ans = txt[last_idx:last_idx + delta_ans.find("")] + last_idx += delta_ans.find("") + return delta_ans + elif delta_ans.endswith(""): + endswith_think = True + elif endswith_think: + endswith_think = False + return "" + + last_idx = len(ans) + if ans.endswith(""): + last_idx -= len("") + return re.sub(r"(|)", "", delta_ans) + + async for t in txt_iter: + yield delta(t) + + if not self.imgs and hasattr(self.chat_mdl, "async_chat_streamly"): + async for t in delta_wrapper(self.chat_mdl.async_chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs)): + yield t + return + if self.imgs and hasattr(self.chat_mdl, "async_chat_streamly"): + async for t in delta_wrapper(self.chat_mdl.async_chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs)): + yield t + return + + # fallback + loop = asyncio.get_running_loop() + queue: asyncio.Queue = asyncio.Queue() + + def worker(): + try: + for item in self._generate_streamly(msg, **kwargs): + loop.call_soon_threadsafe(queue.put_nowait, item) + except Exception as e: + loop.call_soon_threadsafe(queue.put_nowait, e) + finally: + loop.call_soon_threadsafe(queue.put_nowait, StopAsyncIteration) + + threading.Thread(target=worker, daemon=True).start() + while True: + item = await queue.get() + if item is StopAsyncIteration: + break + if isinstance(item, Exception): + raise item + yield item + async def _stream_output_async(self, prompt, msg): _, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97)) answer = "" diff --git a/agent/tools/base.py b/agent/tools/base.py index 791242d59..87e3126b1 100644 --- a/agent/tools/base.py +++ b/agent/tools/base.py @@ -17,6 +17,7 @@ import logging import re import time from copy import deepcopy +import asyncio from functools import partial from typing import TypedDict, List, Any from agent.component.base import ComponentParamBase, ComponentBase @@ -50,10 +51,14 @@ class LLMToolPluginCallSession(ToolCallSession): def tool_call(self, name: str, arguments: dict[str, Any]) -> Any: assert name in self.tools_map, f"LLM tool {name} does not exist" st = timer() - if isinstance(self.tools_map[name], MCPToolCallSession): - resp = self.tools_map[name].tool_call(name, arguments, 60) + tool_obj = self.tools_map[name] + if isinstance(tool_obj, MCPToolCallSession): + resp = tool_obj.tool_call(name, arguments, 60) else: - resp = self.tools_map[name].invoke(**arguments) + if hasattr(tool_obj, "invoke_async") and asyncio.iscoroutinefunction(tool_obj.invoke_async): + resp = asyncio.run(tool_obj.invoke_async(**arguments)) + else: + resp = asyncio.run(asyncio.to_thread(tool_obj.invoke, **arguments)) self.callback(name, arguments, resp, elapsed_time=timer()-st) return resp @@ -139,6 +144,33 @@ class ToolBase(ComponentBase): self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time")) return res + async def invoke_async(self, **kwargs): + """ + Async wrapper for tool invocation. + If `_invoke` is a coroutine, await it directly; otherwise run in a thread to avoid blocking. + Mirrors the exception handling of `invoke`. + """ + if self.check_if_canceled("Tool processing"): + return + + self.set_output("_created_time", time.perf_counter()) + try: + fn_async = getattr(self, "_invoke_async", None) + if fn_async and asyncio.iscoroutinefunction(fn_async): + res = await fn_async(**kwargs) + elif asyncio.iscoroutinefunction(self._invoke): + res = await self._invoke(**kwargs) + else: + res = await asyncio.to_thread(self._invoke, **kwargs) + except Exception as e: + self._param.outputs["_ERROR"] = {"value": str(e)} + logging.exception(e) + res = str(e) + self._param.debug_inputs = [] + + self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time")) + return res + def _retrieve_chunks(self, res_list: list, get_title, get_url, get_content, get_score=None): chunks = [] aggs = [] diff --git a/api/apps/chunk_app.py b/api/apps/chunk_app.py index d5d928342..d96de64d0 100644 --- a/api/apps/chunk_app.py +++ b/api/apps/chunk_app.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import asyncio import datetime import json import re @@ -147,31 +148,35 @@ async def set(): d["available_int"] = req["available_int"] try: - tenant_id = DocumentService.get_tenant_id(req["doc_id"]) - if not tenant_id: - return get_data_error_result(message="Tenant not found!") + def _set_sync(): + tenant_id = DocumentService.get_tenant_id(req["doc_id"]) + if not tenant_id: + return get_data_error_result(message="Tenant not found!") - embd_id = DocumentService.get_embd_id(req["doc_id"]) - embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_id) + embd_id = DocumentService.get_embd_id(req["doc_id"]) + embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_id) - e, doc = DocumentService.get_by_id(req["doc_id"]) - if not e: - return get_data_error_result(message="Document not found!") + e, doc = DocumentService.get_by_id(req["doc_id"]) + if not e: + return get_data_error_result(message="Document not found!") - if doc.parser_id == ParserType.QA: - arr = [ - t for t in re.split( - r"[\n\t]", - req["content_with_weight"]) if len(t) > 1] - q, a = rmPrefix(arr[0]), rmPrefix("\n".join(arr[1:])) - d = beAdoc(d, q, a, not any( - [rag_tokenizer.is_chinese(t) for t in q + a])) + _d = d + if doc.parser_id == ParserType.QA: + arr = [ + t for t in re.split( + r"[\n\t]", + req["content_with_weight"]) if len(t) > 1] + q, a = rmPrefix(arr[0]), rmPrefix("\n".join(arr[1:])) + _d = beAdoc(d, q, a, not any( + [rag_tokenizer.is_chinese(t) for t in q + a])) - v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d.get("question_kwd") else "\n".join(d["question_kwd"])]) - v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1] - d["q_%d_vec" % len(v)] = v.tolist() - settings.docStoreConn.update({"id": req["chunk_id"]}, d, search.index_name(tenant_id), doc.kb_id) - return get_json_result(data=True) + v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not _d.get("question_kwd") else "\n".join(_d["question_kwd"])]) + v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1] + _d["q_%d_vec" % len(v)] = v.tolist() + settings.docStoreConn.update({"id": req["chunk_id"]}, _d, search.index_name(tenant_id), doc.kb_id) + return get_json_result(data=True) + + return await asyncio.to_thread(_set_sync) except Exception as e: return server_error_response(e) @@ -182,16 +187,19 @@ async def set(): async def switch(): req = await get_request_json() try: - e, doc = DocumentService.get_by_id(req["doc_id"]) - if not e: - return get_data_error_result(message="Document not found!") - for cid in req["chunk_ids"]: - if not settings.docStoreConn.update({"id": cid}, - {"available_int": int(req["available_int"])}, - search.index_name(DocumentService.get_tenant_id(req["doc_id"])), - doc.kb_id): - return get_data_error_result(message="Index updating failure") - return get_json_result(data=True) + def _switch_sync(): + e, doc = DocumentService.get_by_id(req["doc_id"]) + if not e: + return get_data_error_result(message="Document not found!") + for cid in req["chunk_ids"]: + if not settings.docStoreConn.update({"id": cid}, + {"available_int": int(req["available_int"])}, + search.index_name(DocumentService.get_tenant_id(req["doc_id"])), + doc.kb_id): + return get_data_error_result(message="Index updating failure") + return get_json_result(data=True) + + return await asyncio.to_thread(_switch_sync) except Exception as e: return server_error_response(e) @@ -202,20 +210,23 @@ async def switch(): async def rm(): req = await get_request_json() try: - e, doc = DocumentService.get_by_id(req["doc_id"]) - if not e: - return get_data_error_result(message="Document not found!") - if not settings.docStoreConn.delete({"id": req["chunk_ids"]}, - search.index_name(DocumentService.get_tenant_id(req["doc_id"])), - doc.kb_id): - return get_data_error_result(message="Chunk deleting failure") - deleted_chunk_ids = req["chunk_ids"] - chunk_number = len(deleted_chunk_ids) - DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0) - for cid in deleted_chunk_ids: - if settings.STORAGE_IMPL.obj_exist(doc.kb_id, cid): - settings.STORAGE_IMPL.rm(doc.kb_id, cid) - return get_json_result(data=True) + def _rm_sync(): + e, doc = DocumentService.get_by_id(req["doc_id"]) + if not e: + return get_data_error_result(message="Document not found!") + if not settings.docStoreConn.delete({"id": req["chunk_ids"]}, + search.index_name(DocumentService.get_tenant_id(req["doc_id"])), + doc.kb_id): + return get_data_error_result(message="Chunk deleting failure") + deleted_chunk_ids = req["chunk_ids"] + chunk_number = len(deleted_chunk_ids) + DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0) + for cid in deleted_chunk_ids: + if settings.STORAGE_IMPL.obj_exist(doc.kb_id, cid): + settings.STORAGE_IMPL.rm(doc.kb_id, cid) + return get_json_result(data=True) + + return await asyncio.to_thread(_rm_sync) except Exception as e: return server_error_response(e) @@ -245,35 +256,38 @@ async def create(): d["tag_feas"] = req["tag_feas"] try: - e, doc = DocumentService.get_by_id(req["doc_id"]) - if not e: - return get_data_error_result(message="Document not found!") - d["kb_id"] = [doc.kb_id] - d["docnm_kwd"] = doc.name - d["title_tks"] = rag_tokenizer.tokenize(doc.name) - d["doc_id"] = doc.id + def _create_sync(): + e, doc = DocumentService.get_by_id(req["doc_id"]) + if not e: + return get_data_error_result(message="Document not found!") + d["kb_id"] = [doc.kb_id] + d["docnm_kwd"] = doc.name + d["title_tks"] = rag_tokenizer.tokenize(doc.name) + d["doc_id"] = doc.id - tenant_id = DocumentService.get_tenant_id(req["doc_id"]) - if not tenant_id: - return get_data_error_result(message="Tenant not found!") + tenant_id = DocumentService.get_tenant_id(req["doc_id"]) + if not tenant_id: + return get_data_error_result(message="Tenant not found!") - e, kb = KnowledgebaseService.get_by_id(doc.kb_id) - if not e: - return get_data_error_result(message="Knowledgebase not found!") - if kb.pagerank: - d[PAGERANK_FLD] = kb.pagerank + e, kb = KnowledgebaseService.get_by_id(doc.kb_id) + if not e: + return get_data_error_result(message="Knowledgebase not found!") + if kb.pagerank: + d[PAGERANK_FLD] = kb.pagerank - embd_id = DocumentService.get_embd_id(req["doc_id"]) - embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING.value, embd_id) + embd_id = DocumentService.get_embd_id(req["doc_id"]) + embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING.value, embd_id) - v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d["question_kwd"] else "\n".join(d["question_kwd"])]) - v = 0.1 * v[0] + 0.9 * v[1] - d["q_%d_vec" % len(v)] = v.tolist() - settings.docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id) + v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d["question_kwd"] else "\n".join(d["question_kwd"])]) + v = 0.1 * v[0] + 0.9 * v[1] + d["q_%d_vec" % len(v)] = v.tolist() + settings.docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id) - DocumentService.increment_chunk_num( - doc.id, doc.kb_id, c, 1, 0) - return get_json_result(data={"chunk_id": chunck_id}) + DocumentService.increment_chunk_num( + doc.id, doc.kb_id, c, 1, 0) + return get_json_result(data={"chunk_id": chunck_id}) + + return await asyncio.to_thread(_create_sync) except Exception as e: return server_error_response(e) @@ -297,25 +311,28 @@ async def retrieval_test(): use_kg = req.get("use_kg", False) top = int(req.get("top_k", 1024)) langs = req.get("cross_languages", []) - tenant_ids = [] + user_id = current_user.id - if req.get("search_id", ""): - search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {}) - meta_data_filter = search_config.get("meta_data_filter", {}) - metas = DocumentService.get_meta_by_kbs(kb_ids) - if meta_data_filter.get("method") == "auto": - chat_mdl = LLMBundle(current_user.id, LLMType.CHAT, llm_name=search_config.get("chat_id", "")) - filters: dict = gen_meta_filter(chat_mdl, metas, question) - doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) - if not doc_ids: - doc_ids = None - elif meta_data_filter.get("method") == "manual": - doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and"))) - if meta_data_filter["manual"] and not doc_ids: - doc_ids = ["-999"] + def _retrieval_sync(): + local_doc_ids = list(doc_ids) if doc_ids else [] + tenant_ids = [] - try: - tenants = UserTenantService.query(user_id=current_user.id) + if req.get("search_id", ""): + search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {}) + meta_data_filter = search_config.get("meta_data_filter", {}) + metas = DocumentService.get_meta_by_kbs(kb_ids) + if meta_data_filter.get("method") == "auto": + chat_mdl = LLMBundle(user_id, LLMType.CHAT, llm_name=search_config.get("chat_id", "")) + filters: dict = gen_meta_filter(chat_mdl, metas, question) + local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) + if not local_doc_ids: + local_doc_ids = None + elif meta_data_filter.get("method") == "manual": + local_doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and"))) + if meta_data_filter["manual"] and not local_doc_ids: + local_doc_ids = ["-999"] + + tenants = UserTenantService.query(user_id=user_id) for kb_id in kb_ids: for tenant in tenants: if KnowledgebaseService.query( @@ -331,8 +348,9 @@ async def retrieval_test(): if not e: return get_data_error_result(message="Knowledgebase not found!") + _question = question if langs: - question = cross_languages(kb.tenant_id, None, question, langs) + _question = cross_languages(kb.tenant_id, None, _question, langs) embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id) @@ -342,19 +360,19 @@ async def retrieval_test(): if req.get("keyword", False): chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT) - question += keyword_extraction(chat_mdl, question) + _question += keyword_extraction(chat_mdl, _question) - labels = label_question(question, [kb]) - ranks = settings.retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, page, size, + labels = label_question(_question, [kb]) + ranks = settings.retriever.retrieval(_question, embd_mdl, tenant_ids, kb_ids, page, size, float(req.get("similarity_threshold", 0.0)), float(req.get("vector_similarity_weight", 0.3)), top, - doc_ids, rerank_mdl=rerank_mdl, + local_doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight", False), rank_feature=labels ) if use_kg: - ck = settings.kg_retriever.retrieval(question, + ck = settings.kg_retriever.retrieval(_question, tenant_ids, kb_ids, embd_mdl, @@ -367,6 +385,9 @@ async def retrieval_test(): ranks["labels"] = labels return get_json_result(data=ranks) + + try: + return await asyncio.to_thread(_retrieval_sync) except Exception as e: if str(e).find("not_found") > 0: return get_json_result(data=False, message='No chunk found! Check the chunk status please!', diff --git a/api/apps/connector_app.py b/api/apps/connector_app.py index 49d8005a6..44d3a3344 100644 --- a/api/apps/connector_app.py +++ b/api/apps/connector_app.py @@ -168,10 +168,12 @@ async def _render_web_oauth_popup(flow_id: str, success: bool, message: str, sou status = "success" if success else "error" auto_close = "window.close();" if success else "" escaped_message = escape(message) + # Drive: ragflow-google-drive-oauth + # Gmail: ragflow-gmail-oauth + payload_type = f"ragflow-{source}-oauth" payload_json = json.dumps( { - # TODO(google-oauth): include connector type (drive/gmail) in payload type if needed - "type": f"ragflow-google-{source}-oauth", + "type": payload_type, "status": status, "flowId": flow_id or "", "message": message, diff --git a/api/apps/conversation_app.py b/api/apps/conversation_app.py index e2be0532d..89630e4a4 100644 --- a/api/apps/conversation_app.py +++ b/api/apps/conversation_app.py @@ -462,7 +462,7 @@ async def related_questions(): if "parameter" in gen_conf: del gen_conf["parameter"] prompt = load_prompt("related_question") - ans = chat_mdl.chat( + ans = await chat_mdl.async_chat( prompt, [ { diff --git a/api/apps/document_app.py b/api/apps/document_app.py index a56f11317..ba52bd61c 100644 --- a/api/apps/document_app.py +++ b/api/apps/document_app.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License # +import asyncio import json import os.path import pathlib @@ -72,7 +73,7 @@ async def upload(): if not check_kb_team_permission(kb, current_user.id): return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) - err, files = FileService.upload_document(kb, file_objs, current_user.id) + err, files = await asyncio.to_thread(FileService.upload_document, kb, file_objs, current_user.id) if err: return get_json_result(data=files, message="\n".join(err), code=RetCode.SERVER_ERROR) @@ -390,7 +391,7 @@ async def rm(): if not DocumentService.accessible4deletion(doc_id, current_user.id): return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) - errors = FileService.delete_docs(doc_ids, current_user.id) + errors = await asyncio.to_thread(FileService.delete_docs, doc_ids, current_user.id) if errors: return get_json_result(data=False, message=errors, code=RetCode.SERVER_ERROR) @@ -403,44 +404,48 @@ async def rm(): @validate_request("doc_ids", "run") async def run(): req = await get_request_json() - for doc_id in req["doc_ids"]: - if not DocumentService.accessible(doc_id, current_user.id): - return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) try: - kb_table_num_map = {} - for id in req["doc_ids"]: - info = {"run": str(req["run"]), "progress": 0} - if str(req["run"]) == TaskStatus.RUNNING.value and req.get("delete", False): - info["progress_msg"] = "" - info["chunk_num"] = 0 - info["token_num"] = 0 + def _run_sync(): + for doc_id in req["doc_ids"]: + if not DocumentService.accessible(doc_id, current_user.id): + return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) - tenant_id = DocumentService.get_tenant_id(id) - if not tenant_id: - return get_data_error_result(message="Tenant not found!") - e, doc = DocumentService.get_by_id(id) - if not e: - return get_data_error_result(message="Document not found!") + kb_table_num_map = {} + for id in req["doc_ids"]: + info = {"run": str(req["run"]), "progress": 0} + if str(req["run"]) == TaskStatus.RUNNING.value and req.get("delete", False): + info["progress_msg"] = "" + info["chunk_num"] = 0 + info["token_num"] = 0 - if str(req["run"]) == TaskStatus.CANCEL.value: - if str(doc.run) == TaskStatus.RUNNING.value: - cancel_all_task_of(id) - else: - return get_data_error_result(message="Cannot cancel a task that is not in RUNNING status") - if all([("delete" not in req or req["delete"]), str(req["run"]) == TaskStatus.RUNNING.value, str(doc.run) == TaskStatus.DONE.value]): - DocumentService.clear_chunk_num_when_rerun(doc.id) + tenant_id = DocumentService.get_tenant_id(id) + if not tenant_id: + return get_data_error_result(message="Tenant not found!") + e, doc = DocumentService.get_by_id(id) + if not e: + return get_data_error_result(message="Document not found!") - DocumentService.update_by_id(id, info) - if req.get("delete", False): - TaskService.filter_delete([Task.doc_id == id]) - if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id): - settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id) + if str(req["run"]) == TaskStatus.CANCEL.value: + if str(doc.run) == TaskStatus.RUNNING.value: + cancel_all_task_of(id) + else: + return get_data_error_result(message="Cannot cancel a task that is not in RUNNING status") + if all([("delete" not in req or req["delete"]), str(req["run"]) == TaskStatus.RUNNING.value, str(doc.run) == TaskStatus.DONE.value]): + DocumentService.clear_chunk_num_when_rerun(doc.id) - if str(req["run"]) == TaskStatus.RUNNING.value: - doc = doc.to_dict() - DocumentService.run(tenant_id, doc, kb_table_num_map) + DocumentService.update_by_id(id, info) + if req.get("delete", False): + TaskService.filter_delete([Task.doc_id == id]) + if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id): + settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id) - return get_json_result(data=True) + if str(req["run"]) == TaskStatus.RUNNING.value: + doc_dict = doc.to_dict() + DocumentService.run(tenant_id, doc_dict, kb_table_num_map) + + return get_json_result(data=True) + + return await asyncio.to_thread(_run_sync) except Exception as e: return server_error_response(e) @@ -450,45 +455,49 @@ async def run(): @validate_request("doc_id", "name") async def rename(): req = await get_request_json() - if not DocumentService.accessible(req["doc_id"], current_user.id): - return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) try: - e, doc = DocumentService.get_by_id(req["doc_id"]) - if not e: - return get_data_error_result(message="Document not found!") - if pathlib.Path(req["name"].lower()).suffix != pathlib.Path(doc.name.lower()).suffix: - return get_json_result(data=False, message="The extension of file can't be changed", code=RetCode.ARGUMENT_ERROR) - if len(req["name"].encode("utf-8")) > FILE_NAME_LEN_LIMIT: - return get_json_result(data=False, message=f"File name must be {FILE_NAME_LEN_LIMIT} bytes or less.", code=RetCode.ARGUMENT_ERROR) + def _rename_sync(): + if not DocumentService.accessible(req["doc_id"], current_user.id): + return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) - for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id): - if d.name == req["name"]: - return get_data_error_result(message="Duplicated document name in the same knowledgebase.") + e, doc = DocumentService.get_by_id(req["doc_id"]) + if not e: + return get_data_error_result(message="Document not found!") + if pathlib.Path(req["name"].lower()).suffix != pathlib.Path(doc.name.lower()).suffix: + return get_json_result(data=False, message="The extension of file can't be changed", code=RetCode.ARGUMENT_ERROR) + if len(req["name"].encode("utf-8")) > FILE_NAME_LEN_LIMIT: + return get_json_result(data=False, message=f"File name must be {FILE_NAME_LEN_LIMIT} bytes or less.", code=RetCode.ARGUMENT_ERROR) - if not DocumentService.update_by_id(req["doc_id"], {"name": req["name"]}): - return get_data_error_result(message="Database error (Document rename)!") + for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id): + if d.name == req["name"]: + return get_data_error_result(message="Duplicated document name in the same knowledgebase.") - informs = File2DocumentService.get_by_document_id(req["doc_id"]) - if informs: - e, file = FileService.get_by_id(informs[0].file_id) - FileService.update_by_id(file.id, {"name": req["name"]}) + if not DocumentService.update_by_id(req["doc_id"], {"name": req["name"]}): + return get_data_error_result(message="Database error (Document rename)!") - tenant_id = DocumentService.get_tenant_id(req["doc_id"]) - title_tks = rag_tokenizer.tokenize(req["name"]) - es_body = { - "docnm_kwd": req["name"], - "title_tks": title_tks, - "title_sm_tks": rag_tokenizer.fine_grained_tokenize(title_tks), - } - if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id): - settings.docStoreConn.update( - {"doc_id": req["doc_id"]}, - es_body, - search.index_name(tenant_id), - doc.kb_id, - ) + informs = File2DocumentService.get_by_document_id(req["doc_id"]) + if informs: + e, file = FileService.get_by_id(informs[0].file_id) + FileService.update_by_id(file.id, {"name": req["name"]}) + + tenant_id = DocumentService.get_tenant_id(req["doc_id"]) + title_tks = rag_tokenizer.tokenize(req["name"]) + es_body = { + "docnm_kwd": req["name"], + "title_tks": title_tks, + "title_sm_tks": rag_tokenizer.fine_grained_tokenize(title_tks), + } + if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id): + settings.docStoreConn.update( + {"doc_id": req["doc_id"]}, + es_body, + search.index_name(tenant_id), + doc.kb_id, + ) + return get_json_result(data=True) + + return await asyncio.to_thread(_rename_sync) - return get_json_result(data=True) except Exception as e: return server_error_response(e) @@ -502,7 +511,8 @@ async def get(doc_id): return get_data_error_result(message="Document not found!") b, n = File2DocumentService.get_storage_address(doc_id=doc_id) - response = await make_response(settings.STORAGE_IMPL.get(b, n)) + data = await asyncio.to_thread(settings.STORAGE_IMPL.get, b, n) + response = await make_response(data) ext = re.search(r"\.([^.]+)$", doc.name.lower()) ext = ext.group(1) if ext else None @@ -523,8 +533,7 @@ async def get(doc_id): async def download_attachment(attachment_id): try: ext = request.args.get("ext", "markdown") - data = settings.STORAGE_IMPL.get(current_user.id, attachment_id) - # data = settings.STORAGE_IMPL.get("eb500d50bb0411f0907561d2782adda5", attachment_id) + data = await asyncio.to_thread(settings.STORAGE_IMPL.get, current_user.id, attachment_id) response = await make_response(data) response.headers.set("Content-Type", CONTENT_TYPE_MAP.get(ext, f"application/{ext}")) @@ -596,7 +605,8 @@ async def get_image(image_id): if len(arr) != 2: return get_data_error_result(message="Image not found.") bkt, nm = image_id.split("-") - response = await make_response(settings.STORAGE_IMPL.get(bkt, nm)) + data = await asyncio.to_thread(settings.STORAGE_IMPL.get, bkt, nm) + response = await make_response(data) response.headers.set("Content-Type", "image/JPEG") return response except Exception as e: diff --git a/api/apps/file_app.py b/api/apps/file_app.py index bbb5b3ddb..1ce5d4cae 100644 --- a/api/apps/file_app.py +++ b/api/apps/file_app.py @@ -14,6 +14,7 @@ # limitations under the License # import logging +import asyncio import os import pathlib import re @@ -61,9 +62,10 @@ async def upload(): e, pf_folder = FileService.get_by_id(pf_id) if not e: return get_data_error_result( message="Can't find this folder!") - for file_obj in file_objs: + + async def _handle_single_file(file_obj): MAX_FILE_NUM_PER_USER: int = int(os.environ.get('MAX_FILE_NUM_PER_USER', 0)) - if 0 < MAX_FILE_NUM_PER_USER <= DocumentService.get_doc_count(current_user.id): + if 0 < MAX_FILE_NUM_PER_USER <= await asyncio.to_thread(DocumentService.get_doc_count, current_user.id): return get_data_error_result( message="Exceed the maximum file number of a free user!") # split file name path @@ -75,35 +77,36 @@ async def upload(): file_len = len(file_obj_names) # get folder - file_id_list = FileService.get_id_list_by_id(pf_id, file_obj_names, 1, [pf_id]) + file_id_list = await asyncio.to_thread(FileService.get_id_list_by_id, pf_id, file_obj_names, 1, [pf_id]) len_id_list = len(file_id_list) # create folder if file_len != len_id_list: - e, file = FileService.get_by_id(file_id_list[len_id_list - 1]) + e, file = await asyncio.to_thread(FileService.get_by_id, file_id_list[len_id_list - 1]) if not e: return get_data_error_result(message="Folder not found!") - last_folder = FileService.create_folder(file, file_id_list[len_id_list - 1], file_obj_names, + last_folder = await asyncio.to_thread(FileService.create_folder, file, file_id_list[len_id_list - 1], file_obj_names, len_id_list) else: - e, file = FileService.get_by_id(file_id_list[len_id_list - 2]) + e, file = await asyncio.to_thread(FileService.get_by_id, file_id_list[len_id_list - 2]) if not e: return get_data_error_result(message="Folder not found!") - last_folder = FileService.create_folder(file, file_id_list[len_id_list - 2], file_obj_names, + last_folder = await asyncio.to_thread(FileService.create_folder, file, file_id_list[len_id_list - 2], file_obj_names, len_id_list) # file type filetype = filename_type(file_obj_names[file_len - 1]) location = file_obj_names[file_len - 1] - while settings.STORAGE_IMPL.obj_exist(last_folder.id, location): + while await asyncio.to_thread(settings.STORAGE_IMPL.obj_exist, last_folder.id, location): location += "_" - blob = file_obj.read() - filename = duplicate_name( + blob = await asyncio.to_thread(file_obj.read) + filename = await asyncio.to_thread( + duplicate_name, FileService.query, name=file_obj_names[file_len - 1], parent_id=last_folder.id) - settings.STORAGE_IMPL.put(last_folder.id, location, blob) - file = { + await asyncio.to_thread(settings.STORAGE_IMPL.put, last_folder.id, location, blob) + file_data = { "id": get_uuid(), "parent_id": last_folder.id, "tenant_id": current_user.id, @@ -113,8 +116,13 @@ async def upload(): "location": location, "size": len(blob), } - file = FileService.insert(file) - file_res.append(file.to_json()) + inserted = await asyncio.to_thread(FileService.insert, file_data) + return inserted.to_json() + + for file_obj in file_objs: + res = await _handle_single_file(file_obj) + file_res.append(res) + return get_json_result(data=file_res) except Exception as e: return server_error_response(e) @@ -242,55 +250,58 @@ async def rm(): req = await get_request_json() file_ids = req["file_ids"] - def _delete_single_file(file): - try: - if file.location: - settings.STORAGE_IMPL.rm(file.parent_id, file.location) - except Exception as e: - logging.exception(f"Fail to remove object: {file.parent_id}/{file.location}, error: {e}") - - informs = File2DocumentService.get_by_file_id(file.id) - for inform in informs: - doc_id = inform.document_id - e, doc = DocumentService.get_by_id(doc_id) - if e and doc: - tenant_id = DocumentService.get_tenant_id(doc_id) - if tenant_id: - DocumentService.remove_document(doc, tenant_id) - File2DocumentService.delete_by_file_id(file.id) - - FileService.delete(file) - - def _delete_folder_recursive(folder, tenant_id): - sub_files = FileService.list_all_files_by_parent_id(folder.id) - for sub_file in sub_files: - if sub_file.type == FileType.FOLDER.value: - _delete_folder_recursive(sub_file, tenant_id) - else: - _delete_single_file(sub_file) - - FileService.delete(folder) - try: - for file_id in file_ids: - e, file = FileService.get_by_id(file_id) - if not e or not file: - return get_data_error_result(message="File or Folder not found!") - if not file.tenant_id: - return get_data_error_result(message="Tenant not found!") - if not check_file_team_permission(file, current_user.id): - return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) + def _delete_single_file(file): + try: + if file.location: + settings.STORAGE_IMPL.rm(file.parent_id, file.location) + except Exception as e: + logging.exception(f"Fail to remove object: {file.parent_id}/{file.location}, error: {e}") - if file.source_type == FileSource.KNOWLEDGEBASE: - continue + informs = File2DocumentService.get_by_file_id(file.id) + for inform in informs: + doc_id = inform.document_id + e, doc = DocumentService.get_by_id(doc_id) + if e and doc: + tenant_id = DocumentService.get_tenant_id(doc_id) + if tenant_id: + DocumentService.remove_document(doc, tenant_id) + File2DocumentService.delete_by_file_id(file.id) - if file.type == FileType.FOLDER.value: - _delete_folder_recursive(file, current_user.id) - continue + FileService.delete(file) - _delete_single_file(file) + def _delete_folder_recursive(folder, tenant_id): + sub_files = FileService.list_all_files_by_parent_id(folder.id) + for sub_file in sub_files: + if sub_file.type == FileType.FOLDER.value: + _delete_folder_recursive(sub_file, tenant_id) + else: + _delete_single_file(sub_file) - return get_json_result(data=True) + FileService.delete(folder) + + def _rm_sync(): + for file_id in file_ids: + e, file = FileService.get_by_id(file_id) + if not e or not file: + return get_data_error_result(message="File or Folder not found!") + if not file.tenant_id: + return get_data_error_result(message="Tenant not found!") + if not check_file_team_permission(file, current_user.id): + return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) + + if file.source_type == FileSource.KNOWLEDGEBASE: + continue + + if file.type == FileType.FOLDER.value: + _delete_folder_recursive(file, current_user.id) + continue + + _delete_single_file(file) + + return get_json_result(data=True) + + return await asyncio.to_thread(_rm_sync) except Exception as e: return server_error_response(e) @@ -346,10 +357,10 @@ async def get(file_id): if not check_file_team_permission(file, current_user.id): return get_json_result(data=False, message='No authorization.', code=RetCode.AUTHENTICATION_ERROR) - blob = settings.STORAGE_IMPL.get(file.parent_id, file.location) + blob = await asyncio.to_thread(settings.STORAGE_IMPL.get, file.parent_id, file.location) if not blob: b, n = File2DocumentService.get_storage_address(file_id=file_id) - blob = settings.STORAGE_IMPL.get(b, n) + blob = await asyncio.to_thread(settings.STORAGE_IMPL.get, b, n) response = await make_response(blob) ext = re.search(r"\.([^.]+)$", file.name.lower()) @@ -444,10 +455,12 @@ async def move(): }, ) - for file in files: - _move_entry_recursive(file, dest_folder) + def _move_sync(): + for file in files: + _move_entry_recursive(file, dest_folder) + return get_json_result(data=True) - return get_json_result(data=True) + return await asyncio.to_thread(_move_sync) except Exception as e: return server_error_response(e) diff --git a/api/apps/kb_app.py b/api/apps/kb_app.py index 7ff01cc19..5d3dee0b9 100644 --- a/api/apps/kb_app.py +++ b/api/apps/kb_app.py @@ -17,6 +17,7 @@ import json import logging import random import re +import asyncio from quart import request import numpy as np @@ -116,12 +117,22 @@ async def update(): if kb.pagerank != req.get("pagerank", 0): if req.get("pagerank", 0) > 0: - settings.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]}, - search.index_name(kb.tenant_id), kb.id) + await asyncio.to_thread( + settings.docStoreConn.update, + {"kb_id": kb.id}, + {PAGERANK_FLD: req["pagerank"]}, + search.index_name(kb.tenant_id), + kb.id, + ) else: # Elasticsearch requires PAGERANK_FLD be non-zero! - settings.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD}, - search.index_name(kb.tenant_id), kb.id) + await asyncio.to_thread( + settings.docStoreConn.update, + {"exists": PAGERANK_FLD}, + {"remove": PAGERANK_FLD}, + search.index_name(kb.tenant_id), + kb.id, + ) e, kb = KnowledgebaseService.get_by_id(kb.id) if not e: @@ -224,25 +235,28 @@ async def rm(): data=False, message='Only owner of knowledgebase authorized for this operation.', code=RetCode.OPERATING_ERROR) - for doc in DocumentService.query(kb_id=req["kb_id"]): - if not DocumentService.remove_document(doc, kbs[0].tenant_id): + def _rm_sync(): + for doc in DocumentService.query(kb_id=req["kb_id"]): + if not DocumentService.remove_document(doc, kbs[0].tenant_id): + return get_data_error_result( + message="Database error (Document removal)!") + f2d = File2DocumentService.get_by_document_id(doc.id) + if f2d: + FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id]) + File2DocumentService.delete_by_document_id(doc.id) + FileService.filter_delete( + [File.source_type == FileSource.KNOWLEDGEBASE, File.type == "folder", File.name == kbs[0].name]) + if not KnowledgebaseService.delete_by_id(req["kb_id"]): return get_data_error_result( - message="Database error (Document removal)!") - f2d = File2DocumentService.get_by_document_id(doc.id) - if f2d: - FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id]) - File2DocumentService.delete_by_document_id(doc.id) - FileService.filter_delete( - [File.source_type == FileSource.KNOWLEDGEBASE, File.type == "folder", File.name == kbs[0].name]) - if not KnowledgebaseService.delete_by_id(req["kb_id"]): - return get_data_error_result( - message="Database error (Knowledgebase removal)!") - for kb in kbs: - settings.docStoreConn.delete({"kb_id": kb.id}, search.index_name(kb.tenant_id), kb.id) - settings.docStoreConn.deleteIdx(search.index_name(kb.tenant_id), kb.id) - if hasattr(settings.STORAGE_IMPL, 'remove_bucket'): - settings.STORAGE_IMPL.remove_bucket(kb.id) - return get_json_result(data=True) + message="Database error (Knowledgebase removal)!") + for kb in kbs: + settings.docStoreConn.delete({"kb_id": kb.id}, search.index_name(kb.tenant_id), kb.id) + settings.docStoreConn.deleteIdx(search.index_name(kb.tenant_id), kb.id) + if hasattr(settings.STORAGE_IMPL, 'remove_bucket'): + settings.STORAGE_IMPL.remove_bucket(kb.id) + return get_json_result(data=True) + + return await asyncio.to_thread(_rm_sync) except Exception as e: return server_error_response(e) @@ -922,5 +936,3 @@ async def check_embedding(): if summary["avg_cos_sim"] > 0.9: return get_json_result(data={"summary": summary, "results": results}) return get_json_result(code=RetCode.NOT_EFFECTIVE, message="Embedding model switch failed: the average similarity between old and new vectors is below 0.9, indicating incompatible vector spaces.", data={"summary": summary, "results": results}) - - diff --git a/api/apps/sdk/session.py b/api/apps/sdk/session.py index 6276877a2..e94f14fcc 100644 --- a/api/apps/sdk/session.py +++ b/api/apps/sdk/session.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import asyncio import json import re import time @@ -787,7 +788,7 @@ Reason: - At the same time, related terms can also help search engines better understand user needs and return more accurate search results. """ - ans = chat_mdl.chat( + ans = await chat_mdl.async_chat( prompt, [ { @@ -963,28 +964,30 @@ async def retrieval_test_embedded(): use_kg = req.get("use_kg", False) top = int(req.get("top_k", 1024)) langs = req.get("cross_languages", []) - tenant_ids = [] - tenant_id = objs[0].tenant_id if not tenant_id: return get_error_data_result(message="permission denined.") - if req.get("search_id", ""): - search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {}) - meta_data_filter = search_config.get("meta_data_filter", {}) - metas = DocumentService.get_meta_by_kbs(kb_ids) - if meta_data_filter.get("method") == "auto": - chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_name=search_config.get("chat_id", "")) - filters: dict = gen_meta_filter(chat_mdl, metas, question) - doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) - if not doc_ids: - doc_ids = None - elif meta_data_filter.get("method") == "manual": - doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and"))) - if meta_data_filter["manual"] and not doc_ids: - doc_ids = ["-999"] + def _retrieval_sync(): + local_doc_ids = list(doc_ids) if doc_ids else [] + tenant_ids = [] + _question = question + + if req.get("search_id", ""): + search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {}) + meta_data_filter = search_config.get("meta_data_filter", {}) + metas = DocumentService.get_meta_by_kbs(kb_ids) + if meta_data_filter.get("method") == "auto": + chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_name=search_config.get("chat_id", "")) + filters: dict = gen_meta_filter(chat_mdl, metas, _question) + local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) + if not local_doc_ids: + local_doc_ids = None + elif meta_data_filter.get("method") == "manual": + local_doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and"))) + if meta_data_filter["manual"] and not local_doc_ids: + local_doc_ids = ["-999"] - try: tenants = UserTenantService.query(user_id=tenant_id) for kb_id in kb_ids: for tenant in tenants: @@ -1000,7 +1003,7 @@ async def retrieval_test_embedded(): return get_error_data_result(message="Knowledgebase not found!") if langs: - question = cross_languages(kb.tenant_id, None, question, langs) + _question = cross_languages(kb.tenant_id, None, _question, langs) embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id) @@ -1010,15 +1013,15 @@ async def retrieval_test_embedded(): if req.get("keyword", False): chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT) - question += keyword_extraction(chat_mdl, question) + _question += keyword_extraction(chat_mdl, _question) - labels = label_question(question, [kb]) + labels = label_question(_question, [kb]) ranks = settings.retriever.retrieval( - question, embd_mdl, tenant_ids, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top, - doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels + _question, embd_mdl, tenant_ids, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top, + local_doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels ) if use_kg: - ck = settings.kg_retriever.retrieval(question, tenant_ids, kb_ids, embd_mdl, + ck = settings.kg_retriever.retrieval(_question, tenant_ids, kb_ids, embd_mdl, LLMBundle(kb.tenant_id, LLMType.CHAT)) if ck["content_with_weight"]: ranks["chunks"].insert(0, ck) @@ -1028,6 +1031,9 @@ async def retrieval_test_embedded(): ranks["labels"] = labels return get_json_result(data=ranks) + + try: + return await asyncio.to_thread(_retrieval_sync) except Exception as e: if str(e).find("not_found") > 0: return get_json_result(data=False, message="No chunk found! Check the chunk status please!", @@ -1064,7 +1070,7 @@ async def related_questions_embedded(): gen_conf = search_config.get("llm_setting", {"temperature": 0.9}) prompt = load_prompt("related_question") - ans = chat_mdl.chat( + ans = await chat_mdl.async_chat( prompt, [ { diff --git a/api/db/services/document_service.py b/api/db/services/document_service.py index 7b7ef53ec..395dcad83 100644 --- a/api/db/services/document_service.py +++ b/api/db/services/document_service.py @@ -719,10 +719,14 @@ class DocumentService(CommonService): # only for special task and parsed docs and unfinished freeze_progress = special_task_running and doc_progress >= 1 and not finished msg = "\n".join(sorted(msg)) + begin_at = d.get("process_begin_at") + if not begin_at: + begin_at = datetime.now() + # fallback + cls.update_by_id(d["id"], {"process_begin_at": begin_at}) + info = { - "process_duration": datetime.timestamp( - datetime.now()) - - d["process_begin_at"].timestamp(), + "process_duration": max(datetime.timestamp(datetime.now()) - begin_at.timestamp(), 0), "run": status} if prg != 0 and not freeze_progress: info["progress"] = prg diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index 1b7140a2b..e69ff1868 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -1685,12 +1685,17 @@ class LiteLLMBase(ABC): yield ans, tol - async def async_chat(self, history, gen_conf, **kwargs): - logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2)) + async def async_chat(self, system, history, gen_conf, **kwargs): + hist = list(history) if history else [] + if system: + if not hist or hist[0].get("role") != "system": + hist.insert(0, {"role": "system", "content": system}) + + logging.info("[HISTORY]" + json.dumps(hist, ensure_ascii=False, indent=2)) if self.model_name.lower().find("qwen3") >= 0: kwargs["extra_body"] = {"enable_thinking": False} - completion_args = self._construct_completion_args(history=history, stream=False, tools=False, **gen_conf) + completion_args = self._construct_completion_args(history=hist, stream=False, tools=False, **gen_conf) for attempt in range(self.max_retries + 1): try: diff --git a/rag/prompts/generator.py b/rag/prompts/generator.py index dd33d885e..fa3f84679 100644 --- a/rag/prompts/generator.py +++ b/rag/prompts/generator.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import asyncio import datetime import json import logging @@ -360,6 +361,10 @@ def analyze_task(chat_mdl, prompt, task_name, tools_description: list[dict], use return kwd +async def analyze_task_async(chat_mdl, prompt, task_name, tools_description: list[dict], user_defined_prompts: dict={}): + return await asyncio.to_thread(analyze_task, chat_mdl, prompt, task_name, tools_description, user_defined_prompts) + + def next_step(chat_mdl, history:list, tools_description: list[dict], task_desc, user_defined_prompts: dict={}): if not tools_description: return "" @@ -378,6 +383,10 @@ def next_step(chat_mdl, history:list, tools_description: list[dict], task_desc, return json_str, tk_cnt +async def next_step_async(chat_mdl, history:list, tools_description: list[dict], task_desc, user_defined_prompts: dict={}): + return await asyncio.to_thread(next_step, chat_mdl, history, tools_description, task_desc, user_defined_prompts) + + def reflect(chat_mdl, history: list[dict], tool_call_res: list[Tuple], user_defined_prompts: dict={}): tool_calls = [{"name": p[0], "result": p[1]} for p in tool_call_res] goal = history[1]["content"] @@ -429,6 +438,14 @@ def rank_memories(chat_mdl, goal:str, sub_goal:str, tool_call_summaries: list[st return re.sub(r"^.*", "", ans, flags=re.DOTALL) +async def reflect_async(chat_mdl, history: list[dict], tool_call_res: list[Tuple], user_defined_prompts: dict={}): + return await asyncio.to_thread(reflect, chat_mdl, history, tool_call_res, user_defined_prompts) + + +async def rank_memories_async(chat_mdl, goal:str, sub_goal:str, tool_call_summaries: list[str], user_defined_prompts: dict={}): + return await asyncio.to_thread(rank_memories, chat_mdl, goal, sub_goal, tool_call_summaries, user_defined_prompts) + + def gen_meta_filter(chat_mdl, meta_data:dict, query: str) -> dict: meta_data_structure = {} for key, values in meta_data.items():