diff --git a/agent/component/__init__.py b/agent/component/__init__.py index ccd59f9e9..c05f531f3 100644 --- a/agent/component/__init__.py +++ b/agent/component/__init__.py @@ -21,6 +21,7 @@ from .deepl import DeepL, DeepLParam from .github import GitHub, GitHubParam from .baidufanyi import BaiduFanyi, BaiduFanyiParam from .qweather import QWeather, QWeatherParam +from .exesql import ExeSQL, ExeSQLParam def component_class(class_name): m = importlib.import_module("agent.component") diff --git a/agent/component/exesql.py b/agent/component/exesql.py new file mode 100644 index 000000000..79d61d3ae --- /dev/null +++ b/agent/component/exesql.py @@ -0,0 +1,85 @@ +# +# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from abc import ABC + +import pandas as pd +from peewee import MySQLDatabase, PostgresqlDatabase +from agent.component.base import ComponentBase, ComponentParamBase + + +class ExeSQLParam(ComponentParamBase): + """ + Define the ExeSQL component parameters. + """ + + def __init__(self): + super().__init__() + self.db_type = "mysql" + self.database = "" + self.username = "" + self.host = "" + self.port = 3306 + self.password = "" + self.loop = 3 + self.top_n = 30 + + def check(self): + self.check_valid_value(self.db_type, "Choose DB type", ['mysql', 'postgresql', 'mariadb']) + self.check_empty(self.database, "Database name") + self.check_empty(self.username, "database username") + self.check_empty(self.host, "IP Address") + self.check_positive_integer(self.port, "IP Port") + self.check_empty(self.password, "Database password") + self.check_positive_integer(self.top_n, "Number of records") + + +class ExeSQL(ComponentBase, ABC): + component_name = "ExeSQL" + + def _run(self, history, **kwargs): + if not hasattr(self, "_loop"): + setattr(self, "_loop", 0) + if self._loop >= self._param.loop: + self._loop = 0 + raise Exception("Maximum loop time exceeds. Can't query the correct data via sql statement.") + self._loop += 1 + + ans = self.get_input() + ans = "".join(ans["content"]) if "content" in ans else "" + if not ans: + return ExeSQL.be_output("SQL statement not found!") + + if self._param.db_type in ["mysql", "mariadb"]: + db = MySQLDatabase(self._param.database, user=self._param.username, host=self._param.host, + port=self._param.port, password=self._param.password) + elif self._param.db_type == 'postgresql': + db = PostgresqlDatabase(self._param.database, user=self._param.username, host=self._param.host, + port=self._param.port, password=self._param.password) + + try: + db.connect() + query = db.execute_sql(ans) + sql_res = [{"content": rec + "\n"} for rec in [str(i) for i in query.fetchall()]] + db.close() + except Exception as e: + return ExeSQL.be_output("**Error**:" + str(e)) + + if not sql_res: + return ExeSQL.be_output("No record in the database!") + + sql_res.insert(0, {"content": "Number of records retrieved from the database is " + str(len(sql_res)) + "\n"}) + df = pd.DataFrame(sql_res[0:self._param.top_n + 1]) + return ExeSQL.be_output(df.to_markdown()) diff --git a/agent/test/dsl_examples/exesql.json b/agent/test/dsl_examples/exesql.json new file mode 100644 index 000000000..7e26587c4 --- /dev/null +++ b/agent/test/dsl_examples/exesql.json @@ -0,0 +1,43 @@ +{ + "components": { + "begin": { + "obj":{ + "component_name": "Begin", + "params": { + "prologue": "Hi there!" + } + }, + "downstream": ["answer:0"], + "upstream": [] + }, + "answer:0": { + "obj": { + "component_name": "Answer", + "params": {} + }, + "downstream": ["exesql:0"], + "upstream": ["begin", "exesql:0"] + }, + "exesql:0": { + "obj": { + "component_name": "ExeSQL", + "params": { + "database": "rag_flow", + "username": "root", + "host": "mysql", + "port": 3306, + "password": "infini_rag_flow", + "top_n": 3 + } + }, + "downstream": ["answer:0"], + "upstream": ["answer:0"] + } + }, + "history": [], + "messages": [], + "reference": {}, + "path": [], + "answer": [] +} + diff --git a/api/apps/canvas_app.py b/api/apps/canvas_app.py index 0f17b8ebf..8584df9cf 100644 --- a/api/apps/canvas_app.py +++ b/api/apps/canvas_app.py @@ -21,6 +21,7 @@ from api.db.services.canvas_service import CanvasTemplateService, UserCanvasServ from api.utils import get_uuid from api.utils.api_utils import get_json_result, server_error_response, validate_request from agent.canvas import Canvas +from peewee import MySQLDatabase, PostgresqlDatabase @manager.route('/templates', methods=['GET']) @@ -158,3 +159,22 @@ def reset(): return get_json_result(data=req["dsl"]) except Exception as e: return server_error_response(e) + + +@manager.route('/test_db_connect', methods=['POST']) +@validate_request("db_type", "database", "username", "host", "port", "password") +@login_required +def test_db_connect(): + req = request.json + try: + if req["db_type"] in ["mysql", "mariadb"]: + db = MySQLDatabase(req["database"], user=req["username"], host=req["host"], port=req["port"], + password=req["password"]) + elif req["db_type"] == 'postgresql': + db = PostgresqlDatabase(req["database"], user=req["username"], host=req["host"], port=req["port"], + password=req["password"]) + db.connect() + db.close() + return get_json_result(retmsg="Database Connection Successful!") + except Exception as e: + return server_error_response(str(e)) diff --git a/requirements.txt b/requirements.txt index 30f92db2d..c8921d496 100644 --- a/requirements.txt +++ b/requirements.txt @@ -53,6 +53,7 @@ peewee==3.17.1 Pillow==10.3.0 pipreqs==0.5.0 protobuf==5.27.2 +psycopg2-binary==2.9.9 pyclipper==1.3.0.post5 pycryptodomex==3.20.0 pypdf==4.3.0 @@ -73,6 +74,7 @@ setuptools==70.0.0 Shapely==2.0.5 six==1.16.0 StrEnum==0.4.15 +tabulate==0.9.0 tika==2.6.0 tiktoken==0.6.0 torch==2.3.0 diff --git a/requirements_arm.txt b/requirements_arm.txt index 58e62d51d..5e4dfc7e8 100644 --- a/requirements_arm.txt +++ b/requirements_arm.txt @@ -160,3 +160,5 @@ editdistance==0.8.1 markdown_to_json==2.1.1 scholarly==1.7.11 deepl==1.18.0 +psycopg2-binary==2.9.9 +tabulate-0.9.0