Fix: exesql issue. (#9995)

### What problem does this PR solve?

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
Kevin Hu
2025-09-09 19:45:10 +08:00
committed by GitHub
parent 776ea078a6
commit 906969fe4e
5 changed files with 30 additions and 7 deletions

View File

@ -166,7 +166,7 @@ class Agent(LLM, ToolBase):
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
use_tools = []
ans = ""
for delta_ans, tk in self._react_with_tools_streamly(prompt, msg, use_tools):
for delta_ans, tk in self._react_with_tools_streamly(prompt, msg, use_tools, user_defined_prompt):
ans += delta_ans
if ans.find("**ERROR**") >= 0:

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import os
import re
from abc import ABC
@ -93,8 +94,20 @@ class ExeSQL(ToolBase, ABC):
sql = kwargs.get("sql")
if not sql:
raise Exception("SQL for `ExeSQL` MUST not be empty.")
sqls = sql.split(";")
vars = self.get_input_elements_from_text(sql)
args = {}
for k, o in vars.items():
args[k] = o["value"]
if not isinstance(args[k], str):
try:
args[k] = json.dumps(args[k], ensure_ascii=False)
except Exception:
args[k] = str(args[k])
self.set_input_value(k, args[k])
sql = self.string_format(sql, args)
sqls = sql.split(";")
if self._param.db_type in ["mysql", "mariadb"]:
db = pymysql.connect(db=self._param.database, user=self._param.username, host=self._param.host,
port=self._param.port, password=self._param.password)

View File

@ -291,6 +291,10 @@ def retrieval_test():
kb_ids = req["kb_id"]
if isinstance(kb_ids, str):
kb_ids = [kb_ids]
if not kb_ids:
return get_json_result(data=False, message='Please specify dataset firstly.',
code=settings.RetCode.DATA_ERROR)
doc_ids = req.get("doc_ids", [])
use_kg = req.get("use_kg", False)
top = int(req.get("top_k", 1024))

View File

@ -941,6 +941,9 @@ def retrieval_test_embedded():
kb_ids = req["kb_id"]
if isinstance(kb_ids, str):
kb_ids = [kb_ids]
if not kb_ids:
return get_json_result(data=False, message='Please specify dataset firstly.',
code=settings.RetCode.DATA_ERROR)
doc_ids = req.get("doc_ids", [])
similarity_threshold = float(req.get("similarity_threshold", 0.0))
vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))

View File

@ -158,7 +158,7 @@ PROMPT_JINJA_ENV = jinja2.Environment(autoescape=False, trim_blocks=True, lstrip
def citation_prompt(user_defined_prompts: dict={}) -> str:
template = PROMPT_JINJA_ENV.from_string(CITATION_PROMPT_TEMPLATE)
template = PROMPT_JINJA_ENV.from_string(user_defined_prompts.get("citation_guidelines", CITATION_PROMPT_TEMPLATE))
return template.render()
@ -343,9 +343,12 @@ def analyze_task(chat_mdl, prompt, task_name, tools_description: list[dict], use
tools_desc = tool_schema(tools_description)
context = ""
template = PROMPT_JINJA_ENV.from_string(ANALYZE_TASK_USER)
if user_defined_prompts.get("task_analysis"):
template = PROMPT_JINJA_ENV.from_string(user_defined_prompts["task_analysis"])
else:
template = PROMPT_JINJA_ENV.from_string(ANALYZE_TASK_SYSTEM + "\n\n" + ANALYZE_TASK_USER)
context = template.render(task=task_name, context=context, agent_prompt=prompt, tools_desc=tools_desc)
kwd = chat_mdl.chat(ANALYZE_TASK_SYSTEM,[{"role": "user", "content": context}], {})
kwd = chat_mdl.chat(context, [{"role": "user", "content": "Please analyze it."}])
if isinstance(kwd, tuple):
kwd = kwd[0]
kwd = re.sub(r"^.*</think>", "", kwd, flags=re.DOTALL)
@ -358,7 +361,7 @@ def next_step(chat_mdl, history:list, tools_description: list[dict], task_desc,
if not tools_description:
return ""
desc = tool_schema(tools_description)
template = PROMPT_JINJA_ENV.from_string(NEXT_STEP)
template = PROMPT_JINJA_ENV.from_string(user_defined_prompts.get("plan_generation", NEXT_STEP))
user_prompt = "\nWhat's the next tool to call? If ready OR IMPOSSIBLE TO BE READY, then call `complete_task`."
hist = deepcopy(history)
if hist[-1]["role"] == "user":
@ -375,7 +378,7 @@ def next_step(chat_mdl, history:list, tools_description: list[dict], task_desc,
def reflect(chat_mdl, history: list[dict], tool_call_res: list[Tuple], user_defined_prompts: dict={}):
tool_calls = [{"name": p[0], "result": p[1]} for p in tool_call_res]
goal = history[1]["content"]
template = PROMPT_JINJA_ENV.from_string(REFLECT)
template = PROMPT_JINJA_ENV.from_string(user_defined_prompts.get("reflection", REFLECT))
user_prompt = template.render(goal=goal, tool_calls=tool_calls)
hist = deepcopy(history)
if hist[-1]["role"] == "user":