mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 12:32:30 +08:00
Feat: add trino support (#10512)
### What problem does this PR solve? issue: [#10296](https://github.com/infiniflow/ragflow/issues/10296) change: - ExeSQL: support connecting to Trino. - Validation: password can be empty only when db_type === "trino"; all other database types keep the existing requirement (non-empty). ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -53,12 +53,13 @@ class ExeSQLParam(ToolParamBase):
|
||||
self.max_records = 1024
|
||||
|
||||
def check(self):
|
||||
self.check_valid_value(self.db_type, "Choose DB type", ['mysql', 'postgres', 'mariadb', 'mssql', 'IBM DB2'])
|
||||
self.check_valid_value(self.db_type, "Choose DB type", ['mysql', 'postgres', 'mariadb', 'mssql', 'IBM DB2', 'trino'])
|
||||
self.check_empty(self.database, "Database name")
|
||||
self.check_empty(self.username, "database username")
|
||||
self.check_empty(self.host, "IP Address")
|
||||
self.check_positive_integer(self.port, "IP Port")
|
||||
self.check_empty(self.password, "Database password")
|
||||
if self.db_type != "trino":
|
||||
self.check_empty(self.password, "Database password")
|
||||
self.check_positive_integer(self.max_records, "Maximum number of records")
|
||||
if self.database == "rag_flow":
|
||||
if self.host == "ragflow-mysql":
|
||||
@ -123,6 +124,45 @@ class ExeSQL(ToolBase, ABC):
|
||||
r'PWD=' + self._param.password
|
||||
)
|
||||
db = pyodbc.connect(conn_str)
|
||||
elif self._param.db_type == 'trino':
|
||||
try:
|
||||
import trino
|
||||
from trino.auth import BasicAuthentication
|
||||
except Exception:
|
||||
raise Exception("Missing dependency 'trino'. Please install: pip install trino")
|
||||
|
||||
def _parse_catalog_schema(db: str):
|
||||
if not db:
|
||||
return None, None
|
||||
if "." in db:
|
||||
c, s = db.split(".", 1)
|
||||
elif "/" in db:
|
||||
c, s = db.split("/", 1)
|
||||
else:
|
||||
c, s = db, "default"
|
||||
return c, s
|
||||
|
||||
catalog, schema = _parse_catalog_schema(self._param.database)
|
||||
if not catalog:
|
||||
raise Exception("For Trino, `database` must be 'catalog.schema' or at least 'catalog'.")
|
||||
|
||||
http_scheme = "https" if os.environ.get("TRINO_USE_TLS", "0") == "1" else "http"
|
||||
auth = None
|
||||
if http_scheme == "https" and self._param.password:
|
||||
auth = BasicAuthentication(self._param.username, self._param.password)
|
||||
|
||||
try:
|
||||
db = trino.dbapi.connect(
|
||||
host=self._param.host,
|
||||
port=int(self._param.port or 8080),
|
||||
user=self._param.username or "ragflow",
|
||||
catalog=catalog,
|
||||
schema=schema or "default",
|
||||
http_scheme=http_scheme,
|
||||
auth=auth
|
||||
)
|
||||
except Exception as e:
|
||||
raise Exception("Database Connection Failed! \n" + str(e))
|
||||
elif self._param.db_type == 'IBM DB2':
|
||||
import ibm_db
|
||||
conn_str = (
|
||||
|
||||
@ -409,6 +409,49 @@ def test_db_connect():
|
||||
ibm_db.fetch_assoc(stmt)
|
||||
ibm_db.close(conn)
|
||||
return get_json_result(data="Database Connection Successful!")
|
||||
elif req["db_type"] == 'trino':
|
||||
def _parse_catalog_schema(db: str):
|
||||
if not db:
|
||||
return None, None
|
||||
if "." in db:
|
||||
c, s = db.split(".", 1)
|
||||
elif "/" in db:
|
||||
c, s = db.split("/", 1)
|
||||
else:
|
||||
c, s = db, "default"
|
||||
return c, s
|
||||
try:
|
||||
import trino
|
||||
import os
|
||||
from trino.auth import BasicAuthentication
|
||||
except Exception:
|
||||
return server_error_response("Missing dependency 'trino'. Please install: pip install trino")
|
||||
|
||||
catalog, schema = _parse_catalog_schema(req["database"])
|
||||
if not catalog:
|
||||
return server_error_response("For Trino, 'database' must be 'catalog.schema' or at least 'catalog'.")
|
||||
|
||||
http_scheme = "https" if os.environ.get("TRINO_USE_TLS", "0") == "1" else "http"
|
||||
|
||||
auth = None
|
||||
if http_scheme == "https" and req.get("password"):
|
||||
auth = BasicAuthentication(req.get("username") or "ragflow", req["password"])
|
||||
|
||||
conn = trino.dbapi.connect(
|
||||
host=req["host"],
|
||||
port=int(req["port"] or 8080),
|
||||
user=req["username"] or "ragflow",
|
||||
catalog=catalog,
|
||||
schema=schema or "default",
|
||||
http_scheme=http_scheme,
|
||||
auth=auth
|
||||
)
|
||||
cur = conn.cursor()
|
||||
cur.execute("SELECT 1")
|
||||
cur.fetchall()
|
||||
cur.close()
|
||||
conn.close()
|
||||
return get_json_result(data="Database Connection Successful!")
|
||||
else:
|
||||
return server_error_response("Unsupported database type.")
|
||||
if req["db_type"] != 'mssql':
|
||||
|
||||
@ -8,14 +8,27 @@ export const ExeSQLFormSchema = {
|
||||
username: z.string().min(1),
|
||||
host: z.string().min(1),
|
||||
port: z.number(),
|
||||
password: z.string().min(1),
|
||||
password: z.string().optional().or(z.literal('')),
|
||||
max_records: z.number(),
|
||||
};
|
||||
|
||||
export const FormSchema = z.object({
|
||||
sql: z.string().optional(),
|
||||
...ExeSQLFormSchema,
|
||||
});
|
||||
export const FormSchema = z
|
||||
.object({
|
||||
sql: z.string().optional(),
|
||||
...ExeSQLFormSchema,
|
||||
})
|
||||
.superRefine((v, ctx) => {
|
||||
if (
|
||||
v.db_type !== 'trino' &&
|
||||
!(v.password && v.password.trim().length > 0)
|
||||
) {
|
||||
ctx.addIssue({
|
||||
code: z.ZodIssueCode.custom,
|
||||
path: ['password'],
|
||||
message: 'String must contain at least 1 character(s)',
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
export function useSubmitForm() {
|
||||
const { testDbConnect, loading } = useTestDbConnect();
|
||||
|
||||
@ -2139,6 +2139,7 @@ export const ExeSQLOptions = [
|
||||
'mariadb',
|
||||
'mssql',
|
||||
'IBM DB2',
|
||||
'trino',
|
||||
].map((x) => ({
|
||||
label: upperFirst(x),
|
||||
value: x,
|
||||
|
||||
Reference in New Issue
Block a user