diff --git a/agent/canvas.py b/agent/canvas.py index 79cdcd2cb..986b4f9c9 100644 --- a/agent/canvas.py +++ b/agent/canvas.py @@ -484,7 +484,7 @@ class Canvas: threads.append(exe.submit(FileService.parse, file["name"], FileService.get_blob(file["created_by"], file["id"]), True, file["created_by"])) return [th.result() for th in threads] - def tool_use_callback(self, agent_id: str, func_name: str, params: dict, result: Any): + def tool_use_callback(self, agent_id: str, func_name: str, params: dict, result: Any, elapsed_time=None): agent_ids = agent_id.split("-->") agent_name = self.get_component_name(agent_ids[0]) path = agent_name if len(agent_ids) < 2 else agent_name+"-->"+"-->".join(agent_ids[1:]) @@ -493,16 +493,16 @@ class Canvas: if bin: obj = json.loads(bin.encode("utf-8")) if obj[-1]["component_id"] == agent_ids[0]: - obj[-1]["trace"].append({"path": path, "tool_name": func_name, "arguments": params, "result": result}) + obj[-1]["trace"].append({"path": path, "tool_name": func_name, "arguments": params, "result": result, "elapsed_time": elapsed_time}) else: obj.append({ "component_id": agent_ids[0], - "trace": [{"path": path, "tool_name": func_name, "arguments": params, "result": result}] + "trace": [{"path": path, "tool_name": func_name, "arguments": params, "result": result, "elapsed_time": elapsed_time}] }) else: obj = [{ "component_id": agent_ids[0], - "trace": [{"path": path, "tool_name": func_name, "arguments": params, "result": result}] + "trace": [{"path": path, "tool_name": func_name, "arguments": params, "result": result, "elapsed_time": elapsed_time}] }] REDIS_CONN.set_obj(f"{self.task_id}-{self.message_id}-logs", obj, 60*10) except Exception as e: diff --git a/agent/component/agent_with_tools.py b/agent/component/agent_with_tools.py index f0369c0a6..52a6900ec 100644 --- a/agent/component/agent_with_tools.py +++ b/agent/component/agent_with_tools.py @@ -22,7 +22,7 @@ from functools import partial from typing import Any import json_repair - +from timeit import default_timer as timer from agent.tools.base import LLMToolPluginCallSession, ToolParamBase, ToolBase, ToolMeta from api.db.services.llm_service import LLMBundle from api.db.services.tenant_llm_service import TenantLLMService @@ -215,8 +215,9 @@ class Agent(LLM, ToolBase): hist = deepcopy(history) last_calling = "" if len(hist) > 3: + st = timer() user_request = full_question(messages=history, chat_mdl=self.chat_mdl) - self.callback("Multi-turn conversation optimization", {}, user_request) + self.callback("Multi-turn conversation optimization", {}, user_request, elapsed_time=timer()-st) else: user_request = history[-1]["content"] @@ -263,12 +264,13 @@ class Agent(LLM, ToolBase): if not need2cite or cited: return + st = timer() txt = "" for delta_ans in self._gen_citations(entire_txt): yield delta_ans, 0 txt += delta_ans - self.callback("gen_citations", {}, txt) + self.callback("gen_citations", {}, txt, elapsed_time=timer()-st) def append_user_content(hist, content): if hist[-1]["role"] == "user": @@ -276,8 +278,9 @@ class Agent(LLM, ToolBase): else: hist.append({"role": "user", "content": content}) + st = timer() task_desc = analyze_task(self.chat_mdl, prompt, user_request, tool_metas) - self.callback("analyze_task", {}, task_desc) + self.callback("analyze_task", {}, task_desc, elapsed_time=timer()-st) for _ in range(self._param.max_rounds + 1): response, tk = next_step(self.chat_mdl, hist, tool_metas, task_desc) # self.callback("next_step", {}, str(response)[:256]+"...") @@ -303,9 +306,10 @@ class Agent(LLM, ToolBase): thr.append(executor.submit(use_tool, name, args)) + st = timer() reflection = reflect(self.chat_mdl, hist, [th.result() for th in thr]) append_user_content(hist, reflection) - self.callback("reflection", {}, str(reflection)) + self.callback("reflection", {}, str(reflection), elapsed_time=timer()-st) except Exception as e: logging.exception(msg=f"Wrong JSON argument format in LLM ReAct response: {e}") diff --git a/agent/tools/base.py b/agent/tools/base.py index 5dbe26da3..8e6b78dd0 100644 --- a/agent/tools/base.py +++ b/agent/tools/base.py @@ -24,6 +24,7 @@ from api.utils import hash_str2int from rag.llm.chat_model import ToolCallSession from rag.prompts.prompts import kb_prompt from rag.utils.mcp_tool_call_conn import MCPToolCallSession +from timeit import default_timer as timer class ToolParameter(TypedDict): @@ -49,12 +50,13 @@ class LLMToolPluginCallSession(ToolCallSession): def tool_call(self, name: str, arguments: dict[str, Any]) -> Any: assert name in self.tools_map, f"LLM tool {name} does not exist" + st = timer() if isinstance(self.tools_map[name], MCPToolCallSession): resp = self.tools_map[name].tool_call(name, arguments, 60) else: resp = self.tools_map[name].invoke(**arguments) - self.callback(name, arguments, resp) + self.callback(name, arguments, resp, elapsed_time=timer()-st) return resp def get_tool_obj(self, name): diff --git a/agent/tools/exesql.py b/agent/tools/exesql.py index ca03dfc21..48f9e3b74 100644 --- a/agent/tools/exesql.py +++ b/agent/tools/exesql.py @@ -79,6 +79,17 @@ class ExeSQL(ToolBase, ABC): @timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60)) def _invoke(self, **kwargs): + + def convert_decimals(obj): + from decimal import Decimal + if isinstance(obj, Decimal): + return float(obj) # 或 str(obj) + elif isinstance(obj, dict): + return {k: convert_decimals(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [convert_decimals(item) for item in obj] + return obj + sql = kwargs.get("sql") if not sql: raise Exception("SQL for `ExeSQL` MUST not be empty.") @@ -122,7 +133,11 @@ class ExeSQL(ToolBase, ABC): single_res = pd.DataFrame([i for i in cursor.fetchmany(self._param.max_records)]) single_res.columns = [i[0] for i in cursor.description] - sql_res.append(single_res.to_dict(orient='records')) + for col in single_res.columns: + if pd.api.types.is_datetime64_any_dtype(single_res[col]): + single_res[col] = single_res[col].dt.strftime('%Y-%m-%d') + + sql_res.append(convert_decimals(single_res.to_dict(orient='records'))) formalized_content.append(single_res.to_markdown(index=False, floatfmt=".6f")) self.set_output("json", sql_res) @@ -130,4 +145,4 @@ class ExeSQL(ToolBase, ABC): return self.output("formalized_content") def thoughts(self) -> str: - return "Query sent—waiting for the data." \ No newline at end of file + return "Query sent—waiting for the data." diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index 97e508960..9f41ed03e 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -40,7 +40,7 @@ from rag.app.resume import forbidden_select_fields4resume from rag.app.tag import label_question from rag.nlp.search import index_name from rag.prompts import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, message_fit_in -from rag.prompts.prompts import gen_meta_filter +from rag.prompts.prompts import gen_meta_filter, PROMPT_JINJA_ENV, ASK_SUMMARY from rag.utils import num_tokens_from_string, rmSpace from rag.utils.tavily_conn import Tavily @@ -723,6 +723,7 @@ def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}): rerank_mdl = LLMBundle(tenant_id, LLMType.RERANK, rerank_id) max_tokens = chat_mdl.max_length tenant_ids = list(set([kb.tenant_id for kb in kbs])) + kbinfos = retriever.retrieval( question = question, embd_mdl=embd_mdl, @@ -740,26 +741,12 @@ def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}): ) knowledges = kb_prompt(kbinfos, max_tokens) - prompt = """ - Role: You're a smart assistant. Your name is Miss R. - Task: Summarize the information from knowledge bases and answer user's question. - Requirements and restriction: - - DO NOT make things up, especially for numbers. - - If the information from knowledge is irrelevant with user's question, JUST SAY: Sorry, no relevant information provided. - - Answer with markdown format text. - - Answer in language of user's question. - - DO NOT make things up, especially for numbers. + sys_prompt = PROMPT_JINJA_ENV.from_string(ASK_SUMMARY).render(knowledge="\n".join(knowledges)) - ### Information from knowledge bases - %s - - The above is information from knowledge bases. - - """ % "\n".join(knowledges) msg = [{"role": "user", "content": question}] def decorate_answer(answer): - nonlocal knowledges, kbinfos, prompt + nonlocal knowledges, kbinfos, sys_prompt answer, idx = retriever.insert_citations(answer, [ck["content_ltks"] for ck in kbinfos["chunks"]], [ck["vector"] for ck in kbinfos["chunks"]], embd_mdl, tkweight=0.7, vtweight=0.3) idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx]) recall_docs = [d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx] @@ -777,7 +764,7 @@ def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}): return {"answer": answer, "reference": refs} answer = "" - for ans in chat_mdl.chat_streamly(prompt, msg, {"temperature": 0.1}): + for ans in chat_mdl.chat_streamly(sys_prompt, msg, {"temperature": 0.1}): answer = ans yield {"answer": answer, "reference": {}} yield decorate_answer(answer) diff --git a/rag/nlp/__init__.py b/rag/nlp/__init__.py index ddfc387bb..e76624113 100644 --- a/rag/nlp/__init__.py +++ b/rag/nlp/__init__.py @@ -611,10 +611,6 @@ def naive_merge_with_images(texts, images, chunk_token_num=128, delimiter="\n。 if re.match(f"^{dels}$", sub_sec): continue add_chunk(sub_sec, image) - - for img in images: - if isinstance(img, Image.Image): - img.close() return cks, result_images diff --git a/rag/prompts/ask_summary.md b/rag/prompts/ask_summary.md new file mode 100644 index 000000000..2074e9c3a --- /dev/null +++ b/rag/prompts/ask_summary.md @@ -0,0 +1,14 @@ +Role: You're a smart assistant. Your name is Miss R. +Task: Summarize the information from knowledge bases and answer user's question. +Requirements and restriction: + - DO NOT make things up, especially for numbers. + - If the information from knowledge is irrelevant with user's question, JUST SAY: Sorry, no relevant information provided. + - Answer with markdown format text. + - Answer in language of user's question. + - DO NOT make things up, especially for numbers. + +### Information from knowledge bases + +{{ knowledge }} + +The above is information from knowledge bases. diff --git a/rag/prompts/prompts.py b/rag/prompts/prompts.py index c49c92ea9..1e6255b35 100644 --- a/rag/prompts/prompts.py +++ b/rag/prompts/prompts.py @@ -150,6 +150,7 @@ REFLECT = load_prompt("reflect") SUMMARY4MEMORY = load_prompt("summary4memory") RANK_MEMORY = load_prompt("rank_memory") META_FILTER = load_prompt("meta_filter") +ASK_SUMMARY = load_prompt("ask_summary") PROMPT_JINJA_ENV = jinja2.Environment(autoescape=False, trim_blocks=True, lstrip_blocks=True)