From 906969fe4e92def8f9b0f7f9ff156f8493fc6f1a Mon Sep 17 00:00:00 2001 From: Kevin Hu Date: Tue, 9 Sep 2025 19:45:10 +0800 Subject: [PATCH] Fix: exesql issue. (#9995) ### What problem does this PR solve? ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- agent/component/agent_with_tools.py | 2 +- agent/tools/exesql.py | 15 ++++++++++++++- api/apps/chunk_app.py | 4 ++++ api/apps/sdk/session.py | 3 +++ rag/prompts/prompts.py | 13 ++++++++----- 5 files changed, 30 insertions(+), 7 deletions(-) diff --git a/agent/component/agent_with_tools.py b/agent/component/agent_with_tools.py index 4019c89ae..6b57fa120 100644 --- a/agent/component/agent_with_tools.py +++ b/agent/component/agent_with_tools.py @@ -166,7 +166,7 @@ class Agent(LLM, ToolBase): _, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97)) use_tools = [] ans = "" - for delta_ans, tk in self._react_with_tools_streamly(prompt, msg, use_tools): + for delta_ans, tk in self._react_with_tools_streamly(prompt, msg, use_tools, user_defined_prompt): ans += delta_ans if ans.find("**ERROR**") >= 0: diff --git a/agent/tools/exesql.py b/agent/tools/exesql.py index 48f9e3b74..317941713 100644 --- a/agent/tools/exesql.py +++ b/agent/tools/exesql.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import json import os import re from abc import ABC @@ -93,8 +94,20 @@ class ExeSQL(ToolBase, ABC): sql = kwargs.get("sql") if not sql: raise Exception("SQL for `ExeSQL` MUST not be empty.") - sqls = sql.split(";") + vars = self.get_input_elements_from_text(sql) + args = {} + for k, o in vars.items(): + args[k] = o["value"] + if not isinstance(args[k], str): + try: + args[k] = json.dumps(args[k], ensure_ascii=False) + except Exception: + args[k] = str(args[k]) + self.set_input_value(k, args[k]) + sql = self.string_format(sql, args) + + sqls = sql.split(";") if self._param.db_type in ["mysql", "mariadb"]: db = pymysql.connect(db=self._param.database, user=self._param.username, host=self._param.host, port=self._param.port, password=self._param.password) diff --git a/api/apps/chunk_app.py b/api/apps/chunk_app.py index e790b4e2d..9b4c341b6 100644 --- a/api/apps/chunk_app.py +++ b/api/apps/chunk_app.py @@ -291,6 +291,10 @@ def retrieval_test(): kb_ids = req["kb_id"] if isinstance(kb_ids, str): kb_ids = [kb_ids] + if not kb_ids: + return get_json_result(data=False, message='Please specify dataset firstly.', + code=settings.RetCode.DATA_ERROR) + doc_ids = req.get("doc_ids", []) use_kg = req.get("use_kg", False) top = int(req.get("top_k", 1024)) diff --git a/api/apps/sdk/session.py b/api/apps/sdk/session.py index 2a2f534e7..80e45a778 100644 --- a/api/apps/sdk/session.py +++ b/api/apps/sdk/session.py @@ -941,6 +941,9 @@ def retrieval_test_embedded(): kb_ids = req["kb_id"] if isinstance(kb_ids, str): kb_ids = [kb_ids] + if not kb_ids: + return get_json_result(data=False, message='Please specify dataset firstly.', + code=settings.RetCode.DATA_ERROR) doc_ids = req.get("doc_ids", []) similarity_threshold = float(req.get("similarity_threshold", 0.0)) vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3)) diff --git a/rag/prompts/prompts.py b/rag/prompts/prompts.py index 5077d51f4..b15019b6f 100644 --- a/rag/prompts/prompts.py +++ b/rag/prompts/prompts.py @@ -158,7 +158,7 @@ PROMPT_JINJA_ENV = jinja2.Environment(autoescape=False, trim_blocks=True, lstrip def citation_prompt(user_defined_prompts: dict={}) -> str: - template = PROMPT_JINJA_ENV.from_string(CITATION_PROMPT_TEMPLATE) + template = PROMPT_JINJA_ENV.from_string(user_defined_prompts.get("citation_guidelines", CITATION_PROMPT_TEMPLATE)) return template.render() @@ -343,9 +343,12 @@ def analyze_task(chat_mdl, prompt, task_name, tools_description: list[dict], use tools_desc = tool_schema(tools_description) context = "" - template = PROMPT_JINJA_ENV.from_string(ANALYZE_TASK_USER) + if user_defined_prompts.get("task_analysis"): + template = PROMPT_JINJA_ENV.from_string(user_defined_prompts["task_analysis"]) + else: + template = PROMPT_JINJA_ENV.from_string(ANALYZE_TASK_SYSTEM + "\n\n" + ANALYZE_TASK_USER) context = template.render(task=task_name, context=context, agent_prompt=prompt, tools_desc=tools_desc) - kwd = chat_mdl.chat(ANALYZE_TASK_SYSTEM,[{"role": "user", "content": context}], {}) + kwd = chat_mdl.chat(context, [{"role": "user", "content": "Please analyze it."}]) if isinstance(kwd, tuple): kwd = kwd[0] kwd = re.sub(r"^.*", "", kwd, flags=re.DOTALL) @@ -358,7 +361,7 @@ def next_step(chat_mdl, history:list, tools_description: list[dict], task_desc, if not tools_description: return "" desc = tool_schema(tools_description) - template = PROMPT_JINJA_ENV.from_string(NEXT_STEP) + template = PROMPT_JINJA_ENV.from_string(user_defined_prompts.get("plan_generation", NEXT_STEP)) user_prompt = "\nWhat's the next tool to call? If ready OR IMPOSSIBLE TO BE READY, then call `complete_task`." hist = deepcopy(history) if hist[-1]["role"] == "user": @@ -375,7 +378,7 @@ def next_step(chat_mdl, history:list, tools_description: list[dict], task_desc, def reflect(chat_mdl, history: list[dict], tool_call_res: list[Tuple], user_defined_prompts: dict={}): tool_calls = [{"name": p[0], "result": p[1]} for p in tool_call_res] goal = history[1]["content"] - template = PROMPT_JINJA_ENV.from_string(REFLECT) + template = PROMPT_JINJA_ENV.from_string(user_defined_prompts.get("reflection", REFLECT)) user_prompt = template.render(goal=goal, tool_calls=tool_calls) hist = deepcopy(history) if hist[-1]["role"] == "user":