From 4e6b84bb41895ab7c8514bf54a18698dbec29e44 Mon Sep 17 00:00:00 2001 From: buua436 <66937541+buua436@users.noreply.github.com> Date: Mon, 13 Oct 2025 13:57:40 +0800 Subject: [PATCH] Feat: add trino support (#10512) ### What problem does this PR solve? issue: [#10296](https://github.com/infiniflow/ragflow/issues/10296) change: - ExeSQL: support connecting to Trino. - Validation: password can be empty only when db_type === "trino"; all other database types keep the existing requirement (non-empty). ### Type of change - [x] New Feature (non-breaking change which adds functionality) --- agent/tools/exesql.py | 44 ++++++++++++++++++- api/apps/canvas_app.py | 43 ++++++++++++++++++ .../agent/form/exesql-form/use-submit-form.ts | 23 +++++++--- web/src/pages/agent/options.ts | 1 + 4 files changed, 104 insertions(+), 7 deletions(-) diff --git a/agent/tools/exesql.py b/agent/tools/exesql.py index 2e1cc24bf..d93745323 100644 --- a/agent/tools/exesql.py +++ b/agent/tools/exesql.py @@ -53,12 +53,13 @@ class ExeSQLParam(ToolParamBase): self.max_records = 1024 def check(self): - self.check_valid_value(self.db_type, "Choose DB type", ['mysql', 'postgres', 'mariadb', 'mssql', 'IBM DB2']) + self.check_valid_value(self.db_type, "Choose DB type", ['mysql', 'postgres', 'mariadb', 'mssql', 'IBM DB2', 'trino']) self.check_empty(self.database, "Database name") self.check_empty(self.username, "database username") self.check_empty(self.host, "IP Address") self.check_positive_integer(self.port, "IP Port") - self.check_empty(self.password, "Database password") + if self.db_type != "trino": + self.check_empty(self.password, "Database password") self.check_positive_integer(self.max_records, "Maximum number of records") if self.database == "rag_flow": if self.host == "ragflow-mysql": @@ -123,6 +124,45 @@ class ExeSQL(ToolBase, ABC): r'PWD=' + self._param.password ) db = pyodbc.connect(conn_str) + elif self._param.db_type == 'trino': + try: + import trino + from trino.auth import BasicAuthentication + except Exception: + raise Exception("Missing dependency 'trino'. Please install: pip install trino") + + def _parse_catalog_schema(db: str): + if not db: + return None, None + if "." in db: + c, s = db.split(".", 1) + elif "/" in db: + c, s = db.split("/", 1) + else: + c, s = db, "default" + return c, s + + catalog, schema = _parse_catalog_schema(self._param.database) + if not catalog: + raise Exception("For Trino, `database` must be 'catalog.schema' or at least 'catalog'.") + + http_scheme = "https" if os.environ.get("TRINO_USE_TLS", "0") == "1" else "http" + auth = None + if http_scheme == "https" and self._param.password: + auth = BasicAuthentication(self._param.username, self._param.password) + + try: + db = trino.dbapi.connect( + host=self._param.host, + port=int(self._param.port or 8080), + user=self._param.username or "ragflow", + catalog=catalog, + schema=schema or "default", + http_scheme=http_scheme, + auth=auth + ) + except Exception as e: + raise Exception("Database Connection Failed! \n" + str(e)) elif self._param.db_type == 'IBM DB2': import ibm_db conn_str = ( diff --git a/api/apps/canvas_app.py b/api/apps/canvas_app.py index c3d4dd824..7cdeff097 100644 --- a/api/apps/canvas_app.py +++ b/api/apps/canvas_app.py @@ -409,6 +409,49 @@ def test_db_connect(): ibm_db.fetch_assoc(stmt) ibm_db.close(conn) return get_json_result(data="Database Connection Successful!") + elif req["db_type"] == 'trino': + def _parse_catalog_schema(db: str): + if not db: + return None, None + if "." in db: + c, s = db.split(".", 1) + elif "/" in db: + c, s = db.split("/", 1) + else: + c, s = db, "default" + return c, s + try: + import trino + import os + from trino.auth import BasicAuthentication + except Exception: + return server_error_response("Missing dependency 'trino'. Please install: pip install trino") + + catalog, schema = _parse_catalog_schema(req["database"]) + if not catalog: + return server_error_response("For Trino, 'database' must be 'catalog.schema' or at least 'catalog'.") + + http_scheme = "https" if os.environ.get("TRINO_USE_TLS", "0") == "1" else "http" + + auth = None + if http_scheme == "https" and req.get("password"): + auth = BasicAuthentication(req.get("username") or "ragflow", req["password"]) + + conn = trino.dbapi.connect( + host=req["host"], + port=int(req["port"] or 8080), + user=req["username"] or "ragflow", + catalog=catalog, + schema=schema or "default", + http_scheme=http_scheme, + auth=auth + ) + cur = conn.cursor() + cur.execute("SELECT 1") + cur.fetchall() + cur.close() + conn.close() + return get_json_result(data="Database Connection Successful!") else: return server_error_response("Unsupported database type.") if req["db_type"] != 'mssql': diff --git a/web/src/pages/agent/form/exesql-form/use-submit-form.ts b/web/src/pages/agent/form/exesql-form/use-submit-form.ts index 8be69c7b0..b7141a2cf 100644 --- a/web/src/pages/agent/form/exesql-form/use-submit-form.ts +++ b/web/src/pages/agent/form/exesql-form/use-submit-form.ts @@ -8,14 +8,27 @@ export const ExeSQLFormSchema = { username: z.string().min(1), host: z.string().min(1), port: z.number(), - password: z.string().min(1), + password: z.string().optional().or(z.literal('')), max_records: z.number(), }; -export const FormSchema = z.object({ - sql: z.string().optional(), - ...ExeSQLFormSchema, -}); +export const FormSchema = z + .object({ + sql: z.string().optional(), + ...ExeSQLFormSchema, + }) + .superRefine((v, ctx) => { + if ( + v.db_type !== 'trino' && + !(v.password && v.password.trim().length > 0) + ) { + ctx.addIssue({ + code: z.ZodIssueCode.custom, + path: ['password'], + message: 'String must contain at least 1 character(s)', + }); + } + }); export function useSubmitForm() { const { testDbConnect, loading } = useTestDbConnect(); diff --git a/web/src/pages/agent/options.ts b/web/src/pages/agent/options.ts index f231ec450..9d68ec70b 100644 --- a/web/src/pages/agent/options.ts +++ b/web/src/pages/agent/options.ts @@ -2139,6 +2139,7 @@ export const ExeSQLOptions = [ 'mariadb', 'mssql', 'IBM DB2', + 'trino', ].map((x) => ({ label: upperFirst(x), value: x,