From 564277736aa26125380b4524b73cbf29cd6c0c79 Mon Sep 17 00:00:00 2001 From: liuhua <10215101452@stu.ecnu.edu.cn> Date: Tue, 31 Dec 2024 19:58:56 +0800 Subject: [PATCH] Update exesql component for agent (#4307) ### What problem does this PR solve? Update exesql component for agent ### Type of change - [x] Refactoring --------- Co-authored-by: liuhua <10215101452@stu.ecun.edu.cn> --- agent/component/exesql.py | 88 ++++++++++++++++++++++++--------------- 1 file changed, 55 insertions(+), 33 deletions(-) diff --git a/agent/component/exesql.py b/agent/component/exesql.py index 253dac5af..dd210ad36 100644 --- a/agent/component/exesql.py +++ b/agent/component/exesql.py @@ -18,11 +18,11 @@ import re import pandas as pd import pymysql import psycopg2 -from agent.component.base import ComponentBase, ComponentParamBase +from agent.component import GenerateParam, Generate import pyodbc import logging -class ExeSQLParam(ComponentParamBase): +class ExeSQLParam(GenerateParam): """ Define the ExeSQL component parameters. """ @@ -39,6 +39,7 @@ class ExeSQLParam(ComponentParamBase): self.top_n = 30 def check(self): + super().check() self.check_valid_value(self.db_type, "Choose DB type", ['mysql', 'postgresql', 'mariadb', 'mssql']) self.check_empty(self.database, "Database name") self.check_empty(self.username, "database username") @@ -53,25 +54,14 @@ class ExeSQLParam(ComponentParamBase): raise ValueError("The host is not accessible.") -class ExeSQL(ComponentBase, ABC): +class ExeSQL(Generate, ABC): component_name = "ExeSQL" - def _run(self, history, **kwargs): - if not hasattr(self, "_loop"): - setattr(self, "_loop", 0) - if self._loop >= self._param.loop: - self._loop = 0 - raise Exception("Maximum loop time exceeds. Can't query the correct data via SQL statement.") - self._loop += 1 - - ans = self.get_input() - ans = "".join([str(a) for a in ans["content"]]) if "content" in ans else "" - - # improve the information extraction, most llm return results in markdown format ```sql query ``` + def _refactor(self,ans): match = re.search(r"```sql\s*(.*?)\s*```", ans, re.DOTALL) if match: ans = match.group(1) # Query content - print(ans) + return ans else: print("no markdown") ans = re.sub(r'^.*?SELECT ', 'SELECT ', (ans), flags=re.IGNORECASE) @@ -79,7 +69,12 @@ class ExeSQL(ComponentBase, ABC): ans = re.sub(r';[^;]*$', r';', ans) if not ans: raise Exception("SQL statement not found!") + return ans + def _run(self, history, **kwargs): + ans = self.get_input() + ans = "".join([str(a) for a in ans["content"]]) if "content" in ans else "" + ans = self._refactor(ans) logging.info("db_type: ",self._param.db_type) if self._param.db_type in ["mysql", "mariadb"]: db = pymysql.connect(db=self._param.database, user=self._param.username, host=self._param.host, @@ -100,25 +95,52 @@ class ExeSQL(ComponentBase, ABC): cursor = db.cursor() except Exception as e: raise Exception("Database Connection Failed! \n" + str(e)) + if not hasattr(self, "_loop"): + setattr(self, "_loop", 0) + self._loop += 1 + input_list=re.split(r';', ans.replace(r"\n", " ")) sql_res = [] - for single_sql in re.split(r';', ans.replace(r"\n", " ")): - if not single_sql: - continue - try: - logging.info("single_sql: ",single_sql) - cursor.execute(single_sql) - if cursor.rowcount == 0: - sql_res.append({"content": "\nTotal: 0\n No record in the database!"}) - continue - single_res = pd.DataFrame([i for i in cursor.fetchmany(self._param.top_n)]) - single_res.columns = [i[0] for i in cursor.description] - sql_res.append({"content": "\nTotal: " + str(cursor.rowcount) + "\n" + single_res.to_markdown()}) - except Exception as e: - sql_res.append({"content": "**Error**:" + str(e) + "\nError SQL Statement:" + single_sql}) - pass + for i in range(len(input_list)): + single_sql=input_list[i] + while self._loop <= self._param.loop: + self._loop+=1 + if not single_sql: + break + try: + logging.info("single_sql: ", single_sql) + cursor.execute(single_sql) + if cursor.rowcount == 0: + sql_res.append({"content": "\nTotal: 0\n No record in the database!"}) + break + single_res = pd.DataFrame([i for i in cursor.fetchmany(self._param.top_n)]) + single_res.columns = [i[0] for i in cursor.description] + sql_res.append({"content": "\nTotal: " + str(cursor.rowcount) + "\n" + single_res.to_markdown()}) + break + except Exception as e: + single_sql = self._regenerate_sql(single_sql, str(e), **kwargs) + single_sql = self._refactor(single_sql) + if self._loop > self._param.loop: + raise Exception("Maximum loop time exceeds. Can't query the correct data via SQL statement.") db.close() - if not sql_res: return ExeSQL.be_output("") - return pd.DataFrame(sql_res) + + def _regenerate_sql(self, failed_sql, error_message,**kwargs): + prompt = f''' + ## You are the Repair SQL Statement Helper, please modify the original SQL statement based on the SQL query error report. + ## The original SQL statement is as follows:{failed_sql}. + ## The contents of the SQL query error report is as follows:{error_message}. + ## Answer only the modified SQL statement. Please do not give any explanation, just answer the code. +''' + self._param.prompt=prompt + response = Generate._run(self, [], **kwargs) + try: + regenerated_sql = response.loc[0,"content"] + return regenerated_sql + except Exception as e: + logging.error(f"Failed to regenerate SQL: {e}") + return None + + def debug(self, **kwargs): + return self._run([], **kwargs) \ No newline at end of file