Update exesql component for agent (#4307)

### What problem does this PR solve?

Update exesql component for agent

### Type of change

- [x] Refactoring

---------

Co-authored-by: liuhua <10215101452@stu.ecun.edu.cn>
This commit is contained in:
liuhua
2024-12-31 19:58:56 +08:00
committed by GitHub
parent 061a22588a
commit 564277736a

View File

@ -18,11 +18,11 @@ import re
import pandas as pd import pandas as pd
import pymysql import pymysql
import psycopg2 import psycopg2
from agent.component.base import ComponentBase, ComponentParamBase from agent.component import GenerateParam, Generate
import pyodbc import pyodbc
import logging import logging
class ExeSQLParam(ComponentParamBase): class ExeSQLParam(GenerateParam):
""" """
Define the ExeSQL component parameters. Define the ExeSQL component parameters.
""" """
@ -39,6 +39,7 @@ class ExeSQLParam(ComponentParamBase):
self.top_n = 30 self.top_n = 30
def check(self): def check(self):
super().check()
self.check_valid_value(self.db_type, "Choose DB type", ['mysql', 'postgresql', 'mariadb', 'mssql']) self.check_valid_value(self.db_type, "Choose DB type", ['mysql', 'postgresql', 'mariadb', 'mssql'])
self.check_empty(self.database, "Database name") self.check_empty(self.database, "Database name")
self.check_empty(self.username, "database username") self.check_empty(self.username, "database username")
@ -53,25 +54,14 @@ class ExeSQLParam(ComponentParamBase):
raise ValueError("The host is not accessible.") raise ValueError("The host is not accessible.")
class ExeSQL(ComponentBase, ABC): class ExeSQL(Generate, ABC):
component_name = "ExeSQL" component_name = "ExeSQL"
def _run(self, history, **kwargs): def _refactor(self,ans):
if not hasattr(self, "_loop"):
setattr(self, "_loop", 0)
if self._loop >= self._param.loop:
self._loop = 0
raise Exception("Maximum loop time exceeds. Can't query the correct data via SQL statement.")
self._loop += 1
ans = self.get_input()
ans = "".join([str(a) for a in ans["content"]]) if "content" in ans else ""
# improve the information extraction, most llm return results in markdown format ```sql query ```
match = re.search(r"```sql\s*(.*?)\s*```", ans, re.DOTALL) match = re.search(r"```sql\s*(.*?)\s*```", ans, re.DOTALL)
if match: if match:
ans = match.group(1) # Query content ans = match.group(1) # Query content
print(ans) return ans
else: else:
print("no markdown") print("no markdown")
ans = re.sub(r'^.*?SELECT ', 'SELECT ', (ans), flags=re.IGNORECASE) ans = re.sub(r'^.*?SELECT ', 'SELECT ', (ans), flags=re.IGNORECASE)
@ -79,7 +69,12 @@ class ExeSQL(ComponentBase, ABC):
ans = re.sub(r';[^;]*$', r';', ans) ans = re.sub(r';[^;]*$', r';', ans)
if not ans: if not ans:
raise Exception("SQL statement not found!") raise Exception("SQL statement not found!")
return ans
def _run(self, history, **kwargs):
ans = self.get_input()
ans = "".join([str(a) for a in ans["content"]]) if "content" in ans else ""
ans = self._refactor(ans)
logging.info("db_type: ",self._param.db_type) logging.info("db_type: ",self._param.db_type)
if self._param.db_type in ["mysql", "mariadb"]: if self._param.db_type in ["mysql", "mariadb"]:
db = pymysql.connect(db=self._param.database, user=self._param.username, host=self._param.host, db = pymysql.connect(db=self._param.database, user=self._param.username, host=self._param.host,
@ -100,25 +95,52 @@ class ExeSQL(ComponentBase, ABC):
cursor = db.cursor() cursor = db.cursor()
except Exception as e: except Exception as e:
raise Exception("Database Connection Failed! \n" + str(e)) raise Exception("Database Connection Failed! \n" + str(e))
if not hasattr(self, "_loop"):
setattr(self, "_loop", 0)
self._loop += 1
input_list=re.split(r';', ans.replace(r"\n", " "))
sql_res = [] sql_res = []
for single_sql in re.split(r';', ans.replace(r"\n", " ")): for i in range(len(input_list)):
if not single_sql: single_sql=input_list[i]
continue while self._loop <= self._param.loop:
try: self._loop+=1
logging.info("single_sql: ",single_sql) if not single_sql:
cursor.execute(single_sql) break
if cursor.rowcount == 0: try:
sql_res.append({"content": "\nTotal: 0\n No record in the database!"}) logging.info("single_sql: ", single_sql)
continue cursor.execute(single_sql)
single_res = pd.DataFrame([i for i in cursor.fetchmany(self._param.top_n)]) if cursor.rowcount == 0:
single_res.columns = [i[0] for i in cursor.description] sql_res.append({"content": "\nTotal: 0\n No record in the database!"})
sql_res.append({"content": "\nTotal: " + str(cursor.rowcount) + "\n" + single_res.to_markdown()}) break
except Exception as e: single_res = pd.DataFrame([i for i in cursor.fetchmany(self._param.top_n)])
sql_res.append({"content": "**Error**:" + str(e) + "\nError SQL Statement:" + single_sql}) single_res.columns = [i[0] for i in cursor.description]
pass sql_res.append({"content": "\nTotal: " + str(cursor.rowcount) + "\n" + single_res.to_markdown()})
break
except Exception as e:
single_sql = self._regenerate_sql(single_sql, str(e), **kwargs)
single_sql = self._refactor(single_sql)
if self._loop > self._param.loop:
raise Exception("Maximum loop time exceeds. Can't query the correct data via SQL statement.")
db.close() db.close()
if not sql_res: if not sql_res:
return ExeSQL.be_output("") return ExeSQL.be_output("")
return pd.DataFrame(sql_res) return pd.DataFrame(sql_res)
def _regenerate_sql(self, failed_sql, error_message,**kwargs):
prompt = f'''
## You are the Repair SQL Statement Helper, please modify the original SQL statement based on the SQL query error report.
## The original SQL statement is as follows:{failed_sql}.
## The contents of the SQL query error report is as follows:{error_message}.
## Answer only the modified SQL statement. Please do not give any explanation, just answer the code.
'''
self._param.prompt=prompt
response = Generate._run(self, [], **kwargs)
try:
regenerated_sql = response.loc[0,"content"]
return regenerated_sql
except Exception as e:
logging.error(f"Failed to regenerate SQL: {e}")
return None
def debug(self, **kwargs):
return self._run([], **kwargs)