mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 12:32:30 +08:00
Feat: add trino support (#10512)
### What problem does this PR solve? issue: [#10296](https://github.com/infiniflow/ragflow/issues/10296) change: - ExeSQL: support connecting to Trino. - Validation: password can be empty only when db_type === "trino"; all other database types keep the existing requirement (non-empty). ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -53,12 +53,13 @@ class ExeSQLParam(ToolParamBase):
|
|||||||
self.max_records = 1024
|
self.max_records = 1024
|
||||||
|
|
||||||
def check(self):
|
def check(self):
|
||||||
self.check_valid_value(self.db_type, "Choose DB type", ['mysql', 'postgres', 'mariadb', 'mssql', 'IBM DB2'])
|
self.check_valid_value(self.db_type, "Choose DB type", ['mysql', 'postgres', 'mariadb', 'mssql', 'IBM DB2', 'trino'])
|
||||||
self.check_empty(self.database, "Database name")
|
self.check_empty(self.database, "Database name")
|
||||||
self.check_empty(self.username, "database username")
|
self.check_empty(self.username, "database username")
|
||||||
self.check_empty(self.host, "IP Address")
|
self.check_empty(self.host, "IP Address")
|
||||||
self.check_positive_integer(self.port, "IP Port")
|
self.check_positive_integer(self.port, "IP Port")
|
||||||
self.check_empty(self.password, "Database password")
|
if self.db_type != "trino":
|
||||||
|
self.check_empty(self.password, "Database password")
|
||||||
self.check_positive_integer(self.max_records, "Maximum number of records")
|
self.check_positive_integer(self.max_records, "Maximum number of records")
|
||||||
if self.database == "rag_flow":
|
if self.database == "rag_flow":
|
||||||
if self.host == "ragflow-mysql":
|
if self.host == "ragflow-mysql":
|
||||||
@ -123,6 +124,45 @@ class ExeSQL(ToolBase, ABC):
|
|||||||
r'PWD=' + self._param.password
|
r'PWD=' + self._param.password
|
||||||
)
|
)
|
||||||
db = pyodbc.connect(conn_str)
|
db = pyodbc.connect(conn_str)
|
||||||
|
elif self._param.db_type == 'trino':
|
||||||
|
try:
|
||||||
|
import trino
|
||||||
|
from trino.auth import BasicAuthentication
|
||||||
|
except Exception:
|
||||||
|
raise Exception("Missing dependency 'trino'. Please install: pip install trino")
|
||||||
|
|
||||||
|
def _parse_catalog_schema(db: str):
|
||||||
|
if not db:
|
||||||
|
return None, None
|
||||||
|
if "." in db:
|
||||||
|
c, s = db.split(".", 1)
|
||||||
|
elif "/" in db:
|
||||||
|
c, s = db.split("/", 1)
|
||||||
|
else:
|
||||||
|
c, s = db, "default"
|
||||||
|
return c, s
|
||||||
|
|
||||||
|
catalog, schema = _parse_catalog_schema(self._param.database)
|
||||||
|
if not catalog:
|
||||||
|
raise Exception("For Trino, `database` must be 'catalog.schema' or at least 'catalog'.")
|
||||||
|
|
||||||
|
http_scheme = "https" if os.environ.get("TRINO_USE_TLS", "0") == "1" else "http"
|
||||||
|
auth = None
|
||||||
|
if http_scheme == "https" and self._param.password:
|
||||||
|
auth = BasicAuthentication(self._param.username, self._param.password)
|
||||||
|
|
||||||
|
try:
|
||||||
|
db = trino.dbapi.connect(
|
||||||
|
host=self._param.host,
|
||||||
|
port=int(self._param.port or 8080),
|
||||||
|
user=self._param.username or "ragflow",
|
||||||
|
catalog=catalog,
|
||||||
|
schema=schema or "default",
|
||||||
|
http_scheme=http_scheme,
|
||||||
|
auth=auth
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception("Database Connection Failed! \n" + str(e))
|
||||||
elif self._param.db_type == 'IBM DB2':
|
elif self._param.db_type == 'IBM DB2':
|
||||||
import ibm_db
|
import ibm_db
|
||||||
conn_str = (
|
conn_str = (
|
||||||
|
|||||||
@ -409,6 +409,49 @@ def test_db_connect():
|
|||||||
ibm_db.fetch_assoc(stmt)
|
ibm_db.fetch_assoc(stmt)
|
||||||
ibm_db.close(conn)
|
ibm_db.close(conn)
|
||||||
return get_json_result(data="Database Connection Successful!")
|
return get_json_result(data="Database Connection Successful!")
|
||||||
|
elif req["db_type"] == 'trino':
|
||||||
|
def _parse_catalog_schema(db: str):
|
||||||
|
if not db:
|
||||||
|
return None, None
|
||||||
|
if "." in db:
|
||||||
|
c, s = db.split(".", 1)
|
||||||
|
elif "/" in db:
|
||||||
|
c, s = db.split("/", 1)
|
||||||
|
else:
|
||||||
|
c, s = db, "default"
|
||||||
|
return c, s
|
||||||
|
try:
|
||||||
|
import trino
|
||||||
|
import os
|
||||||
|
from trino.auth import BasicAuthentication
|
||||||
|
except Exception:
|
||||||
|
return server_error_response("Missing dependency 'trino'. Please install: pip install trino")
|
||||||
|
|
||||||
|
catalog, schema = _parse_catalog_schema(req["database"])
|
||||||
|
if not catalog:
|
||||||
|
return server_error_response("For Trino, 'database' must be 'catalog.schema' or at least 'catalog'.")
|
||||||
|
|
||||||
|
http_scheme = "https" if os.environ.get("TRINO_USE_TLS", "0") == "1" else "http"
|
||||||
|
|
||||||
|
auth = None
|
||||||
|
if http_scheme == "https" and req.get("password"):
|
||||||
|
auth = BasicAuthentication(req.get("username") or "ragflow", req["password"])
|
||||||
|
|
||||||
|
conn = trino.dbapi.connect(
|
||||||
|
host=req["host"],
|
||||||
|
port=int(req["port"] or 8080),
|
||||||
|
user=req["username"] or "ragflow",
|
||||||
|
catalog=catalog,
|
||||||
|
schema=schema or "default",
|
||||||
|
http_scheme=http_scheme,
|
||||||
|
auth=auth
|
||||||
|
)
|
||||||
|
cur = conn.cursor()
|
||||||
|
cur.execute("SELECT 1")
|
||||||
|
cur.fetchall()
|
||||||
|
cur.close()
|
||||||
|
conn.close()
|
||||||
|
return get_json_result(data="Database Connection Successful!")
|
||||||
else:
|
else:
|
||||||
return server_error_response("Unsupported database type.")
|
return server_error_response("Unsupported database type.")
|
||||||
if req["db_type"] != 'mssql':
|
if req["db_type"] != 'mssql':
|
||||||
|
|||||||
@ -8,14 +8,27 @@ export const ExeSQLFormSchema = {
|
|||||||
username: z.string().min(1),
|
username: z.string().min(1),
|
||||||
host: z.string().min(1),
|
host: z.string().min(1),
|
||||||
port: z.number(),
|
port: z.number(),
|
||||||
password: z.string().min(1),
|
password: z.string().optional().or(z.literal('')),
|
||||||
max_records: z.number(),
|
max_records: z.number(),
|
||||||
};
|
};
|
||||||
|
|
||||||
export const FormSchema = z.object({
|
export const FormSchema = z
|
||||||
sql: z.string().optional(),
|
.object({
|
||||||
...ExeSQLFormSchema,
|
sql: z.string().optional(),
|
||||||
});
|
...ExeSQLFormSchema,
|
||||||
|
})
|
||||||
|
.superRefine((v, ctx) => {
|
||||||
|
if (
|
||||||
|
v.db_type !== 'trino' &&
|
||||||
|
!(v.password && v.password.trim().length > 0)
|
||||||
|
) {
|
||||||
|
ctx.addIssue({
|
||||||
|
code: z.ZodIssueCode.custom,
|
||||||
|
path: ['password'],
|
||||||
|
message: 'String must contain at least 1 character(s)',
|
||||||
|
});
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
export function useSubmitForm() {
|
export function useSubmitForm() {
|
||||||
const { testDbConnect, loading } = useTestDbConnect();
|
const { testDbConnect, loading } = useTestDbConnect();
|
||||||
|
|||||||
@ -2139,6 +2139,7 @@ export const ExeSQLOptions = [
|
|||||||
'mariadb',
|
'mariadb',
|
||||||
'mssql',
|
'mssql',
|
||||||
'IBM DB2',
|
'IBM DB2',
|
||||||
|
'trino',
|
||||||
].map((x) => ({
|
].map((x) => ({
|
||||||
label: upperFirst(x),
|
label: upperFirst(x),
|
||||||
value: x,
|
value: x,
|
||||||
|
|||||||
Reference in New Issue
Block a user