diff --git a/agent/component/agent_with_tools.py b/agent/component/agent_with_tools.py index 3511845ef..40c47d9b5 100644 --- a/agent/component/agent_with_tools.py +++ b/agent/component/agent_with_tools.py @@ -165,7 +165,7 @@ class Agent(LLM, ToolBase): _, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97)) use_tools = [] ans = "" - for delta_ans, tk in self._react_with_tools_streamly(msg, use_tools): + for delta_ans, tk in self._react_with_tools_streamly(prompt, msg, use_tools): ans += delta_ans if ans.find("**ERROR**") >= 0: @@ -185,7 +185,7 @@ class Agent(LLM, ToolBase): _, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97)) answer_without_toolcall = "" use_tools = [] - for delta_ans,_ in self._react_with_tools_streamly(msg, use_tools): + for delta_ans,_ in self._react_with_tools_streamly(prompt, msg, use_tools): if delta_ans.find("**ERROR**") >= 0: if self.get_exception_default_value(): self.set_output("content", self.get_exception_default_value()) @@ -208,7 +208,7 @@ class Agent(LLM, ToolBase): ]): yield delta_ans - def _react_with_tools_streamly(self, history: list[dict], use_tools): + def _react_with_tools_streamly(self, prompt, history: list[dict], use_tools): token_count = 0 tool_metas = self.tool_meta hist = deepcopy(history) @@ -221,7 +221,7 @@ class Agent(LLM, ToolBase): def use_tool(name, args): nonlocal hist, use_tools, token_count,last_calling,user_request - print(f"{last_calling=} == {name=}", ) + logging.info(f"{last_calling=} == {name=}") # Summarize of function calling #if all([ # isinstance(self.toolcall_session.get_tool_obj(name), Agent), @@ -275,7 +275,7 @@ class Agent(LLM, ToolBase): else: hist.append({"role": "user", "content": content}) - task_desc = analyze_task(self.chat_mdl, user_request, tool_metas) + task_desc = analyze_task(self.chat_mdl, prompt, user_request, tool_metas) self.callback("analyze_task", {}, task_desc) for _ in range(self._param.max_rounds + 1): response, tk = next_step(self.chat_mdl, hist, tool_metas, task_desc) diff --git a/rag/prompts/analyze_task_user.md b/rag/prompts/analyze_task_user.md index 65ddba7a8..499cad575 100644 --- a/rag/prompts/analyze_task_user.md +++ b/rag/prompts/analyze_task_user.md @@ -4,6 +4,9 @@ Task: {{ task }} Context: {{ context }} +**Agent Prompt** +{{ agent_prompt }} + **Analysis Requirements:** 1. Is it just a small talk? (If yes, no further plan or analysis is needed) 2. What is the core objective of the task? diff --git a/rag/prompts/prompts.py b/rag/prompts/prompts.py index aeeb6cb91..75c9369b8 100644 --- a/rag/prompts/prompts.py +++ b/rag/prompts/prompts.py @@ -335,13 +335,13 @@ def form_history(history, limit=-6): return context -def analyze_task(chat_mdl, task_name, tools_description: list[dict]): +def analyze_task(chat_mdl, prompt, task_name, tools_description: list[dict]): tools_desc = tool_schema(tools_description) context = "" template = PROMPT_JINJA_ENV.from_string(ANALYZE_TASK_USER) - - kwd = chat_mdl.chat(ANALYZE_TASK_SYSTEM,[{"role": "user", "content": template.render(task=task_name, context=context, tools_desc=tools_desc)}], {}) + context = template.render(task=task_name, context=context, agent_prompt=prompt, tools_desc=tools_desc) + kwd = chat_mdl.chat(ANALYZE_TASK_SYSTEM,[{"role": "user", "content": context}], {}) if isinstance(kwd, tuple): kwd = kwd[0] kwd = re.sub(r"^.*", "", kwd, flags=re.DOTALL)