Add component ExeSQL (#1966)

### What problem does this PR solve?

#1965 

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
This commit is contained in:
H
2024-08-16 12:36:53 +08:00
committed by GitHub
parent b4ef50bdb5
commit 644f68de97
6 changed files with 153 additions and 0 deletions

View File

@ -21,6 +21,7 @@ from .deepl import DeepL, DeepLParam
from .github import GitHub, GitHubParam from .github import GitHub, GitHubParam
from .baidufanyi import BaiduFanyi, BaiduFanyiParam from .baidufanyi import BaiduFanyi, BaiduFanyiParam
from .qweather import QWeather, QWeatherParam from .qweather import QWeather, QWeatherParam
from .exesql import ExeSQL, ExeSQLParam
def component_class(class_name): def component_class(class_name):
m = importlib.import_module("agent.component") m = importlib.import_module("agent.component")

85
agent/component/exesql.py Normal file
View File

@ -0,0 +1,85 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from abc import ABC
import pandas as pd
from peewee import MySQLDatabase, PostgresqlDatabase
from agent.component.base import ComponentBase, ComponentParamBase
class ExeSQLParam(ComponentParamBase):
"""
Define the ExeSQL component parameters.
"""
def __init__(self):
super().__init__()
self.db_type = "mysql"
self.database = ""
self.username = ""
self.host = ""
self.port = 3306
self.password = ""
self.loop = 3
self.top_n = 30
def check(self):
self.check_valid_value(self.db_type, "Choose DB type", ['mysql', 'postgresql', 'mariadb'])
self.check_empty(self.database, "Database name")
self.check_empty(self.username, "database username")
self.check_empty(self.host, "IP Address")
self.check_positive_integer(self.port, "IP Port")
self.check_empty(self.password, "Database password")
self.check_positive_integer(self.top_n, "Number of records")
class ExeSQL(ComponentBase, ABC):
component_name = "ExeSQL"
def _run(self, history, **kwargs):
if not hasattr(self, "_loop"):
setattr(self, "_loop", 0)
if self._loop >= self._param.loop:
self._loop = 0
raise Exception("Maximum loop time exceeds. Can't query the correct data via sql statement.")
self._loop += 1
ans = self.get_input()
ans = "".join(ans["content"]) if "content" in ans else ""
if not ans:
return ExeSQL.be_output("SQL statement not found!")
if self._param.db_type in ["mysql", "mariadb"]:
db = MySQLDatabase(self._param.database, user=self._param.username, host=self._param.host,
port=self._param.port, password=self._param.password)
elif self._param.db_type == 'postgresql':
db = PostgresqlDatabase(self._param.database, user=self._param.username, host=self._param.host,
port=self._param.port, password=self._param.password)
try:
db.connect()
query = db.execute_sql(ans)
sql_res = [{"content": rec + "\n"} for rec in [str(i) for i in query.fetchall()]]
db.close()
except Exception as e:
return ExeSQL.be_output("**Error**:" + str(e))
if not sql_res:
return ExeSQL.be_output("No record in the database!")
sql_res.insert(0, {"content": "Number of records retrieved from the database is " + str(len(sql_res)) + "\n"})
df = pd.DataFrame(sql_res[0:self._param.top_n + 1])
return ExeSQL.be_output(df.to_markdown())

View File

@ -0,0 +1,43 @@
{
"components": {
"begin": {
"obj":{
"component_name": "Begin",
"params": {
"prologue": "Hi there!"
}
},
"downstream": ["answer:0"],
"upstream": []
},
"answer:0": {
"obj": {
"component_name": "Answer",
"params": {}
},
"downstream": ["exesql:0"],
"upstream": ["begin", "exesql:0"]
},
"exesql:0": {
"obj": {
"component_name": "ExeSQL",
"params": {
"database": "rag_flow",
"username": "root",
"host": "mysql",
"port": 3306,
"password": "infini_rag_flow",
"top_n": 3
}
},
"downstream": ["answer:0"],
"upstream": ["answer:0"]
}
},
"history": [],
"messages": [],
"reference": {},
"path": [],
"answer": []
}

View File

@ -21,6 +21,7 @@ from api.db.services.canvas_service import CanvasTemplateService, UserCanvasServ
from api.utils import get_uuid from api.utils import get_uuid
from api.utils.api_utils import get_json_result, server_error_response, validate_request from api.utils.api_utils import get_json_result, server_error_response, validate_request
from agent.canvas import Canvas from agent.canvas import Canvas
from peewee import MySQLDatabase, PostgresqlDatabase
@manager.route('/templates', methods=['GET']) @manager.route('/templates', methods=['GET'])
@ -158,3 +159,22 @@ def reset():
return get_json_result(data=req["dsl"]) return get_json_result(data=req["dsl"])
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@manager.route('/test_db_connect', methods=['POST'])
@validate_request("db_type", "database", "username", "host", "port", "password")
@login_required
def test_db_connect():
req = request.json
try:
if req["db_type"] in ["mysql", "mariadb"]:
db = MySQLDatabase(req["database"], user=req["username"], host=req["host"], port=req["port"],
password=req["password"])
elif req["db_type"] == 'postgresql':
db = PostgresqlDatabase(req["database"], user=req["username"], host=req["host"], port=req["port"],
password=req["password"])
db.connect()
db.close()
return get_json_result(retmsg="Database Connection Successful!")
except Exception as e:
return server_error_response(str(e))

View File

@ -53,6 +53,7 @@ peewee==3.17.1
Pillow==10.3.0 Pillow==10.3.0
pipreqs==0.5.0 pipreqs==0.5.0
protobuf==5.27.2 protobuf==5.27.2
psycopg2-binary==2.9.9
pyclipper==1.3.0.post5 pyclipper==1.3.0.post5
pycryptodomex==3.20.0 pycryptodomex==3.20.0
pypdf==4.3.0 pypdf==4.3.0
@ -73,6 +74,7 @@ setuptools==70.0.0
Shapely==2.0.5 Shapely==2.0.5
six==1.16.0 six==1.16.0
StrEnum==0.4.15 StrEnum==0.4.15
tabulate==0.9.0
tika==2.6.0 tika==2.6.0
tiktoken==0.6.0 tiktoken==0.6.0
torch==2.3.0 torch==2.3.0

View File

@ -160,3 +160,5 @@ editdistance==0.8.1
markdown_to_json==2.1.1 markdown_to_json==2.1.1
scholarly==1.7.11 scholarly==1.7.11
deepl==1.18.0 deepl==1.18.0
psycopg2-binary==2.9.9
tabulate-0.9.0