mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-02-07 11:05:05 +08:00
Fix RDBMS field separation after chunking by wrapping field names in brackets (【field】: value). This ensures fields remain distinguishable even when TxtParser strips newline delimiters during chunk merging. Closes #13001 Co-authored-by: mkdev11 <YOUR_GITHUB_ID+MkDev11@users.noreply.github.com>
406 lines
15 KiB
Python
406 lines
15 KiB
Python
"""RDBMS (MySQL/PostgreSQL) data source connector for importing data from relational databases."""
|
|
|
|
import hashlib
|
|
import json
|
|
import logging
|
|
from datetime import datetime, timezone
|
|
from enum import Enum
|
|
from typing import Any, Dict, Generator, Optional, Union
|
|
|
|
from common.data_source.config import DocumentSource, INDEX_BATCH_SIZE
|
|
from common.data_source.exceptions import (
|
|
ConnectorMissingCredentialError,
|
|
ConnectorValidationError,
|
|
)
|
|
from common.data_source.interfaces import LoadConnector, PollConnector, SecondsSinceUnixEpoch
|
|
from common.data_source.models import Document
|
|
|
|
|
|
class DatabaseType(str, Enum):
|
|
"""Supported database types."""
|
|
MYSQL = "mysql"
|
|
POSTGRESQL = "postgresql"
|
|
|
|
|
|
class RDBMSConnector(LoadConnector, PollConnector):
|
|
"""
|
|
RDBMS connector for importing data from MySQL and PostgreSQL databases.
|
|
|
|
This connector allows users to:
|
|
1. Connect to a MySQL or PostgreSQL database
|
|
2. Execute a SQL query to extract data
|
|
3. Map columns to content (for vectorization) and metadata
|
|
4. Sync data in batch or incremental mode using a timestamp column
|
|
"""
|
|
def __init__(
|
|
self,
|
|
db_type: str,
|
|
host: str,
|
|
port: int,
|
|
database: str,
|
|
query: str,
|
|
content_columns: str,
|
|
metadata_columns: Optional[str] = None,
|
|
id_column: Optional[str] = None,
|
|
timestamp_column: Optional[str] = None,
|
|
batch_size: int = INDEX_BATCH_SIZE,
|
|
) -> None:
|
|
"""
|
|
Initialize the RDBMS connector.
|
|
|
|
Args:
|
|
db_type: Database type ('mysql' or 'postgresql')
|
|
host: Database host
|
|
port: Database port
|
|
database: Database name
|
|
query: SQL query to execute (e.g., "SELECT * FROM products WHERE status = 'active'")
|
|
content_columns: Comma-separated column names to use for document content
|
|
metadata_columns: Comma-separated column names to use as metadata (optional)
|
|
id_column: Column to use as unique document ID (optional, will generate hash if not provided)
|
|
timestamp_column: Column to use for incremental sync (optional, must be datetime/timestamp type)
|
|
batch_size: Number of documents per batch
|
|
"""
|
|
self.db_type = DatabaseType(db_type.lower())
|
|
self.host = host.strip()
|
|
self.port = port
|
|
self.database = database.strip()
|
|
self.query = query.strip()
|
|
self.content_columns = [c.strip() for c in content_columns.split(",") if c.strip()]
|
|
self.metadata_columns = [c.strip() for c in (metadata_columns or "").split(",") if c.strip()]
|
|
self.id_column = id_column.strip() if id_column else None
|
|
self.timestamp_column = timestamp_column.strip() if timestamp_column else None
|
|
self.batch_size = batch_size
|
|
|
|
self._connection = None
|
|
self._credentials: Dict[str, Any] = {}
|
|
|
|
def load_credentials(self, credentials: Dict[str, Any]) -> Dict[str, Any] | None:
|
|
"""Load database credentials."""
|
|
logging.debug(f"Loading credentials for {self.db_type} database: {self.database}")
|
|
|
|
required_keys = ["username", "password"]
|
|
for key in required_keys:
|
|
if not credentials.get(key):
|
|
raise ConnectorMissingCredentialError(f"RDBMS ({self.db_type}): missing {key}")
|
|
|
|
self._credentials = credentials
|
|
return None
|
|
|
|
def _get_connection(self):
|
|
"""Create and return a database connection."""
|
|
if self._connection is not None:
|
|
return self._connection
|
|
|
|
username = self._credentials.get("username")
|
|
password = self._credentials.get("password")
|
|
|
|
if self.db_type == DatabaseType.MYSQL:
|
|
try:
|
|
import mysql.connector
|
|
except ImportError:
|
|
raise ConnectorValidationError(
|
|
"MySQL connector not installed. Please install mysql-connector-python."
|
|
)
|
|
try:
|
|
self._connection = mysql.connector.connect(
|
|
host=self.host,
|
|
port=self.port,
|
|
database=self.database,
|
|
user=username,
|
|
password=password,
|
|
charset='utf8mb4',
|
|
use_unicode=True,
|
|
)
|
|
except Exception as e:
|
|
raise ConnectorValidationError(f"Failed to connect to MySQL: {e}")
|
|
elif self.db_type == DatabaseType.POSTGRESQL:
|
|
try:
|
|
import psycopg2
|
|
except ImportError:
|
|
raise ConnectorValidationError(
|
|
"PostgreSQL connector not installed. Please install psycopg2-binary."
|
|
)
|
|
try:
|
|
self._connection = psycopg2.connect(
|
|
host=self.host,
|
|
port=self.port,
|
|
dbname=self.database,
|
|
user=username,
|
|
password=password,
|
|
)
|
|
except Exception as e:
|
|
raise ConnectorValidationError(f"Failed to connect to PostgreSQL: {e}")
|
|
|
|
return self._connection
|
|
|
|
def _close_connection(self):
|
|
"""Close the database connection."""
|
|
if self._connection is not None:
|
|
try:
|
|
self._connection.close()
|
|
except Exception:
|
|
pass
|
|
self._connection = None
|
|
|
|
def _get_tables(self) -> list[str]:
|
|
"""Get list of all tables in the database."""
|
|
connection = self._get_connection()
|
|
cursor = connection.cursor()
|
|
|
|
try:
|
|
if self.db_type == DatabaseType.MYSQL:
|
|
cursor.execute("SHOW TABLES")
|
|
else:
|
|
cursor.execute(
|
|
"SELECT table_name FROM information_schema.tables "
|
|
"WHERE table_schema = 'public' AND table_type = 'BASE TABLE'"
|
|
)
|
|
tables = [row[0] for row in cursor.fetchall()]
|
|
return tables
|
|
finally:
|
|
cursor.close()
|
|
|
|
def _build_query_with_time_filter(
|
|
self,
|
|
start: Optional[datetime] = None,
|
|
end: Optional[datetime] = None,
|
|
) -> str:
|
|
"""Build the query with optional time filtering for incremental sync."""
|
|
if not self.query:
|
|
return "" # Will be handled by table discovery
|
|
base_query = self.query.rstrip(";")
|
|
|
|
if not self.timestamp_column or (start is None and end is None):
|
|
return base_query
|
|
|
|
has_where = "where" in base_query.lower()
|
|
connector = " AND" if has_where else " WHERE"
|
|
|
|
time_conditions = []
|
|
if start is not None:
|
|
if self.db_type == DatabaseType.MYSQL:
|
|
time_conditions.append(f"{self.timestamp_column} > '{start.strftime('%Y-%m-%d %H:%M:%S')}'")
|
|
else:
|
|
time_conditions.append(f"{self.timestamp_column} > '{start.isoformat()}'")
|
|
|
|
if end is not None:
|
|
if self.db_type == DatabaseType.MYSQL:
|
|
time_conditions.append(f"{self.timestamp_column} <= '{end.strftime('%Y-%m-%d %H:%M:%S')}'")
|
|
else:
|
|
time_conditions.append(f"{self.timestamp_column} <= '{end.isoformat()}'")
|
|
|
|
if time_conditions:
|
|
return f"{base_query}{connector} {' AND '.join(time_conditions)}"
|
|
|
|
return base_query
|
|
|
|
def _row_to_document(self, row: Union[tuple, list, Dict[str, Any]], column_names: list) -> Document:
|
|
"""Convert a database row to a Document."""
|
|
row_dict = dict(zip(column_names, row)) if isinstance(row, (list, tuple)) else row
|
|
|
|
content_parts = []
|
|
for col in self.content_columns:
|
|
if col in row_dict and row_dict[col] is not None:
|
|
value = row_dict[col]
|
|
if isinstance(value, (dict, list)):
|
|
value = json.dumps(value, ensure_ascii=False)
|
|
# Use brackets around field name to ensure it's distinguishable
|
|
# after chunking (TxtParser strips \n delimiters during merge)
|
|
content_parts.append(f"【{col}】: {value}")
|
|
|
|
content = "\n".join(content_parts)
|
|
|
|
if self.id_column and self.id_column in row_dict:
|
|
doc_id = f"{self.db_type}:{self.database}:{row_dict[self.id_column]}"
|
|
else:
|
|
content_hash = hashlib.md5(content.encode()).hexdigest()
|
|
doc_id = f"{self.db_type}:{self.database}:{content_hash}"
|
|
|
|
metadata = {}
|
|
for col in self.metadata_columns:
|
|
if col in row_dict and row_dict[col] is not None:
|
|
value = row_dict[col]
|
|
if isinstance(value, datetime):
|
|
value = value.isoformat()
|
|
elif isinstance(value, (dict, list)):
|
|
value = json.dumps(value, ensure_ascii=False)
|
|
else:
|
|
value = str(value)
|
|
metadata[col] = value
|
|
|
|
doc_updated_at = datetime.now(timezone.utc)
|
|
if self.timestamp_column and self.timestamp_column in row_dict:
|
|
ts_value = row_dict[self.timestamp_column]
|
|
if isinstance(ts_value, datetime):
|
|
if ts_value.tzinfo is None:
|
|
doc_updated_at = ts_value.replace(tzinfo=timezone.utc)
|
|
else:
|
|
doc_updated_at = ts_value
|
|
|
|
first_content_col = self.content_columns[0] if self.content_columns else "record"
|
|
semantic_id = str(row_dict.get(first_content_col, "database_record"))[:100]
|
|
|
|
return Document(
|
|
id=doc_id,
|
|
blob=content.encode("utf-8"),
|
|
source=DocumentSource(self.db_type.value),
|
|
semantic_identifier=semantic_id,
|
|
extension=".txt",
|
|
doc_updated_at=doc_updated_at,
|
|
size_bytes=len(content.encode("utf-8")),
|
|
metadata=metadata if metadata else None,
|
|
)
|
|
|
|
def _yield_documents_from_query(
|
|
self,
|
|
query: str,
|
|
) -> Generator[list[Document], None, None]:
|
|
"""Generate documents from a single query."""
|
|
connection = self._get_connection()
|
|
cursor = connection.cursor()
|
|
|
|
try:
|
|
logging.info(f"Executing query: {query[:200]}...")
|
|
cursor.execute(query)
|
|
column_names = [desc[0] for desc in cursor.description]
|
|
|
|
batch: list[Document] = []
|
|
for row in cursor:
|
|
try:
|
|
doc = self._row_to_document(row, column_names)
|
|
batch.append(doc)
|
|
|
|
if len(batch) >= self.batch_size:
|
|
yield batch
|
|
batch = []
|
|
except Exception as e:
|
|
logging.warning(f"Error converting row to document: {e}")
|
|
continue
|
|
|
|
if batch:
|
|
yield batch
|
|
|
|
finally:
|
|
try:
|
|
cursor.fetchall()
|
|
except Exception:
|
|
pass
|
|
cursor.close()
|
|
|
|
def _yield_documents(
|
|
self,
|
|
start: Optional[datetime] = None,
|
|
end: Optional[datetime] = None,
|
|
) -> Generator[list[Document], None, None]:
|
|
"""Generate documents from database query results."""
|
|
if self.query:
|
|
query = self._build_query_with_time_filter(start, end)
|
|
yield from self._yield_documents_from_query(query)
|
|
else:
|
|
tables = self._get_tables()
|
|
logging.info(f"No query specified. Loading all {len(tables)} tables: {tables}")
|
|
for table in tables:
|
|
query = f"SELECT * FROM {table}"
|
|
logging.info(f"Loading table: {table}")
|
|
yield from self._yield_documents_from_query(query)
|
|
|
|
self._close_connection()
|
|
|
|
def load_from_state(self) -> Generator[list[Document], None, None]:
|
|
"""Load all documents from the database (full sync)."""
|
|
logging.debug(f"Loading all records from {self.db_type} database: {self.database}")
|
|
return self._yield_documents()
|
|
|
|
def poll_source(
|
|
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
|
) -> Generator[list[Document], None, None]:
|
|
"""Poll for new/updated documents since the last sync (incremental sync)."""
|
|
if not self.timestamp_column:
|
|
logging.warning(
|
|
"No timestamp column configured for incremental sync. "
|
|
"Falling back to full sync."
|
|
)
|
|
return self.load_from_state()
|
|
|
|
start_datetime = datetime.fromtimestamp(start, tz=timezone.utc)
|
|
end_datetime = datetime.fromtimestamp(end, tz=timezone.utc)
|
|
|
|
logging.debug(
|
|
f"Polling {self.db_type} database {self.database} "
|
|
f"from {start_datetime} to {end_datetime}"
|
|
)
|
|
|
|
return self._yield_documents(start_datetime, end_datetime)
|
|
|
|
def validate_connector_settings(self) -> None:
|
|
"""Validate connector settings by testing the connection."""
|
|
if not self._credentials:
|
|
raise ConnectorMissingCredentialError("RDBMS credentials not loaded.")
|
|
|
|
if not self.host:
|
|
raise ConnectorValidationError("Database host is required.")
|
|
|
|
if not self.database:
|
|
raise ConnectorValidationError("Database name is required.")
|
|
|
|
if not self.content_columns:
|
|
raise ConnectorValidationError(
|
|
"At least one content column must be specified."
|
|
)
|
|
|
|
try:
|
|
connection = self._get_connection()
|
|
cursor = connection.cursor()
|
|
|
|
test_query = "SELECT 1"
|
|
cursor.execute(test_query)
|
|
cursor.fetchone()
|
|
cursor.close()
|
|
|
|
logging.info(f"Successfully connected to {self.db_type} database: {self.database}")
|
|
|
|
except ConnectorValidationError:
|
|
self._close_connection()
|
|
raise
|
|
except Exception as e:
|
|
self._close_connection()
|
|
raise ConnectorValidationError(
|
|
f"Failed to connect to {self.db_type} database: {str(e)}"
|
|
)
|
|
finally:
|
|
self._close_connection()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import os
|
|
|
|
credentials_dict = {
|
|
"username": os.environ.get("DB_USERNAME", "root"),
|
|
"password": os.environ.get("DB_PASSWORD", ""),
|
|
}
|
|
|
|
connector = RDBMSConnector(
|
|
db_type="mysql",
|
|
host=os.environ.get("DB_HOST", "localhost"),
|
|
port=int(os.environ.get("DB_PORT", "3306")),
|
|
database=os.environ.get("DB_NAME", "test"),
|
|
query="SELECT * FROM products LIMIT 10",
|
|
content_columns="name,description",
|
|
metadata_columns="id,category,price",
|
|
id_column="id",
|
|
timestamp_column="updated_at",
|
|
)
|
|
|
|
try:
|
|
connector.load_credentials(credentials_dict)
|
|
connector.validate_connector_settings()
|
|
|
|
for batch in connector.load_from_state():
|
|
print(f"Batch of {len(batch)} documents:")
|
|
for doc in batch:
|
|
print(f" - {doc.id}: {doc.semantic_identifier}")
|
|
break
|
|
|
|
except Exception as e:
|
|
print(f"Error: {e}")
|