Feat: add trino support (#10512)

### What problem does this PR solve?
issue:
[#10296](https://github.com/infiniflow/ragflow/issues/10296)
change:
- ExeSQL: support connecting to Trino.
- Validation: password can be empty only when db_type === "trino";
  all other database types keep the existing requirement (non-empty).

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
buua436
2025-10-13 13:57:40 +08:00
committed by GitHub
parent 65c3f0406c
commit 4e6b84bb41
4 changed files with 104 additions and 7 deletions

View File

@ -53,12 +53,13 @@ class ExeSQLParam(ToolParamBase):
self.max_records = 1024
def check(self):
self.check_valid_value(self.db_type, "Choose DB type", ['mysql', 'postgres', 'mariadb', 'mssql', 'IBM DB2'])
self.check_valid_value(self.db_type, "Choose DB type", ['mysql', 'postgres', 'mariadb', 'mssql', 'IBM DB2', 'trino'])
self.check_empty(self.database, "Database name")
self.check_empty(self.username, "database username")
self.check_empty(self.host, "IP Address")
self.check_positive_integer(self.port, "IP Port")
self.check_empty(self.password, "Database password")
if self.db_type != "trino":
self.check_empty(self.password, "Database password")
self.check_positive_integer(self.max_records, "Maximum number of records")
if self.database == "rag_flow":
if self.host == "ragflow-mysql":
@ -123,6 +124,45 @@ class ExeSQL(ToolBase, ABC):
r'PWD=' + self._param.password
)
db = pyodbc.connect(conn_str)
elif self._param.db_type == 'trino':
try:
import trino
from trino.auth import BasicAuthentication
except Exception:
raise Exception("Missing dependency 'trino'. Please install: pip install trino")
def _parse_catalog_schema(db: str):
if not db:
return None, None
if "." in db:
c, s = db.split(".", 1)
elif "/" in db:
c, s = db.split("/", 1)
else:
c, s = db, "default"
return c, s
catalog, schema = _parse_catalog_schema(self._param.database)
if not catalog:
raise Exception("For Trino, `database` must be 'catalog.schema' or at least 'catalog'.")
http_scheme = "https" if os.environ.get("TRINO_USE_TLS", "0") == "1" else "http"
auth = None
if http_scheme == "https" and self._param.password:
auth = BasicAuthentication(self._param.username, self._param.password)
try:
db = trino.dbapi.connect(
host=self._param.host,
port=int(self._param.port or 8080),
user=self._param.username or "ragflow",
catalog=catalog,
schema=schema or "default",
http_scheme=http_scheme,
auth=auth
)
except Exception as e:
raise Exception("Database Connection Failed! \n" + str(e))
elif self._param.db_type == 'IBM DB2':
import ibm_db
conn_str = (

View File

@ -409,6 +409,49 @@ def test_db_connect():
ibm_db.fetch_assoc(stmt)
ibm_db.close(conn)
return get_json_result(data="Database Connection Successful!")
elif req["db_type"] == 'trino':
def _parse_catalog_schema(db: str):
if not db:
return None, None
if "." in db:
c, s = db.split(".", 1)
elif "/" in db:
c, s = db.split("/", 1)
else:
c, s = db, "default"
return c, s
try:
import trino
import os
from trino.auth import BasicAuthentication
except Exception:
return server_error_response("Missing dependency 'trino'. Please install: pip install trino")
catalog, schema = _parse_catalog_schema(req["database"])
if not catalog:
return server_error_response("For Trino, 'database' must be 'catalog.schema' or at least 'catalog'.")
http_scheme = "https" if os.environ.get("TRINO_USE_TLS", "0") == "1" else "http"
auth = None
if http_scheme == "https" and req.get("password"):
auth = BasicAuthentication(req.get("username") or "ragflow", req["password"])
conn = trino.dbapi.connect(
host=req["host"],
port=int(req["port"] or 8080),
user=req["username"] or "ragflow",
catalog=catalog,
schema=schema or "default",
http_scheme=http_scheme,
auth=auth
)
cur = conn.cursor()
cur.execute("SELECT 1")
cur.fetchall()
cur.close()
conn.close()
return get_json_result(data="Database Connection Successful!")
else:
return server_error_response("Unsupported database type.")
if req["db_type"] != 'mssql':

View File

@ -8,14 +8,27 @@ export const ExeSQLFormSchema = {
username: z.string().min(1),
host: z.string().min(1),
port: z.number(),
password: z.string().min(1),
password: z.string().optional().or(z.literal('')),
max_records: z.number(),
};
export const FormSchema = z.object({
sql: z.string().optional(),
...ExeSQLFormSchema,
});
export const FormSchema = z
.object({
sql: z.string().optional(),
...ExeSQLFormSchema,
})
.superRefine((v, ctx) => {
if (
v.db_type !== 'trino' &&
!(v.password && v.password.trim().length > 0)
) {
ctx.addIssue({
code: z.ZodIssueCode.custom,
path: ['password'],
message: 'String must contain at least 1 character(s)',
});
}
});
export function useSubmitForm() {
const { testDbConnect, loading } = useTestDbConnect();

View File

@ -2139,6 +2139,7 @@ export const ExeSQLOptions = [
'mariadb',
'mssql',
'IBM DB2',
'trino',
].map((x) => ({
label: upperFirst(x),
value: x,