mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-02-02 08:35:08 +08:00
feat(tools): add Elasticsearch to OceanBase migration tool (#12927)
### What problem does this PR solve? fixes https://github.com/infiniflow/ragflow/issues/12774 Add a CLI tool for migrating RAGFlow data from Elasticsearch to OceanBase, enabling users to switch their document storage backend. - Automatic discovery and migration of all `ragflow_*` indices - Schema conversion with vector dimension auto-detection - Batch processing with progress tracking and resume capability - Data consistency validation and migration report generation **Note**: Due to network issues, I was unable to pull the required Docker images (Elasticsearch, OceanBase) to run the full end-to-end verification. Unit tests have been verified to pass. I will complete the e2e verification when network conditions allow, and submit a follow-up PR if any fixes are needed. ```bash ============================= test session starts ============================== platform darwin -- Python 3.13.6, pytest-9.0.2, pluggy-1.6.0 rootdir: /Users/sevenc/code/ai/oceanbase/ragflow/tools/es-to-oceanbase-migration configfile: pyproject.toml testpaths: tests plugins: anyio-4.12.1, asyncio-1.3.0, cov-7.0.0 collected 86 items tests/test_progress.py::TestMigrationProgress::test_create_basic_progress PASSED [ 1%] tests/test_progress.py::TestMigrationProgress::test_create_progress_with_counts PASSED [ 2%] tests/test_progress.py::TestMigrationProgress::test_progress_default_values PASSED [ 3%] tests/test_progress.py::TestMigrationProgress::test_progress_status_values PASSED [ 4%] tests/test_progress.py::TestProgressManager::test_create_progress_manager PASSED [ 5%] tests/test_progress.py::TestProgressManager::test_create_progress_manager_creates_dir PASSED [ 6%] tests/test_progress.py::TestProgressManager::test_create_progress PASSED [ 8%] tests/test_progress.py::TestProgressManager::test_save_and_load_progress PASSED [ 9%] tests/test_progress.py::TestProgressManager::test_load_nonexistent_progress PASSED [ 10%] tests/test_progress.py::TestProgressManager::test_delete_progress PASSED [ 11%] tests/test_progress.py::TestProgressManager::test_update_progress PASSED [ 12%] tests/test_progress.py::TestProgressManager::test_update_progress_multiple_batches PASSED [ 13%] tests/test_progress.py::TestProgressManager::test_mark_completed PASSED [ 15%] tests/test_progress.py::TestProgressManager::test_mark_failed PASSED [ 16%] tests/test_progress.py::TestProgressManager::test_mark_paused PASSED [ 17%] tests/test_progress.py::TestProgressManager::test_can_resume_running PASSED [ 18%] tests/test_progress.py::TestProgressManager::test_can_resume_paused PASSED [ 19%] tests/test_progress.py::TestProgressManager::test_can_resume_completed PASSED [ 20%] tests/test_progress.py::TestProgressManager::test_can_resume_nonexistent PASSED [ 22%] tests/test_progress.py::TestProgressManager::test_get_resume_info PASSED [ 23%] tests/test_progress.py::TestProgressManager::test_get_resume_info_nonexistent PASSED [ 24%] tests/test_progress.py::TestProgressManager::test_progress_file_path PASSED [ 25%] tests/test_progress.py::TestProgressManager::test_progress_file_content PASSED [ 26%] tests/test_schema.py::TestRAGFlowSchemaConverter::test_analyze_ragflow_mapping PASSED [ 27%] tests/test_schema.py::TestRAGFlowSchemaConverter::test_detect_vector_size PASSED [ 29%] tests/test_schema.py::TestRAGFlowSchemaConverter::test_unknown_fields PASSED [ 30%] tests/test_schema.py::TestRAGFlowSchemaConverter::test_get_column_definitions PASSED [ 31%] tests/test_schema.py::TestRAGFlowDataConverter::test_convert_basic_document PASSED [ 32%] tests/test_schema.py::TestRAGFlowDataConverter::test_convert_with_vector PASSED [ 33%] tests/test_schema.py::TestRAGFlowDataConverter::test_convert_array_fields PASSED [ 34%] tests/test_schema.py::TestRAGFlowDataConverter::test_convert_json_fields PASSED [ 36%] tests/test_schema.py::TestRAGFlowDataConverter::test_convert_unknown_fields_to_extra PASSED [ 37%] tests/test_schema.py::TestRAGFlowDataConverter::test_convert_kb_id_list PASSED [ 38%] tests/test_schema.py::TestRAGFlowDataConverter::test_convert_content_with_weight_dict PASSED [ 39%] tests/test_schema.py::TestRAGFlowDataConverter::test_convert_batch PASSED [ 40%] tests/test_schema.py::TestVectorFieldPattern::test_valid_patterns PASSED [ 41%] tests/test_schema.py::TestVectorFieldPattern::test_invalid_patterns PASSED [ 43%] tests/test_schema.py::TestVectorFieldPattern::test_extract_dimension PASSED [ 44%] tests/test_schema.py::TestConstants::test_array_columns PASSED [ 45%] tests/test_schema.py::TestConstants::test_json_columns PASSED [ 46%] tests/test_schema.py::TestConstants::test_ragflow_columns_completeness PASSED [ 47%] tests/test_schema.py::TestConstants::test_fts_columns PASSED [ 48%] tests/test_schema.py::TestConstants::test_ragflow_columns_types PASSED [ 50%] tests/test_schema.py::TestRAGFlowSchemaConverterEdgeCases::test_empty_mapping PASSED [ 51%] tests/test_schema.py::TestRAGFlowSchemaConverterEdgeCases::test_mapping_without_properties PASSED [ 52%] tests/test_schema.py::TestRAGFlowSchemaConverterEdgeCases::test_multiple_vector_fields PASSED [ 53%] tests/test_schema.py::TestRAGFlowSchemaConverterEdgeCases::test_get_column_definitions_without_analysis PASSED [ 54%] tests/test_schema.py::TestRAGFlowSchemaConverterEdgeCases::test_get_vector_fields PASSED [ 55%] tests/test_schema.py::TestRAGFlowDataConverterEdgeCases::test_convert_empty_document PASSED [ 56%] tests/test_schema.py::TestRAGFlowDataConverterEdgeCases::test_convert_document_without_source PASSED [ 58%] tests/test_schema.py::TestRAGFlowDataConverterEdgeCases::test_convert_boolean_to_integer PASSED [ 59%] tests/test_schema.py::TestRAGFlowDataConverterEdgeCases::test_convert_invalid_integer PASSED [ 60%] tests/test_schema.py::TestRAGFlowDataConverterEdgeCases::test_convert_float_field PASSED [ 61%] tests/test_schema.py::TestRAGFlowDataConverterEdgeCases::test_convert_array_with_special_characters PASSED [ 62%] tests/test_schema.py::TestRAGFlowDataConverterEdgeCases::test_convert_already_json_array PASSED [ 63%] tests/test_schema.py::TestRAGFlowDataConverterEdgeCases::test_convert_single_value_to_array PASSED [ 65%] tests/test_schema.py::TestRAGFlowDataConverterEdgeCases::test_detect_vector_fields_from_document PASSED [ 66%] tests/test_schema.py::TestRAGFlowDataConverterEdgeCases::test_convert_with_default_values PASSED [ 67%] tests/test_schema.py::TestRAGFlowDataConverterEdgeCases::test_convert_list_content PASSED [ 68%] tests/test_schema.py::TestRAGFlowDataConverterEdgeCases::test_convert_batch_empty PASSED [ 69%] tests/test_schema.py::TestRAGFlowDataConverterEdgeCases::test_existing_extra_field_merged PASSED [ 70%] tests/test_verify.py::TestVerificationResult::test_create_basic_result PASSED [ 72%] tests/test_verify.py::TestVerificationResult::test_result_default_values PASSED [ 73%] tests/test_verify.py::TestVerificationResult::test_result_with_counts PASSED [ 74%] tests/test_verify.py::TestMigrationVerifier::test_verify_counts_match PASSED [ 75%] tests/test_verify.py::TestMigrationVerifier::test_verify_counts_mismatch PASSED [ 76%] tests/test_verify.py::TestMigrationVerifier::test_verify_samples_all_match PASSED [ 77%] tests/test_verify.py::TestMigrationVerifier::test_verify_samples_some_missing PASSED [ 79%] tests/test_verify.py::TestMigrationVerifier::test_verify_samples_data_mismatch PASSED [ 80%] tests/test_verify.py::TestMigrationVerifier::test_values_equal_none_values PASSED [ 81%] tests/test_verify.py::TestMigrationVerifier::test_values_equal_array_columns PASSED [ 82%] tests/test_verify.py::TestMigrationVerifier::test_values_equal_json_columns PASSED [ 83%] tests/test_verify.py::TestMigrationVerifier::test_values_equal_kb_id_list PASSED [ 84%] tests/test_verify.py::TestMigrationVerifier::test_values_equal_content_with_weight_dict PASSED [ 86%] tests/test_verify.py::TestMigrationVerifier::test_determine_result_passed PASSED [ 87%] tests/test_verify.py::TestMigrationVerifier::test_determine_result_failed_count PASSED [ 88%] tests/test_verify.py::TestMigrationVerifier::test_determine_result_failed_samples PASSED [ 89%] tests/test_verify.py::TestMigrationVerifier::test_generate_report PASSED [ 90%] tests/test_verify.py::TestMigrationVerifier::test_generate_report_with_missing PASSED [ 91%] tests/test_verify.py::TestMigrationVerifier::test_generate_report_with_mismatches PASSED [ 93%] tests/test_verify.py::TestValueComparison::test_string_comparison PASSED [ 94%] tests/test_verify.py::TestValueComparison::test_integer_comparison PASSED [ 95%] tests/test_verify.py::TestValueComparison::test_float_comparison PASSED [ 96%] tests/test_verify.py::TestValueComparison::test_boolean_comparison PASSED [ 97%] tests/test_verify.py::TestValueComparison::test_empty_array_comparison PASSED [ 98%] tests/test_verify.py::TestValueComparison::test_nested_json_comparison PASSED [100%] ======================= 86 passed, 88 warnings in 0.66s ======================== ``` ### Type of change - [ ] Bug Fix (non-breaking change which fixes an issue) - [x] New Feature (non-breaking change which adds functionality) - [ ] Documentation Update - [ ] Refactoring - [ ] Performance Improvement - [ ] Other (please describe):
This commit is contained in:
@ -0,0 +1,41 @@
|
||||
"""
|
||||
RAGFlow ES to OceanBase Migration Tool
|
||||
|
||||
A CLI tool for migrating RAGFlow data from Elasticsearch 8+ to OceanBase,
|
||||
supporting schema conversion, vector data mapping, batch import, and resume capability.
|
||||
|
||||
This tool is specifically designed for RAGFlow's data structure.
|
||||
"""
|
||||
|
||||
__version__ = "0.1.0"
|
||||
|
||||
from .migrator import ESToOceanBaseMigrator
|
||||
from .es_client import ESClient
|
||||
from .ob_client import OBClient
|
||||
from .schema import RAGFlowSchemaConverter, RAGFlowDataConverter, RAGFLOW_COLUMNS
|
||||
from .verify import MigrationVerifier, VerificationResult
|
||||
from .progress import ProgressManager, MigrationProgress
|
||||
|
||||
# Backwards compatibility aliases
|
||||
SchemaConverter = RAGFlowSchemaConverter
|
||||
DataConverter = RAGFlowDataConverter
|
||||
|
||||
__all__ = [
|
||||
# Main classes
|
||||
"ESToOceanBaseMigrator",
|
||||
"ESClient",
|
||||
"OBClient",
|
||||
# Schema
|
||||
"RAGFlowSchemaConverter",
|
||||
"RAGFlowDataConverter",
|
||||
"RAGFLOW_COLUMNS",
|
||||
# Verification
|
||||
"MigrationVerifier",
|
||||
"VerificationResult",
|
||||
# Progress
|
||||
"ProgressManager",
|
||||
"MigrationProgress",
|
||||
# Aliases
|
||||
"SchemaConverter",
|
||||
"DataConverter",
|
||||
]
|
||||
574
tools/es-to-oceanbase-migration/src/es_ob_migration/cli.py
Normal file
574
tools/es-to-oceanbase-migration/src/es_ob_migration/cli.py
Normal file
@ -0,0 +1,574 @@
|
||||
"""
|
||||
CLI entry point for RAGFlow ES to OceanBase migration tool.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
|
||||
import click
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
from rich.logging import RichHandler
|
||||
|
||||
from .es_client import ESClient
|
||||
from .ob_client import OBClient
|
||||
from .migrator import ESToOceanBaseMigrator
|
||||
from .verify import MigrationVerifier
|
||||
from .schema import RAGFLOW_COLUMNS
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
def setup_logging(verbose: bool = False):
|
||||
"""Setup logging configuration."""
|
||||
level = logging.DEBUG if verbose else logging.INFO
|
||||
logging.basicConfig(
|
||||
level=level,
|
||||
format="%(message)s",
|
||||
datefmt="[%X]",
|
||||
handlers=[RichHandler(rich_tracebacks=True, console=console)],
|
||||
)
|
||||
|
||||
|
||||
@click.group()
|
||||
@click.option("-v", "--verbose", is_flag=True, help="Enable verbose logging")
|
||||
@click.pass_context
|
||||
def main(ctx, verbose):
|
||||
"""RAGFlow ES to OceanBase Migration Tool.
|
||||
|
||||
Migrate RAGFlow data from Elasticsearch 8+ to OceanBase with schema conversion,
|
||||
vector data mapping, batch import, and resume capability.
|
||||
|
||||
This tool is specifically designed for RAGFlow's data structure.
|
||||
"""
|
||||
ctx.ensure_object(dict)
|
||||
ctx.obj["verbose"] = verbose
|
||||
setup_logging(verbose)
|
||||
|
||||
|
||||
@main.command()
|
||||
@click.option("--es-host", default="localhost", help="Elasticsearch host")
|
||||
@click.option("--es-port", default=9200, type=int, help="Elasticsearch port")
|
||||
@click.option("--es-user", default=None, help="Elasticsearch username")
|
||||
@click.option("--es-password", default=None, help="Elasticsearch password")
|
||||
@click.option("--es-api-key", default=None, help="Elasticsearch API key")
|
||||
@click.option("--ob-host", default="localhost", help="OceanBase host")
|
||||
@click.option("--ob-port", default=2881, type=int, help="OceanBase port")
|
||||
@click.option("--ob-user", default="root@test", help="OceanBase user (format: user@tenant)")
|
||||
@click.option("--ob-password", default="", help="OceanBase password")
|
||||
@click.option("--ob-database", default="test", help="OceanBase database")
|
||||
@click.option("--index", "-i", default=None, help="Source ES index name (omit to migrate all ragflow_* indices)")
|
||||
@click.option("--table", "-t", default=None, help="Target OceanBase table name (omit to use same name as index)")
|
||||
@click.option("--batch-size", default=1000, type=int, help="Batch size for migration")
|
||||
@click.option("--resume", is_flag=True, help="Resume from previous progress")
|
||||
@click.option("--verify/--no-verify", default=True, help="Verify after migration")
|
||||
@click.option("--progress-dir", default=".migration_progress", help="Progress file directory")
|
||||
@click.pass_context
|
||||
def migrate(
|
||||
ctx,
|
||||
es_host,
|
||||
es_port,
|
||||
es_user,
|
||||
es_password,
|
||||
es_api_key,
|
||||
ob_host,
|
||||
ob_port,
|
||||
ob_user,
|
||||
ob_password,
|
||||
ob_database,
|
||||
index,
|
||||
table,
|
||||
batch_size,
|
||||
resume,
|
||||
verify,
|
||||
progress_dir,
|
||||
):
|
||||
"""Run RAGFlow data migration from Elasticsearch to OceanBase.
|
||||
|
||||
If --index is omitted, all indices starting with 'ragflow_' will be migrated.
|
||||
If --table is omitted, the same name as the source index will be used.
|
||||
"""
|
||||
console.print("[bold]RAGFlow ES to OceanBase Migration[/]")
|
||||
|
||||
try:
|
||||
# Initialize ES client first to discover indices if needed
|
||||
es_client = ESClient(
|
||||
host=es_host,
|
||||
port=es_port,
|
||||
username=es_user,
|
||||
password=es_password,
|
||||
api_key=es_api_key,
|
||||
)
|
||||
|
||||
ob_client = OBClient(
|
||||
host=ob_host,
|
||||
port=ob_port,
|
||||
user=ob_user,
|
||||
password=ob_password,
|
||||
database=ob_database,
|
||||
)
|
||||
|
||||
# Determine indices to migrate
|
||||
if index:
|
||||
# Single index specified
|
||||
indices_to_migrate = [(index, table if table else index)]
|
||||
else:
|
||||
# Auto-discover all ragflow_* indices
|
||||
console.print(f"\n[cyan]Discovering RAGFlow indices...[/]")
|
||||
ragflow_indices = es_client.list_ragflow_indices()
|
||||
|
||||
if not ragflow_indices:
|
||||
console.print("[yellow]No ragflow_* indices found in Elasticsearch[/]")
|
||||
sys.exit(0)
|
||||
|
||||
# Each index maps to a table with the same name
|
||||
indices_to_migrate = [(idx, idx) for idx in ragflow_indices]
|
||||
|
||||
console.print(f"[green]Found {len(indices_to_migrate)} RAGFlow indices:[/]")
|
||||
for idx, _ in indices_to_migrate:
|
||||
doc_count = es_client.count_documents(idx)
|
||||
console.print(f" - {idx} ({doc_count:,} documents)")
|
||||
console.print()
|
||||
|
||||
# Initialize migrator
|
||||
migrator = ESToOceanBaseMigrator(
|
||||
es_client=es_client,
|
||||
ob_client=ob_client,
|
||||
progress_dir=progress_dir,
|
||||
)
|
||||
|
||||
# Track overall results
|
||||
total_success = 0
|
||||
total_failed = 0
|
||||
results = []
|
||||
|
||||
# Migrate each index
|
||||
for es_index, ob_table in indices_to_migrate:
|
||||
console.print(f"\n[bold blue]{'='*60}[/]")
|
||||
console.print(f"[bold]Migrating: {es_index} -> {ob_database}.{ob_table}[/]")
|
||||
console.print(f"[bold blue]{'='*60}[/]")
|
||||
|
||||
result = migrator.migrate(
|
||||
es_index=es_index,
|
||||
ob_table=ob_table,
|
||||
batch_size=batch_size,
|
||||
resume=resume,
|
||||
verify_after=verify,
|
||||
)
|
||||
|
||||
results.append(result)
|
||||
if result["success"]:
|
||||
total_success += 1
|
||||
else:
|
||||
total_failed += 1
|
||||
|
||||
# Summary for multiple indices
|
||||
if len(indices_to_migrate) > 1:
|
||||
console.print(f"\n[bold]{'='*60}[/]")
|
||||
console.print(f"[bold]Migration Summary[/]")
|
||||
console.print(f"[bold]{'='*60}[/]")
|
||||
console.print(f" Total indices: {len(indices_to_migrate)}")
|
||||
console.print(f" [green]Successful: {total_success}[/]")
|
||||
if total_failed > 0:
|
||||
console.print(f" [red]Failed: {total_failed}[/]")
|
||||
|
||||
# Exit code based on results
|
||||
if total_failed == 0:
|
||||
console.print("\n[bold green]All migrations completed successfully![/]")
|
||||
sys.exit(0)
|
||||
else:
|
||||
console.print(f"\n[bold red]{total_failed} migration(s) failed[/]")
|
||||
sys.exit(1)
|
||||
|
||||
except Exception as e:
|
||||
console.print(f"[bold red]Error: {e}[/]")
|
||||
if ctx.obj.get("verbose"):
|
||||
console.print_exception()
|
||||
sys.exit(1)
|
||||
finally:
|
||||
# Cleanup
|
||||
if "es_client" in locals():
|
||||
es_client.close()
|
||||
if "ob_client" in locals():
|
||||
ob_client.close()
|
||||
|
||||
|
||||
@main.command()
|
||||
@click.option("--es-host", default="localhost", help="Elasticsearch host")
|
||||
@click.option("--es-port", default=9200, type=int, help="Elasticsearch port")
|
||||
@click.option("--es-user", default=None, help="Elasticsearch username")
|
||||
@click.option("--es-password", default=None, help="Elasticsearch password")
|
||||
@click.option("--index", "-i", required=True, help="ES index name")
|
||||
@click.option("--output", "-o", default=None, help="Output file (JSON)")
|
||||
@click.pass_context
|
||||
def schema(ctx, es_host, es_port, es_user, es_password, index, output):
|
||||
"""Preview RAGFlow schema analysis from ES mapping."""
|
||||
try:
|
||||
es_client = ESClient(
|
||||
host=es_host,
|
||||
port=es_port,
|
||||
username=es_user,
|
||||
password=es_password,
|
||||
)
|
||||
|
||||
# Dummy OB client for schema preview
|
||||
ob_client = None
|
||||
|
||||
migrator = ESToOceanBaseMigrator(es_client, ob_client if ob_client else OBClient.__new__(OBClient))
|
||||
# Directly use schema converter
|
||||
from .schema import RAGFlowSchemaConverter
|
||||
converter = RAGFlowSchemaConverter()
|
||||
|
||||
es_mapping = es_client.get_index_mapping(index)
|
||||
analysis = converter.analyze_es_mapping(es_mapping)
|
||||
column_defs = converter.get_column_definitions()
|
||||
|
||||
# Display analysis
|
||||
console.print(f"\n[bold]ES Index Analysis: {index}[/]\n")
|
||||
|
||||
# Known RAGFlow fields
|
||||
console.print(f"[green]Known RAGFlow fields:[/] {len(analysis['known_fields'])}")
|
||||
|
||||
# Vector fields
|
||||
if analysis['vector_fields']:
|
||||
console.print(f"\n[cyan]Vector fields detected:[/]")
|
||||
for vf in analysis['vector_fields']:
|
||||
console.print(f" - {vf['name']} (dimension: {vf['dimension']})")
|
||||
|
||||
# Unknown fields
|
||||
if analysis['unknown_fields']:
|
||||
console.print(f"\n[yellow]Unknown fields (will be stored in 'extra'):[/]")
|
||||
for uf in analysis['unknown_fields']:
|
||||
console.print(f" - {uf}")
|
||||
|
||||
# Display RAGFlow column schema
|
||||
console.print(f"\n[bold]RAGFlow OceanBase Schema ({len(column_defs)} columns):[/]\n")
|
||||
|
||||
table = Table(title="Column Definitions")
|
||||
table.add_column("Column Name", style="cyan")
|
||||
table.add_column("OB Type", style="green")
|
||||
table.add_column("Nullable", style="yellow")
|
||||
table.add_column("Special", style="magenta")
|
||||
|
||||
for col in column_defs[:20]: # Show first 20
|
||||
special = []
|
||||
if col.get("is_primary"):
|
||||
special.append("PK")
|
||||
if col.get("index"):
|
||||
special.append("IDX")
|
||||
if col.get("is_array"):
|
||||
special.append("ARRAY")
|
||||
if col.get("is_vector"):
|
||||
special.append("VECTOR")
|
||||
|
||||
table.add_row(
|
||||
col["name"],
|
||||
col["ob_type"],
|
||||
"Yes" if col.get("nullable", True) else "No",
|
||||
", ".join(special) if special else "-",
|
||||
)
|
||||
|
||||
if len(column_defs) > 20:
|
||||
table.add_row("...", f"({len(column_defs) - 20} more)", "", "")
|
||||
|
||||
console.print(table)
|
||||
|
||||
# Save to file if requested
|
||||
if output:
|
||||
preview = {
|
||||
"es_index": index,
|
||||
"es_mapping": es_mapping,
|
||||
"analysis": analysis,
|
||||
"ob_columns": column_defs,
|
||||
}
|
||||
with open(output, "w") as f:
|
||||
json.dump(preview, f, indent=2, default=str)
|
||||
console.print(f"\nSchema saved to {output}")
|
||||
|
||||
except Exception as e:
|
||||
console.print(f"[bold red]Error: {e}[/]")
|
||||
if ctx.obj.get("verbose"):
|
||||
console.print_exception()
|
||||
sys.exit(1)
|
||||
finally:
|
||||
if "es_client" in locals():
|
||||
es_client.close()
|
||||
|
||||
|
||||
@main.command()
|
||||
@click.option("--es-host", default="localhost", help="Elasticsearch host")
|
||||
@click.option("--es-port", default=9200, type=int, help="Elasticsearch port")
|
||||
@click.option("--ob-host", default="localhost", help="OceanBase host")
|
||||
@click.option("--ob-port", default=2881, type=int, help="OceanBase port")
|
||||
@click.option("--ob-user", default="root@test", help="OceanBase user")
|
||||
@click.option("--ob-password", default="", help="OceanBase password")
|
||||
@click.option("--ob-database", default="test", help="OceanBase database")
|
||||
@click.option("--index", "-i", required=True, help="Source ES index name")
|
||||
@click.option("--table", "-t", required=True, help="Target OceanBase table name")
|
||||
@click.option("--sample-size", default=100, type=int, help="Sample size for verification")
|
||||
@click.pass_context
|
||||
def verify(
|
||||
ctx,
|
||||
es_host,
|
||||
es_port,
|
||||
ob_host,
|
||||
ob_port,
|
||||
ob_user,
|
||||
ob_password,
|
||||
ob_database,
|
||||
index,
|
||||
table,
|
||||
sample_size,
|
||||
):
|
||||
"""Verify migration data consistency."""
|
||||
try:
|
||||
es_client = ESClient(host=es_host, port=es_port)
|
||||
ob_client = OBClient(
|
||||
host=ob_host,
|
||||
port=ob_port,
|
||||
user=ob_user,
|
||||
password=ob_password,
|
||||
database=ob_database,
|
||||
)
|
||||
|
||||
verifier = MigrationVerifier(es_client, ob_client)
|
||||
result = verifier.verify(
|
||||
index, table,
|
||||
sample_size=sample_size,
|
||||
)
|
||||
|
||||
console.print(verifier.generate_report(result))
|
||||
|
||||
sys.exit(0 if result.passed else 1)
|
||||
|
||||
except Exception as e:
|
||||
console.print(f"[bold red]Error: {e}[/]")
|
||||
if ctx.obj.get("verbose"):
|
||||
console.print_exception()
|
||||
sys.exit(1)
|
||||
finally:
|
||||
if "es_client" in locals():
|
||||
es_client.close()
|
||||
if "ob_client" in locals():
|
||||
ob_client.close()
|
||||
|
||||
|
||||
@main.command("list-indices")
|
||||
@click.option("--es-host", default="localhost", help="Elasticsearch host")
|
||||
@click.option("--es-port", default=9200, type=int, help="Elasticsearch port")
|
||||
@click.option("--es-user", default=None, help="Elasticsearch username")
|
||||
@click.option("--es-password", default=None, help="Elasticsearch password")
|
||||
@click.pass_context
|
||||
def list_indices(ctx, es_host, es_port, es_user, es_password):
|
||||
"""List all RAGFlow indices (ragflow_*) in Elasticsearch."""
|
||||
try:
|
||||
es_client = ESClient(
|
||||
host=es_host,
|
||||
port=es_port,
|
||||
username=es_user,
|
||||
password=es_password,
|
||||
)
|
||||
|
||||
console.print(f"\n[bold]RAGFlow Indices in Elasticsearch ({es_host}:{es_port})[/]\n")
|
||||
|
||||
indices = es_client.list_ragflow_indices()
|
||||
|
||||
if not indices:
|
||||
console.print("[yellow]No ragflow_* indices found[/]")
|
||||
return
|
||||
|
||||
table = Table(title="RAGFlow Indices")
|
||||
table.add_column("Index Name", style="cyan")
|
||||
table.add_column("Document Count", style="green", justify="right")
|
||||
table.add_column("Type", style="yellow")
|
||||
|
||||
total_docs = 0
|
||||
for idx in indices:
|
||||
doc_count = es_client.count_documents(idx)
|
||||
total_docs += doc_count
|
||||
|
||||
# Determine index type
|
||||
if idx.startswith("ragflow_doc_meta_"):
|
||||
idx_type = "Metadata"
|
||||
elif idx.startswith("ragflow_"):
|
||||
idx_type = "Document Chunks"
|
||||
else:
|
||||
idx_type = "Unknown"
|
||||
|
||||
table.add_row(idx, f"{doc_count:,}", idx_type)
|
||||
|
||||
table.add_row("", "", "")
|
||||
table.add_row("[bold]Total[/]", f"[bold]{total_docs:,}[/]", f"[bold]{len(indices)} indices[/]")
|
||||
|
||||
console.print(table)
|
||||
|
||||
except Exception as e:
|
||||
console.print(f"[bold red]Error: {e}[/]")
|
||||
if ctx.obj.get("verbose"):
|
||||
console.print_exception()
|
||||
sys.exit(1)
|
||||
finally:
|
||||
if "es_client" in locals():
|
||||
es_client.close()
|
||||
|
||||
|
||||
@main.command("list-kb")
|
||||
@click.option("--es-host", default="localhost", help="Elasticsearch host")
|
||||
@click.option("--es-port", default=9200, type=int, help="Elasticsearch port")
|
||||
@click.option("--es-user", default=None, help="Elasticsearch username")
|
||||
@click.option("--es-password", default=None, help="Elasticsearch password")
|
||||
@click.option("--index", "-i", required=True, help="ES index name")
|
||||
@click.pass_context
|
||||
def list_kb(ctx, es_host, es_port, es_user, es_password, index):
|
||||
"""List all knowledge bases in an ES index."""
|
||||
try:
|
||||
es_client = ESClient(
|
||||
host=es_host,
|
||||
port=es_port,
|
||||
username=es_user,
|
||||
password=es_password,
|
||||
)
|
||||
|
||||
console.print(f"\n[bold]Knowledge Bases in index: {index}[/]\n")
|
||||
|
||||
# Get kb_id aggregation
|
||||
agg_result = es_client.aggregate_field(index, "kb_id")
|
||||
buckets = agg_result.get("buckets", [])
|
||||
|
||||
if not buckets:
|
||||
console.print("[yellow]No knowledge bases found[/]")
|
||||
return
|
||||
|
||||
table = Table(title="Knowledge Bases")
|
||||
table.add_column("KB ID", style="cyan")
|
||||
table.add_column("Document Count", style="green", justify="right")
|
||||
|
||||
total_docs = 0
|
||||
for bucket in buckets:
|
||||
table.add_row(
|
||||
bucket["key"],
|
||||
f"{bucket['doc_count']:,}",
|
||||
)
|
||||
total_docs += bucket["doc_count"]
|
||||
|
||||
table.add_row("", "")
|
||||
table.add_row("[bold]Total[/]", f"[bold]{total_docs:,}[/]")
|
||||
|
||||
console.print(table)
|
||||
console.print(f"\nTotal knowledge bases: {len(buckets)}")
|
||||
|
||||
except Exception as e:
|
||||
console.print(f"[bold red]Error: {e}[/]")
|
||||
if ctx.obj.get("verbose"):
|
||||
console.print_exception()
|
||||
sys.exit(1)
|
||||
finally:
|
||||
if "es_client" in locals():
|
||||
es_client.close()
|
||||
|
||||
|
||||
@main.command()
|
||||
@click.option("--es-host", default="localhost", help="Elasticsearch host")
|
||||
@click.option("--es-port", default=9200, type=int, help="Elasticsearch port")
|
||||
@click.option("--ob-host", default="localhost", help="OceanBase host")
|
||||
@click.option("--ob-port", default=2881, type=int, help="OceanBase port")
|
||||
@click.option("--ob-user", default="root@test", help="OceanBase user")
|
||||
@click.option("--ob-password", default="", help="OceanBase password")
|
||||
@click.pass_context
|
||||
def status(ctx, es_host, es_port, ob_host, ob_port, ob_user, ob_password):
|
||||
"""Check connection status to ES and OceanBase."""
|
||||
console.print("[bold]Connection Status[/]\n")
|
||||
|
||||
# Check ES
|
||||
try:
|
||||
es_client = ESClient(host=es_host, port=es_port)
|
||||
health = es_client.health_check()
|
||||
info = es_client.get_cluster_info()
|
||||
console.print(f"[green]Elasticsearch ({es_host}:{es_port}): Connected[/]")
|
||||
console.print(f" Cluster: {health.get('cluster_name')}")
|
||||
console.print(f" Status: {health.get('status')}")
|
||||
console.print(f" Version: {info.get('version', {}).get('number', 'unknown')}")
|
||||
|
||||
# List indices
|
||||
indices = es_client.list_indices("*")
|
||||
console.print(f" Indices: {len(indices)}")
|
||||
|
||||
es_client.close()
|
||||
except Exception as e:
|
||||
console.print(f"[red]Elasticsearch ({es_host}:{es_port}): Failed[/]")
|
||||
console.print(f" Error: {e}")
|
||||
|
||||
console.print()
|
||||
|
||||
# Check OceanBase
|
||||
try:
|
||||
ob_client = OBClient(
|
||||
host=ob_host,
|
||||
port=ob_port,
|
||||
user=ob_user,
|
||||
password=ob_password,
|
||||
)
|
||||
if ob_client.health_check():
|
||||
version = ob_client.get_version()
|
||||
console.print(f"[green]OceanBase ({ob_host}:{ob_port}): Connected[/]")
|
||||
console.print(f" Version: {version}")
|
||||
else:
|
||||
console.print(f"[red]OceanBase ({ob_host}:{ob_port}): Health check failed[/]")
|
||||
ob_client.close()
|
||||
except Exception as e:
|
||||
console.print(f"[red]OceanBase ({ob_host}:{ob_port}): Failed[/]")
|
||||
console.print(f" Error: {e}")
|
||||
|
||||
|
||||
@main.command()
|
||||
@click.option("--es-host", default="localhost", help="Elasticsearch host")
|
||||
@click.option("--es-port", default=9200, type=int, help="Elasticsearch port")
|
||||
@click.option("--index", "-i", required=True, help="ES index name")
|
||||
@click.option("--size", "-n", default=5, type=int, help="Number of samples")
|
||||
@click.pass_context
|
||||
def sample(ctx, es_host, es_port, index, size):
|
||||
"""Show sample documents from ES index."""
|
||||
try:
|
||||
es_client = ESClient(host=es_host, port=es_port)
|
||||
|
||||
docs = es_client.get_sample_documents(index, size)
|
||||
|
||||
console.print(f"\n[bold]Sample documents from {index}[/]")
|
||||
console.print()
|
||||
|
||||
for i, doc in enumerate(docs, 1):
|
||||
console.print(f"[bold cyan]Document {i}[/]")
|
||||
console.print(f" _id: {doc.get('_id')}")
|
||||
console.print(f" kb_id: {doc.get('kb_id')}")
|
||||
console.print(f" doc_id: {doc.get('doc_id')}")
|
||||
console.print(f" docnm_kwd: {doc.get('docnm_kwd')}")
|
||||
|
||||
# Check for vector fields
|
||||
vector_fields = [k for k in doc.keys() if k.startswith("q_") and k.endswith("_vec")]
|
||||
if vector_fields:
|
||||
for vf in vector_fields:
|
||||
vec = doc.get(vf)
|
||||
if vec:
|
||||
console.print(f" {vf}: [{len(vec)} dimensions]")
|
||||
|
||||
content = doc.get("content_with_weight", "")
|
||||
if content:
|
||||
if isinstance(content, dict):
|
||||
content = json.dumps(content, ensure_ascii=False)
|
||||
preview = content[:100] + "..." if len(str(content)) > 100 else content
|
||||
console.print(f" content: {preview}")
|
||||
|
||||
console.print()
|
||||
|
||||
es_client.close()
|
||||
|
||||
except Exception as e:
|
||||
console.print(f"[bold red]Error: {e}[/]")
|
||||
if ctx.obj.get("verbose"):
|
||||
console.print_exception()
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
292
tools/es-to-oceanbase-migration/src/es_ob_migration/es_client.py
Normal file
292
tools/es-to-oceanbase-migration/src/es_ob_migration/es_client.py
Normal file
@ -0,0 +1,292 @@
|
||||
"""
|
||||
Elasticsearch 8+ Client for RAGFlow data migration.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Iterator
|
||||
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ESClient:
|
||||
"""Elasticsearch client wrapper for RAGFlow migration operations."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str = "localhost",
|
||||
port: int = 9200,
|
||||
username: str | None = None,
|
||||
password: str | None = None,
|
||||
api_key: str | None = None,
|
||||
use_ssl: bool = False,
|
||||
verify_certs: bool = True,
|
||||
):
|
||||
"""
|
||||
Initialize ES client.
|
||||
|
||||
Args:
|
||||
host: ES host address
|
||||
port: ES port
|
||||
username: Basic auth username
|
||||
password: Basic auth password
|
||||
api_key: API key for authentication
|
||||
use_ssl: Whether to use SSL
|
||||
verify_certs: Whether to verify SSL certificates
|
||||
"""
|
||||
self.host = host
|
||||
self.port = port
|
||||
|
||||
# Build connection URL
|
||||
scheme = "https" if use_ssl else "http"
|
||||
url = f"{scheme}://{host}:{port}"
|
||||
|
||||
# Build connection arguments
|
||||
conn_args: dict[str, Any] = {
|
||||
"hosts": [url],
|
||||
"verify_certs": verify_certs,
|
||||
}
|
||||
|
||||
if api_key:
|
||||
conn_args["api_key"] = api_key
|
||||
elif username and password:
|
||||
conn_args["basic_auth"] = (username, password)
|
||||
|
||||
self.client = Elasticsearch(**conn_args)
|
||||
logger.info(f"Connected to Elasticsearch at {url}")
|
||||
|
||||
def health_check(self) -> dict[str, Any]:
|
||||
"""Check cluster health."""
|
||||
return self.client.cluster.health().body
|
||||
|
||||
def get_cluster_info(self) -> dict[str, Any]:
|
||||
"""Get cluster information."""
|
||||
return self.client.info().body
|
||||
|
||||
def list_indices(self, pattern: str = "*") -> list[str]:
|
||||
"""List all indices matching pattern."""
|
||||
response = self.client.indices.get(index=pattern)
|
||||
return list(response.keys())
|
||||
|
||||
def list_ragflow_indices(self) -> list[str]:
|
||||
"""
|
||||
List all RAGFlow-related indices.
|
||||
|
||||
Returns indices matching patterns:
|
||||
- ragflow_* (document chunks)
|
||||
- ragflow_doc_meta_* (document metadata)
|
||||
|
||||
Returns:
|
||||
List of RAGFlow index names
|
||||
"""
|
||||
try:
|
||||
# Get all ragflow_* indices
|
||||
ragflow_indices = self.list_indices("ragflow_*")
|
||||
return sorted(ragflow_indices)
|
||||
except Exception:
|
||||
# If no indices match, return empty list
|
||||
return []
|
||||
|
||||
def get_index_mapping(self, index_name: str) -> dict[str, Any]:
|
||||
"""
|
||||
Get index mapping.
|
||||
|
||||
Args:
|
||||
index_name: Name of the index
|
||||
|
||||
Returns:
|
||||
Index mapping dictionary
|
||||
"""
|
||||
response = self.client.indices.get_mapping(index=index_name)
|
||||
return response[index_name]["mappings"]
|
||||
|
||||
def get_index_settings(self, index_name: str) -> dict[str, Any]:
|
||||
"""Get index settings."""
|
||||
response = self.client.indices.get_settings(index=index_name)
|
||||
return response[index_name]["settings"]
|
||||
|
||||
def count_documents(self, index_name: str) -> int:
|
||||
"""Count documents in an index."""
|
||||
response = self.client.count(index=index_name)
|
||||
return response["count"]
|
||||
|
||||
def count_documents_with_filter(
|
||||
self,
|
||||
index_name: str,
|
||||
filters: dict[str, Any]
|
||||
) -> int:
|
||||
"""
|
||||
Count documents with filter conditions.
|
||||
|
||||
Args:
|
||||
index_name: Index name
|
||||
filters: Filter conditions (e.g., {"kb_id": "xxx"})
|
||||
|
||||
Returns:
|
||||
Document count
|
||||
"""
|
||||
# Build bool query with filters
|
||||
must_clauses = []
|
||||
for field, value in filters.items():
|
||||
if isinstance(value, list):
|
||||
must_clauses.append({"terms": {field: value}})
|
||||
else:
|
||||
must_clauses.append({"term": {field: value}})
|
||||
|
||||
query = {
|
||||
"bool": {
|
||||
"must": must_clauses
|
||||
}
|
||||
} if must_clauses else {"match_all": {}}
|
||||
|
||||
response = self.client.count(index=index_name, query=query)
|
||||
return response["count"]
|
||||
|
||||
def aggregate_field(
|
||||
self,
|
||||
index_name: str,
|
||||
field: str,
|
||||
size: int = 10000,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Aggregate field values (like getting all unique kb_ids).
|
||||
|
||||
Args:
|
||||
index_name: Index name
|
||||
field: Field to aggregate
|
||||
size: Max number of buckets
|
||||
|
||||
Returns:
|
||||
Aggregation result with buckets
|
||||
"""
|
||||
response = self.client.search(
|
||||
index=index_name,
|
||||
size=0,
|
||||
aggs={
|
||||
"field_values": {
|
||||
"terms": {
|
||||
"field": field,
|
||||
"size": size,
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
return response["aggregations"]["field_values"]
|
||||
|
||||
def scroll_documents(
|
||||
self,
|
||||
index_name: str,
|
||||
batch_size: int = 1000,
|
||||
query: dict[str, Any] | None = None,
|
||||
sort_field: str = "_doc",
|
||||
) -> Iterator[list[dict[str, Any]]]:
|
||||
"""
|
||||
Scroll through all documents in an index using search_after (ES 8+).
|
||||
|
||||
This is the recommended approach for ES 8+ instead of scroll API.
|
||||
Uses search_after for efficient deep pagination.
|
||||
|
||||
Args:
|
||||
index_name: Name of the index
|
||||
batch_size: Number of documents per batch
|
||||
query: Optional query filter
|
||||
sort_field: Field to sort by (default: _doc for efficiency)
|
||||
|
||||
Yields:
|
||||
Batches of documents
|
||||
"""
|
||||
search_body: dict[str, Any] = {
|
||||
"size": batch_size,
|
||||
"sort": [{sort_field: "asc"}, {"_id": "asc"}],
|
||||
}
|
||||
|
||||
if query:
|
||||
search_body["query"] = query
|
||||
else:
|
||||
search_body["query"] = {"match_all": {}}
|
||||
|
||||
# Initial search
|
||||
response = self.client.search(index=index_name, body=search_body)
|
||||
hits = response["hits"]["hits"]
|
||||
|
||||
while hits:
|
||||
# Extract documents with _id
|
||||
documents = []
|
||||
for hit in hits:
|
||||
doc = hit["_source"].copy()
|
||||
doc["_id"] = hit["_id"]
|
||||
if "_score" in hit:
|
||||
doc["_score"] = hit["_score"]
|
||||
documents.append(doc)
|
||||
|
||||
yield documents
|
||||
|
||||
# Check if there are more results
|
||||
if len(hits) < batch_size:
|
||||
break
|
||||
|
||||
# Get search_after value from last hit
|
||||
search_after = hits[-1]["sort"]
|
||||
search_body["search_after"] = search_after
|
||||
|
||||
response = self.client.search(index=index_name, body=search_body)
|
||||
hits = response["hits"]["hits"]
|
||||
|
||||
def get_document(self, index_name: str, doc_id: str) -> dict[str, Any] | None:
|
||||
"""Get a single document by ID."""
|
||||
try:
|
||||
response = self.client.get(index=index_name, id=doc_id)
|
||||
doc = response["_source"].copy()
|
||||
doc["_id"] = response["_id"]
|
||||
return doc
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_sample_documents(
|
||||
self,
|
||||
index_name: str,
|
||||
size: int = 10,
|
||||
query: dict[str, Any] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Get sample documents from an index.
|
||||
|
||||
Args:
|
||||
index_name: Index name
|
||||
size: Number of samples
|
||||
query: Optional query filter
|
||||
"""
|
||||
search_body = {
|
||||
"query": query if query else {"match_all": {}},
|
||||
"size": size
|
||||
}
|
||||
|
||||
response = self.client.search(index=index_name, body=search_body)
|
||||
documents = []
|
||||
for hit in response["hits"]["hits"]:
|
||||
doc = hit["_source"].copy()
|
||||
doc["_id"] = hit["_id"]
|
||||
documents.append(doc)
|
||||
return documents
|
||||
|
||||
def get_document_ids(
|
||||
self,
|
||||
index_name: str,
|
||||
size: int = 1000,
|
||||
query: dict[str, Any] | None = None,
|
||||
) -> list[str]:
|
||||
"""Get list of document IDs."""
|
||||
search_body = {
|
||||
"query": query if query else {"match_all": {}},
|
||||
"size": size,
|
||||
"_source": False,
|
||||
}
|
||||
|
||||
response = self.client.search(index=index_name, body=search_body)
|
||||
return [hit["_id"] for hit in response["hits"]["hits"]]
|
||||
|
||||
def close(self):
|
||||
"""Close the ES client connection."""
|
||||
self.client.close()
|
||||
logger.info("Elasticsearch connection closed")
|
||||
370
tools/es-to-oceanbase-migration/src/es_ob_migration/migrator.py
Normal file
370
tools/es-to-oceanbase-migration/src/es_ob_migration/migrator.py
Normal file
@ -0,0 +1,370 @@
|
||||
"""
|
||||
RAGFlow-specific migration orchestrator from Elasticsearch to OceanBase.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Callable
|
||||
|
||||
from rich.console import Console
|
||||
from rich.progress import (
|
||||
Progress,
|
||||
SpinnerColumn,
|
||||
TextColumn,
|
||||
BarColumn,
|
||||
TaskProgressColumn,
|
||||
TimeRemainingColumn,
|
||||
)
|
||||
|
||||
from .es_client import ESClient
|
||||
from .ob_client import OBClient
|
||||
from .schema import RAGFlowSchemaConverter, RAGFlowDataConverter, VECTOR_FIELD_PATTERN
|
||||
from .progress import ProgressManager, MigrationProgress
|
||||
from .verify import MigrationVerifier
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
console = Console()
|
||||
|
||||
|
||||
class ESToOceanBaseMigrator:
|
||||
"""
|
||||
RAGFlow-specific migration orchestrator.
|
||||
|
||||
This migrator is designed specifically for RAGFlow's data structure,
|
||||
handling the fixed schema and vector embeddings correctly.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
es_client: ESClient,
|
||||
ob_client: OBClient,
|
||||
progress_dir: str = ".migration_progress",
|
||||
):
|
||||
"""
|
||||
Initialize migrator.
|
||||
|
||||
Args:
|
||||
es_client: Elasticsearch client
|
||||
ob_client: OceanBase client
|
||||
progress_dir: Directory for progress files
|
||||
"""
|
||||
self.es_client = es_client
|
||||
self.ob_client = ob_client
|
||||
self.progress_manager = ProgressManager(progress_dir)
|
||||
self.schema_converter = RAGFlowSchemaConverter()
|
||||
|
||||
def migrate(
|
||||
self,
|
||||
es_index: str,
|
||||
ob_table: str,
|
||||
batch_size: int = 1000,
|
||||
resume: bool = False,
|
||||
verify_after: bool = True,
|
||||
on_progress: Callable[[int, int], None] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Execute full migration from ES to OceanBase for RAGFlow data.
|
||||
|
||||
Args:
|
||||
es_index: Source Elasticsearch index
|
||||
ob_table: Target OceanBase table
|
||||
batch_size: Documents per batch
|
||||
resume: Resume from previous progress
|
||||
verify_after: Run verification after migration
|
||||
on_progress: Progress callback (migrated, total)
|
||||
|
||||
Returns:
|
||||
Migration result dictionary
|
||||
"""
|
||||
start_time = time.time()
|
||||
result = {
|
||||
"success": False,
|
||||
"es_index": es_index,
|
||||
"ob_table": ob_table,
|
||||
"total_documents": 0,
|
||||
"migrated_documents": 0,
|
||||
"failed_documents": 0,
|
||||
"duration_seconds": 0,
|
||||
"verification": None,
|
||||
"error": None,
|
||||
}
|
||||
|
||||
progress: MigrationProgress | None = None
|
||||
|
||||
try:
|
||||
# Step 1: Check connections
|
||||
console.print("[bold blue]Step 1: Checking connections...[/]")
|
||||
self._check_connections()
|
||||
|
||||
# Step 2: Analyze ES index
|
||||
console.print("\n[bold blue]Step 2: Analyzing ES index...[/]")
|
||||
analysis = self._analyze_es_index(es_index)
|
||||
|
||||
# Auto-detect vector size from ES mapping
|
||||
vector_size = 768 # Default fallback
|
||||
if analysis["vector_fields"]:
|
||||
vector_size = analysis["vector_fields"][0]["dimension"]
|
||||
console.print(f" [green]Auto-detected vector dimension: {vector_size}[/]")
|
||||
else:
|
||||
console.print(f" [yellow]No vector fields found, using default: {vector_size}[/]")
|
||||
console.print(f" Known RAGFlow fields: {len(analysis['known_fields'])}")
|
||||
if analysis["unknown_fields"]:
|
||||
console.print(f" [yellow]Unknown fields (will be stored in 'extra'): {analysis['unknown_fields']}[/]")
|
||||
|
||||
# Step 3: Get total document count
|
||||
total_docs = self.es_client.count_documents(es_index)
|
||||
console.print(f" Total documents: {total_docs:,}")
|
||||
|
||||
result["total_documents"] = total_docs
|
||||
|
||||
if total_docs == 0:
|
||||
console.print("[yellow]No documents to migrate[/]")
|
||||
result["success"] = True
|
||||
return result
|
||||
|
||||
# Step 4: Handle resume or fresh start
|
||||
if resume and self.progress_manager.can_resume(es_index, ob_table):
|
||||
console.print("\n[bold yellow]Resuming from previous progress...[/]")
|
||||
progress = self.progress_manager.load_progress(es_index, ob_table)
|
||||
console.print(
|
||||
f" Previously migrated: {progress.migrated_documents:,} documents"
|
||||
)
|
||||
else:
|
||||
# Fresh start - check if table already exists
|
||||
if self.ob_client.table_exists(ob_table):
|
||||
raise RuntimeError(
|
||||
f"Table '{ob_table}' already exists in OceanBase. "
|
||||
f"Migration aborted to prevent data conflicts. "
|
||||
f"Please drop the table manually or use a different table name."
|
||||
)
|
||||
|
||||
progress = self.progress_manager.create_progress(
|
||||
es_index, ob_table, total_docs
|
||||
)
|
||||
|
||||
# Step 5: Create table if needed
|
||||
if not progress.table_created:
|
||||
console.print("\n[bold blue]Step 3: Creating OceanBase table...[/]")
|
||||
if not self.ob_client.table_exists(ob_table):
|
||||
self.ob_client.create_ragflow_table(
|
||||
table_name=ob_table,
|
||||
vector_size=vector_size,
|
||||
create_indexes=True,
|
||||
create_fts_indexes=True,
|
||||
)
|
||||
console.print(f" Created table '{ob_table}' with RAGFlow schema")
|
||||
else:
|
||||
console.print(f" Table '{ob_table}' already exists")
|
||||
# Check and add vector column if needed
|
||||
self.ob_client.add_vector_column(ob_table, vector_size)
|
||||
|
||||
progress.table_created = True
|
||||
progress.indexes_created = True
|
||||
progress.schema_converted = True
|
||||
self.progress_manager.save_progress(progress)
|
||||
|
||||
# Step 6: Migrate data
|
||||
console.print("\n[bold blue]Step 4: Migrating data...[/]")
|
||||
data_converter = RAGFlowDataConverter()
|
||||
|
||||
migrated = self._migrate_data(
|
||||
es_index=es_index,
|
||||
ob_table=ob_table,
|
||||
data_converter=data_converter,
|
||||
progress=progress,
|
||||
batch_size=batch_size,
|
||||
on_progress=on_progress,
|
||||
)
|
||||
|
||||
result["migrated_documents"] = migrated
|
||||
result["failed_documents"] = progress.failed_documents
|
||||
|
||||
# Step 7: Mark completed
|
||||
self.progress_manager.mark_completed(progress)
|
||||
|
||||
# Step 8: Verify (optional)
|
||||
if verify_after:
|
||||
console.print("\n[bold blue]Step 5: Verifying migration...[/]")
|
||||
verifier = MigrationVerifier(self.es_client, self.ob_client)
|
||||
verification = verifier.verify(
|
||||
es_index, ob_table,
|
||||
primary_key="id"
|
||||
)
|
||||
result["verification"] = {
|
||||
"passed": verification.passed,
|
||||
"message": verification.message,
|
||||
"es_count": verification.es_count,
|
||||
"ob_count": verification.ob_count,
|
||||
"sample_match_rate": verification.sample_match_rate,
|
||||
}
|
||||
console.print(verifier.generate_report(verification))
|
||||
|
||||
result["success"] = True
|
||||
result["duration_seconds"] = time.time() - start_time
|
||||
|
||||
console.print(
|
||||
f"\n[bold green]Migration completed successfully![/]"
|
||||
f"\n Total: {result['total_documents']:,} documents"
|
||||
f"\n Migrated: {result['migrated_documents']:,} documents"
|
||||
f"\n Failed: {result['failed_documents']:,} documents"
|
||||
f"\n Duration: {result['duration_seconds']:.1f} seconds"
|
||||
)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
console.print("\n[bold yellow]Migration interrupted by user[/]")
|
||||
if progress:
|
||||
self.progress_manager.mark_paused(progress)
|
||||
result["error"] = "Interrupted by user"
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Migration failed")
|
||||
if progress:
|
||||
self.progress_manager.mark_failed(progress, str(e))
|
||||
result["error"] = str(e)
|
||||
console.print(f"\n[bold red]Migration failed: {e}[/]")
|
||||
|
||||
return result
|
||||
|
||||
def _check_connections(self):
|
||||
"""Verify connections to both databases."""
|
||||
# Check ES
|
||||
es_health = self.es_client.health_check()
|
||||
if es_health.get("status") not in ("green", "yellow"):
|
||||
raise RuntimeError(f"ES cluster unhealthy: {es_health}")
|
||||
console.print(f" ES cluster status: {es_health.get('status')}")
|
||||
|
||||
# Check OceanBase
|
||||
if not self.ob_client.health_check():
|
||||
raise RuntimeError("OceanBase connection failed")
|
||||
|
||||
ob_version = self.ob_client.get_version()
|
||||
console.print(f" OceanBase connection: OK (version: {ob_version})")
|
||||
|
||||
def _analyze_es_index(self, es_index: str) -> dict[str, Any]:
|
||||
"""Analyze ES index structure for RAGFlow compatibility."""
|
||||
es_mapping = self.es_client.get_index_mapping(es_index)
|
||||
return self.schema_converter.analyze_es_mapping(es_mapping)
|
||||
|
||||
def _migrate_data(
|
||||
self,
|
||||
es_index: str,
|
||||
ob_table: str,
|
||||
data_converter: RAGFlowDataConverter,
|
||||
progress: MigrationProgress,
|
||||
batch_size: int,
|
||||
on_progress: Callable[[int, int], None] | None,
|
||||
) -> int:
|
||||
"""Migrate data in batches."""
|
||||
total = progress.total_documents
|
||||
migrated = progress.migrated_documents
|
||||
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
BarColumn(),
|
||||
TaskProgressColumn(),
|
||||
TimeRemainingColumn(),
|
||||
console=console,
|
||||
) as pbar:
|
||||
task = pbar.add_task(
|
||||
"Migrating...",
|
||||
total=total,
|
||||
completed=migrated,
|
||||
)
|
||||
|
||||
batch_count = 0
|
||||
for batch in self.es_client.scroll_documents(es_index, batch_size):
|
||||
batch_count += 1
|
||||
|
||||
# Convert batch to OceanBase format
|
||||
ob_rows = data_converter.convert_batch(batch)
|
||||
|
||||
# Insert batch
|
||||
try:
|
||||
inserted = self.ob_client.insert_batch(ob_table, ob_rows)
|
||||
migrated += inserted
|
||||
|
||||
# Update progress
|
||||
last_ids = [doc.get("_id", doc.get("id", "")) for doc in batch]
|
||||
self.progress_manager.update_progress(
|
||||
progress,
|
||||
migrated_count=inserted,
|
||||
last_batch_ids=last_ids,
|
||||
)
|
||||
|
||||
# Update progress bar
|
||||
pbar.update(task, completed=migrated)
|
||||
|
||||
# Callback
|
||||
if on_progress:
|
||||
on_progress(migrated, total)
|
||||
|
||||
# Log periodically
|
||||
if batch_count % 10 == 0:
|
||||
logger.info(f"Migrated {migrated:,}/{total:,} documents")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Batch insert failed: {e}")
|
||||
progress.failed_documents += len(batch)
|
||||
# Continue with next batch
|
||||
|
||||
return migrated
|
||||
|
||||
def get_schema_preview(self, es_index: str) -> dict[str, Any]:
|
||||
"""
|
||||
Get a preview of schema analysis without executing migration.
|
||||
|
||||
Args:
|
||||
es_index: Elasticsearch index name
|
||||
|
||||
Returns:
|
||||
Schema analysis information
|
||||
"""
|
||||
es_mapping = self.es_client.get_index_mapping(es_index)
|
||||
analysis = self.schema_converter.analyze_es_mapping(es_mapping)
|
||||
column_defs = self.schema_converter.get_column_definitions()
|
||||
|
||||
return {
|
||||
"es_index": es_index,
|
||||
"es_mapping": es_mapping,
|
||||
"analysis": analysis,
|
||||
"ob_columns": column_defs,
|
||||
"vector_fields": self.schema_converter.get_vector_fields(),
|
||||
"total_columns": len(column_defs),
|
||||
}
|
||||
|
||||
def get_data_preview(
|
||||
self,
|
||||
es_index: str,
|
||||
sample_size: int = 5,
|
||||
kb_id: str | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Get sample documents from ES for preview.
|
||||
|
||||
Args:
|
||||
es_index: ES index name
|
||||
sample_size: Number of samples
|
||||
kb_id: Optional KB filter
|
||||
"""
|
||||
query = None
|
||||
if kb_id:
|
||||
query = {"term": {"kb_id": kb_id}}
|
||||
return self.es_client.get_sample_documents(es_index, sample_size, query=query)
|
||||
|
||||
def list_knowledge_bases(self, es_index: str) -> list[str]:
|
||||
"""
|
||||
List all knowledge base IDs in an ES index.
|
||||
|
||||
Args:
|
||||
es_index: ES index name
|
||||
|
||||
Returns:
|
||||
List of kb_id values
|
||||
"""
|
||||
try:
|
||||
agg_result = self.es_client.aggregate_field(es_index, "kb_id")
|
||||
return [bucket["key"] for bucket in agg_result.get("buckets", [])]
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to list knowledge bases: {e}")
|
||||
return []
|
||||
442
tools/es-to-oceanbase-migration/src/es_ob_migration/ob_client.py
Normal file
442
tools/es-to-oceanbase-migration/src/es_ob_migration/ob_client.py
Normal file
@ -0,0 +1,442 @@
|
||||
"""
|
||||
OceanBase Client for RAGFlow data migration.
|
||||
|
||||
This client is specifically designed for RAGFlow's data structure.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from pyobvector import ObVecClient, FtsIndexParam, FtsParser, VECTOR, ARRAY
|
||||
from sqlalchemy import Column, String, Integer, Float, JSON, Text, text, Double
|
||||
from sqlalchemy.dialects.mysql import LONGTEXT, TEXT as MYSQL_TEXT
|
||||
|
||||
from .schema import RAGFLOW_COLUMNS, ARRAY_COLUMNS, FTS_COLUMNS_TKS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Index naming templates (from RAGFlow ob_conn.py)
|
||||
INDEX_NAME_TEMPLATE = "ix_%s_%s"
|
||||
FULLTEXT_INDEX_NAME_TEMPLATE = "fts_idx_%s"
|
||||
VECTOR_INDEX_NAME_TEMPLATE = "%s_idx"
|
||||
|
||||
# Columns that need regular indexes
|
||||
INDEX_COLUMNS = [
|
||||
"kb_id",
|
||||
"doc_id",
|
||||
"available_int",
|
||||
"knowledge_graph_kwd",
|
||||
"entity_type_kwd",
|
||||
"removed_kwd",
|
||||
]
|
||||
|
||||
|
||||
class OBClient:
|
||||
"""OceanBase client wrapper for RAGFlow migration operations."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str = "localhost",
|
||||
port: int = 2881,
|
||||
user: str = "root",
|
||||
password: str = "",
|
||||
database: str = "test",
|
||||
pool_size: int = 10,
|
||||
):
|
||||
"""
|
||||
Initialize OceanBase client.
|
||||
|
||||
Args:
|
||||
host: OceanBase host address
|
||||
port: OceanBase port
|
||||
user: Database user (format: user@tenant for OceanBase)
|
||||
password: Database password
|
||||
database: Database name
|
||||
pool_size: Connection pool size
|
||||
"""
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.user = user
|
||||
self.password = password
|
||||
self.database = database
|
||||
|
||||
# Initialize pyobvector client
|
||||
self.uri = f"{host}:{port}"
|
||||
self.client = ObVecClient(
|
||||
uri=self.uri,
|
||||
user=user,
|
||||
password=password,
|
||||
db_name=database,
|
||||
pool_pre_ping=True,
|
||||
pool_recycle=3600,
|
||||
pool_size=pool_size,
|
||||
)
|
||||
logger.info(f"Connected to OceanBase at {self.uri}, database: {database}")
|
||||
|
||||
def health_check(self) -> bool:
|
||||
"""Check database connectivity."""
|
||||
try:
|
||||
result = self.client.perform_raw_text_sql("SELECT 1 FROM DUAL")
|
||||
result.fetchone()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"OceanBase health check failed: {e}")
|
||||
return False
|
||||
|
||||
def get_version(self) -> str | None:
|
||||
"""Get OceanBase version."""
|
||||
try:
|
||||
result = self.client.perform_raw_text_sql("SELECT OB_VERSION() FROM DUAL")
|
||||
row = result.fetchone()
|
||||
return row[0] if row else None
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get OceanBase version: {e}")
|
||||
return None
|
||||
|
||||
def table_exists(self, table_name: str) -> bool:
|
||||
"""Check if a table exists."""
|
||||
try:
|
||||
return self.client.check_table_exists(table_name)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def create_ragflow_table(
|
||||
self,
|
||||
table_name: str,
|
||||
vector_size: int = 768,
|
||||
create_indexes: bool = True,
|
||||
create_fts_indexes: bool = True,
|
||||
):
|
||||
"""
|
||||
Create a RAGFlow-compatible table in OceanBase.
|
||||
|
||||
This creates a table with the exact schema that RAGFlow expects,
|
||||
including all columns, indexes, and vector columns.
|
||||
|
||||
Args:
|
||||
table_name: Name of the table (usually the ES index name)
|
||||
vector_size: Vector dimension (e.g., 768, 1024, 1536)
|
||||
create_indexes: Whether to create regular indexes
|
||||
create_fts_indexes: Whether to create fulltext indexes
|
||||
"""
|
||||
# Build column definitions
|
||||
columns = self._build_ragflow_columns()
|
||||
|
||||
# Add vector column
|
||||
vector_column_name = f"q_{vector_size}_vec"
|
||||
columns.append(
|
||||
Column(vector_column_name, VECTOR(vector_size), nullable=True,
|
||||
comment=f"vector embedding ({vector_size} dimensions)")
|
||||
)
|
||||
|
||||
# Table options (from RAGFlow)
|
||||
table_options = {
|
||||
"mysql_charset": "utf8mb4",
|
||||
"mysql_collate": "utf8mb4_unicode_ci",
|
||||
"mysql_organization": "heap",
|
||||
}
|
||||
|
||||
# Create table
|
||||
self.client.create_table(
|
||||
table_name=table_name,
|
||||
columns=columns,
|
||||
**table_options,
|
||||
)
|
||||
logger.info(f"Created table: {table_name}")
|
||||
|
||||
# Create regular indexes
|
||||
if create_indexes:
|
||||
self._create_regular_indexes(table_name)
|
||||
|
||||
# Create fulltext indexes
|
||||
if create_fts_indexes:
|
||||
self._create_fulltext_indexes(table_name)
|
||||
|
||||
# Create vector index
|
||||
self._create_vector_index(table_name, vector_column_name)
|
||||
|
||||
# Refresh metadata
|
||||
self.client.refresh_metadata([table_name])
|
||||
|
||||
def _build_ragflow_columns(self) -> list[Column]:
|
||||
"""Build SQLAlchemy Column objects for RAGFlow schema."""
|
||||
columns = []
|
||||
|
||||
for col_name, col_def in RAGFLOW_COLUMNS.items():
|
||||
ob_type = col_def["ob_type"]
|
||||
nullable = col_def.get("nullable", True)
|
||||
default = col_def.get("default")
|
||||
is_primary = col_def.get("is_primary", False)
|
||||
is_array = col_def.get("is_array", False)
|
||||
|
||||
# Parse type and create appropriate Column
|
||||
col = self._create_column(col_name, ob_type, nullable, default, is_primary, is_array)
|
||||
columns.append(col)
|
||||
|
||||
return columns
|
||||
|
||||
def _create_column(
|
||||
self,
|
||||
name: str,
|
||||
ob_type: str,
|
||||
nullable: bool,
|
||||
default: Any,
|
||||
is_primary: bool,
|
||||
is_array: bool,
|
||||
) -> Column:
|
||||
"""Create a SQLAlchemy Column object based on type string."""
|
||||
|
||||
# Handle array types
|
||||
if is_array or ob_type.startswith("ARRAY"):
|
||||
# Extract inner type
|
||||
if "String" in ob_type:
|
||||
inner_type = String(256)
|
||||
elif "Integer" in ob_type:
|
||||
inner_type = Integer
|
||||
else:
|
||||
inner_type = String(256)
|
||||
|
||||
# Nested array (e.g., ARRAY(ARRAY(Integer)))
|
||||
if ob_type.count("ARRAY") > 1:
|
||||
return Column(name, ARRAY(ARRAY(inner_type)), nullable=nullable)
|
||||
else:
|
||||
return Column(name, ARRAY(inner_type), nullable=nullable)
|
||||
|
||||
# Handle String types with length
|
||||
if ob_type.startswith("String"):
|
||||
# Extract length: String(256) -> 256
|
||||
import re
|
||||
match = re.search(r'\((\d+)\)', ob_type)
|
||||
length = int(match.group(1)) if match else 256
|
||||
return Column(
|
||||
name, String(length),
|
||||
primary_key=is_primary,
|
||||
nullable=nullable,
|
||||
server_default=f"'{default}'" if default else None
|
||||
)
|
||||
|
||||
# Map other types
|
||||
type_map = {
|
||||
"Integer": Integer,
|
||||
"Double": Double,
|
||||
"Float": Float,
|
||||
"JSON": JSON,
|
||||
"LONGTEXT": LONGTEXT,
|
||||
"TEXT": MYSQL_TEXT,
|
||||
}
|
||||
|
||||
for type_name, type_class in type_map.items():
|
||||
if type_name in ob_type:
|
||||
return Column(
|
||||
name, type_class,
|
||||
primary_key=is_primary,
|
||||
nullable=nullable,
|
||||
server_default=str(default) if default is not None else None
|
||||
)
|
||||
|
||||
# Default to String
|
||||
return Column(name, String(256), nullable=nullable)
|
||||
|
||||
def _create_regular_indexes(self, table_name: str):
|
||||
"""Create regular indexes for indexed columns."""
|
||||
for col_name in INDEX_COLUMNS:
|
||||
index_name = INDEX_NAME_TEMPLATE % (table_name, col_name)
|
||||
try:
|
||||
self.client.create_index(
|
||||
table_name=table_name,
|
||||
is_vec_index=False,
|
||||
index_name=index_name,
|
||||
column_names=[col_name],
|
||||
)
|
||||
logger.debug(f"Created index: {index_name}")
|
||||
except Exception as e:
|
||||
if "Duplicate" in str(e):
|
||||
logger.debug(f"Index {index_name} already exists")
|
||||
else:
|
||||
logger.warning(f"Failed to create index {index_name}: {e}")
|
||||
|
||||
def _create_fulltext_indexes(self, table_name: str):
|
||||
"""Create fulltext indexes for text columns."""
|
||||
for fts_column in FTS_COLUMNS_TKS:
|
||||
col_name = fts_column.split("^")[0] # Remove weight suffix
|
||||
index_name = FULLTEXT_INDEX_NAME_TEMPLATE % col_name
|
||||
try:
|
||||
self.client.create_fts_idx_with_fts_index_param(
|
||||
table_name=table_name,
|
||||
fts_idx_param=FtsIndexParam(
|
||||
index_name=index_name,
|
||||
field_names=[col_name],
|
||||
parser_type=FtsParser.IK,
|
||||
),
|
||||
)
|
||||
logger.debug(f"Created fulltext index: {index_name}")
|
||||
except Exception as e:
|
||||
if "Duplicate" in str(e):
|
||||
logger.debug(f"Fulltext index {index_name} already exists")
|
||||
else:
|
||||
logger.warning(f"Failed to create fulltext index {index_name}: {e}")
|
||||
|
||||
def _create_vector_index(self, table_name: str, vector_column_name: str):
|
||||
"""Create vector index for embedding column."""
|
||||
index_name = VECTOR_INDEX_NAME_TEMPLATE % vector_column_name
|
||||
try:
|
||||
self.client.create_index(
|
||||
table_name=table_name,
|
||||
is_vec_index=True,
|
||||
index_name=index_name,
|
||||
column_names=[vector_column_name],
|
||||
vidx_params="distance=cosine, type=hnsw, lib=vsag",
|
||||
)
|
||||
logger.info(f"Created vector index: {index_name}")
|
||||
except Exception as e:
|
||||
if "Duplicate" in str(e):
|
||||
logger.debug(f"Vector index {index_name} already exists")
|
||||
else:
|
||||
logger.warning(f"Failed to create vector index {index_name}: {e}")
|
||||
|
||||
def add_vector_column(self, table_name: str, vector_size: int):
|
||||
"""Add a vector column to an existing table."""
|
||||
vector_column_name = f"q_{vector_size}_vec"
|
||||
|
||||
# Check if column exists
|
||||
if self._column_exists(table_name, vector_column_name):
|
||||
logger.info(f"Vector column {vector_column_name} already exists")
|
||||
return
|
||||
|
||||
try:
|
||||
self.client.add_columns(
|
||||
table_name=table_name,
|
||||
columns=[Column(vector_column_name, VECTOR(vector_size), nullable=True)],
|
||||
)
|
||||
logger.info(f"Added vector column: {vector_column_name}")
|
||||
|
||||
# Create index
|
||||
self._create_vector_index(table_name, vector_column_name)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to add vector column: {e}")
|
||||
raise
|
||||
|
||||
def _column_exists(self, table_name: str, column_name: str) -> bool:
|
||||
"""Check if a column exists in a table."""
|
||||
try:
|
||||
result = self.client.perform_raw_text_sql(
|
||||
f"SELECT COUNT(*) FROM INFORMATION_SCHEMA.COLUMNS "
|
||||
f"WHERE TABLE_SCHEMA = '{self.database}' "
|
||||
f"AND TABLE_NAME = '{table_name}' "
|
||||
f"AND COLUMN_NAME = '{column_name}'"
|
||||
)
|
||||
count = result.fetchone()[0]
|
||||
return count > 0
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _index_exists(self, table_name: str, index_name: str) -> bool:
|
||||
"""Check if an index exists."""
|
||||
try:
|
||||
result = self.client.perform_raw_text_sql(
|
||||
f"SELECT COUNT(*) FROM INFORMATION_SCHEMA.STATISTICS "
|
||||
f"WHERE TABLE_SCHEMA = '{self.database}' "
|
||||
f"AND TABLE_NAME = '{table_name}' "
|
||||
f"AND INDEX_NAME = '{index_name}'"
|
||||
)
|
||||
count = result.fetchone()[0]
|
||||
return count > 0
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def insert_batch(
|
||||
self,
|
||||
table_name: str,
|
||||
documents: list[dict[str, Any]],
|
||||
) -> int:
|
||||
"""
|
||||
Insert a batch of documents using upsert.
|
||||
|
||||
Args:
|
||||
table_name: Name of the table
|
||||
documents: List of documents to insert
|
||||
|
||||
Returns:
|
||||
Number of documents inserted
|
||||
"""
|
||||
if not documents:
|
||||
return 0
|
||||
|
||||
try:
|
||||
self.client.upsert(table_name=table_name, data=documents)
|
||||
return len(documents)
|
||||
except Exception as e:
|
||||
logger.error(f"Batch insert failed: {e}")
|
||||
raise
|
||||
|
||||
def count_rows(self, table_name: str, kb_id: str | None = None) -> int:
|
||||
"""
|
||||
Count rows in a table.
|
||||
|
||||
Args:
|
||||
table_name: Table name
|
||||
kb_id: Optional knowledge base ID filter
|
||||
"""
|
||||
try:
|
||||
sql = f"SELECT COUNT(*) FROM `{table_name}`"
|
||||
if kb_id:
|
||||
sql += f" WHERE kb_id = '{kb_id}'"
|
||||
result = self.client.perform_raw_text_sql(sql)
|
||||
return result.fetchone()[0]
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
def get_sample_rows(
|
||||
self,
|
||||
table_name: str,
|
||||
limit: int = 10,
|
||||
kb_id: str | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Get sample rows from a table."""
|
||||
try:
|
||||
sql = f"SELECT * FROM `{table_name}`"
|
||||
if kb_id:
|
||||
sql += f" WHERE kb_id = '{kb_id}'"
|
||||
sql += f" LIMIT {limit}"
|
||||
|
||||
result = self.client.perform_raw_text_sql(sql)
|
||||
columns = result.keys()
|
||||
rows = []
|
||||
for row in result:
|
||||
rows.append(dict(zip(columns, row)))
|
||||
return rows
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get sample rows: {e}")
|
||||
return []
|
||||
|
||||
def get_row_by_id(self, table_name: str, doc_id: str) -> dict[str, Any] | None:
|
||||
"""Get a single row by ID."""
|
||||
try:
|
||||
result = self.client.get(table_name=table_name, ids=[doc_id])
|
||||
row = result.fetchone()
|
||||
if row:
|
||||
columns = result.keys()
|
||||
return dict(zip(columns, row))
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get row: {e}")
|
||||
return None
|
||||
|
||||
def drop_table(self, table_name: str):
|
||||
"""Drop a table if exists."""
|
||||
try:
|
||||
self.client.drop_table_if_exist(table_name)
|
||||
logger.info(f"Dropped table: {table_name}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to drop table: {e}")
|
||||
|
||||
def execute_sql(self, sql: str) -> Any:
|
||||
"""Execute raw SQL."""
|
||||
return self.client.perform_raw_text_sql(sql)
|
||||
|
||||
def close(self):
|
||||
"""Close the OB client connection."""
|
||||
self.client.engine.dispose()
|
||||
logger.info("OceanBase connection closed")
|
||||
221
tools/es-to-oceanbase-migration/src/es_ob_migration/progress.py
Normal file
221
tools/es-to-oceanbase-migration/src/es_ob_migration/progress.py
Normal file
@ -0,0 +1,221 @@
|
||||
"""
|
||||
Progress tracking and resume capability for migration.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MigrationProgress:
|
||||
"""Migration progress state."""
|
||||
|
||||
# Basic info
|
||||
es_index: str
|
||||
ob_table: str
|
||||
started_at: str = ""
|
||||
updated_at: str = ""
|
||||
|
||||
# Progress counters
|
||||
total_documents: int = 0
|
||||
migrated_documents: int = 0
|
||||
failed_documents: int = 0
|
||||
|
||||
# State for resume
|
||||
last_sort_values: list[Any] = field(default_factory=list)
|
||||
last_batch_ids: list[str] = field(default_factory=list)
|
||||
|
||||
# Status
|
||||
status: str = "pending" # pending, running, completed, failed, paused
|
||||
error_message: str = ""
|
||||
|
||||
# Schema info
|
||||
schema_converted: bool = False
|
||||
table_created: bool = False
|
||||
indexes_created: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if not self.started_at:
|
||||
self.started_at = datetime.utcnow().isoformat()
|
||||
self.updated_at = datetime.utcnow().isoformat()
|
||||
|
||||
|
||||
class ProgressManager:
|
||||
"""Manage migration progress persistence."""
|
||||
|
||||
def __init__(self, progress_dir: str = ".migration_progress"):
|
||||
"""
|
||||
Initialize progress manager.
|
||||
|
||||
Args:
|
||||
progress_dir: Directory to store progress files
|
||||
"""
|
||||
self.progress_dir = Path(progress_dir)
|
||||
self.progress_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def _get_progress_file(self, es_index: str, ob_table: str) -> Path:
|
||||
"""Get progress file path for a migration."""
|
||||
filename = f"{es_index}_to_{ob_table}.json"
|
||||
return self.progress_dir / filename
|
||||
|
||||
def load_progress(
|
||||
self, es_index: str, ob_table: str
|
||||
) -> MigrationProgress | None:
|
||||
"""
|
||||
Load progress from file.
|
||||
|
||||
Args:
|
||||
es_index: Elasticsearch index name
|
||||
ob_table: OceanBase table name
|
||||
|
||||
Returns:
|
||||
MigrationProgress if exists, None otherwise
|
||||
"""
|
||||
progress_file = self._get_progress_file(es_index, ob_table)
|
||||
|
||||
if not progress_file.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(progress_file, "r") as f:
|
||||
data = json.load(f)
|
||||
progress = MigrationProgress(**data)
|
||||
logger.info(
|
||||
f"Loaded progress: {progress.migrated_documents}/{progress.total_documents} documents"
|
||||
)
|
||||
return progress
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load progress: {e}")
|
||||
return None
|
||||
|
||||
def save_progress(self, progress: MigrationProgress):
|
||||
"""
|
||||
Save progress to file.
|
||||
|
||||
Args:
|
||||
progress: MigrationProgress instance
|
||||
"""
|
||||
progress.updated_at = datetime.utcnow().isoformat()
|
||||
progress_file = self._get_progress_file(progress.es_index, progress.ob_table)
|
||||
|
||||
try:
|
||||
with open(progress_file, "w") as f:
|
||||
json.dump(asdict(progress), f, indent=2, default=str)
|
||||
logger.debug(f"Saved progress to {progress_file}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save progress: {e}")
|
||||
|
||||
def delete_progress(self, es_index: str, ob_table: str):
|
||||
"""Delete progress file."""
|
||||
progress_file = self._get_progress_file(es_index, ob_table)
|
||||
if progress_file.exists():
|
||||
progress_file.unlink()
|
||||
logger.info(f"Deleted progress file: {progress_file}")
|
||||
|
||||
def create_progress(
|
||||
self,
|
||||
es_index: str,
|
||||
ob_table: str,
|
||||
total_documents: int,
|
||||
) -> MigrationProgress:
|
||||
"""
|
||||
Create new progress tracker.
|
||||
|
||||
Args:
|
||||
es_index: Elasticsearch index name
|
||||
ob_table: OceanBase table name
|
||||
total_documents: Total documents to migrate
|
||||
|
||||
Returns:
|
||||
New MigrationProgress instance
|
||||
"""
|
||||
progress = MigrationProgress(
|
||||
es_index=es_index,
|
||||
ob_table=ob_table,
|
||||
total_documents=total_documents,
|
||||
status="running",
|
||||
)
|
||||
self.save_progress(progress)
|
||||
return progress
|
||||
|
||||
def update_progress(
|
||||
self,
|
||||
progress: MigrationProgress,
|
||||
migrated_count: int,
|
||||
last_sort_values: list[Any] | None = None,
|
||||
last_batch_ids: list[str] | None = None,
|
||||
):
|
||||
"""
|
||||
Update progress after a batch.
|
||||
|
||||
Args:
|
||||
progress: MigrationProgress instance
|
||||
migrated_count: Number of documents migrated in this batch
|
||||
last_sort_values: Sort values for search_after
|
||||
last_batch_ids: IDs of documents in last batch
|
||||
"""
|
||||
progress.migrated_documents += migrated_count
|
||||
|
||||
if last_sort_values:
|
||||
progress.last_sort_values = last_sort_values
|
||||
if last_batch_ids:
|
||||
progress.last_batch_ids = last_batch_ids
|
||||
|
||||
self.save_progress(progress)
|
||||
|
||||
def mark_completed(self, progress: MigrationProgress):
|
||||
"""Mark migration as completed."""
|
||||
progress.status = "completed"
|
||||
progress.updated_at = datetime.utcnow().isoformat()
|
||||
self.save_progress(progress)
|
||||
logger.info(
|
||||
f"Migration completed: {progress.migrated_documents} documents"
|
||||
)
|
||||
|
||||
def mark_failed(self, progress: MigrationProgress, error: str):
|
||||
"""Mark migration as failed."""
|
||||
progress.status = "failed"
|
||||
progress.error_message = error
|
||||
progress.updated_at = datetime.utcnow().isoformat()
|
||||
self.save_progress(progress)
|
||||
logger.error(f"Migration failed: {error}")
|
||||
|
||||
def mark_paused(self, progress: MigrationProgress):
|
||||
"""Mark migration as paused (for resume later)."""
|
||||
progress.status = "paused"
|
||||
progress.updated_at = datetime.utcnow().isoformat()
|
||||
self.save_progress(progress)
|
||||
logger.info(
|
||||
f"Migration paused at {progress.migrated_documents}/{progress.total_documents}"
|
||||
)
|
||||
|
||||
def can_resume(self, es_index: str, ob_table: str) -> bool:
|
||||
"""Check if migration can be resumed."""
|
||||
progress = self.load_progress(es_index, ob_table)
|
||||
if not progress:
|
||||
return False
|
||||
return progress.status in ("running", "paused", "failed")
|
||||
|
||||
def get_resume_info(self, es_index: str, ob_table: str) -> dict[str, Any] | None:
|
||||
"""Get information needed to resume migration."""
|
||||
progress = self.load_progress(es_index, ob_table)
|
||||
if not progress:
|
||||
return None
|
||||
|
||||
return {
|
||||
"migrated_documents": progress.migrated_documents,
|
||||
"total_documents": progress.total_documents,
|
||||
"last_sort_values": progress.last_sort_values,
|
||||
"last_batch_ids": progress.last_batch_ids,
|
||||
"schema_converted": progress.schema_converted,
|
||||
"table_created": progress.table_created,
|
||||
"indexes_created": progress.indexes_created,
|
||||
"status": progress.status,
|
||||
}
|
||||
451
tools/es-to-oceanbase-migration/src/es_ob_migration/schema.py
Normal file
451
tools/es-to-oceanbase-migration/src/es_ob_migration/schema.py
Normal file
@ -0,0 +1,451 @@
|
||||
"""
|
||||
RAGFlow-specific schema conversion from Elasticsearch to OceanBase.
|
||||
|
||||
This module handles the fixed RAGFlow table structure migration.
|
||||
RAGFlow uses a predefined schema for both ES and OceanBase.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# RAGFlow fixed column definitions (from rag/utils/ob_conn.py)
|
||||
# These are the actual columns used by RAGFlow
|
||||
RAGFLOW_COLUMNS = {
|
||||
# Primary identifiers
|
||||
"id": {"ob_type": "String(256)", "nullable": False, "is_primary": True},
|
||||
"kb_id": {"ob_type": "String(256)", "nullable": False, "index": True},
|
||||
"doc_id": {"ob_type": "String(256)", "nullable": True, "index": True},
|
||||
|
||||
# Document metadata
|
||||
"docnm_kwd": {"ob_type": "String(256)", "nullable": True}, # document name
|
||||
"doc_type_kwd": {"ob_type": "String(256)", "nullable": True}, # document type
|
||||
|
||||
# Title fields
|
||||
"title_tks": {"ob_type": "String(256)", "nullable": True}, # title tokens
|
||||
"title_sm_tks": {"ob_type": "String(256)", "nullable": True}, # fine-grained title tokens
|
||||
|
||||
# Content fields
|
||||
"content_with_weight": {"ob_type": "LONGTEXT", "nullable": True}, # original content
|
||||
"content_ltks": {"ob_type": "LONGTEXT", "nullable": True}, # long text tokens
|
||||
"content_sm_ltks": {"ob_type": "LONGTEXT", "nullable": True}, # fine-grained tokens
|
||||
|
||||
# Feature fields
|
||||
"pagerank_fea": {"ob_type": "Integer", "nullable": True}, # page rank priority
|
||||
|
||||
# Array fields
|
||||
"important_kwd": {"ob_type": "ARRAY(String(256))", "nullable": True, "is_array": True}, # keywords
|
||||
"important_tks": {"ob_type": "TEXT", "nullable": True}, # keyword tokens
|
||||
"question_kwd": {"ob_type": "ARRAY(String(1024))", "nullable": True, "is_array": True}, # questions
|
||||
"question_tks": {"ob_type": "TEXT", "nullable": True}, # question tokens
|
||||
"tag_kwd": {"ob_type": "ARRAY(String(256))", "nullable": True, "is_array": True}, # tags
|
||||
"tag_feas": {"ob_type": "JSON", "nullable": True, "is_json": True}, # tag features
|
||||
|
||||
# Status fields
|
||||
"available_int": {"ob_type": "Integer", "nullable": False, "default": 1},
|
||||
|
||||
# Time fields
|
||||
"create_time": {"ob_type": "String(19)", "nullable": True},
|
||||
"create_timestamp_flt": {"ob_type": "Double", "nullable": True},
|
||||
|
||||
# Image field
|
||||
"img_id": {"ob_type": "String(128)", "nullable": True},
|
||||
|
||||
# Position fields (arrays)
|
||||
"position_int": {"ob_type": "ARRAY(ARRAY(Integer))", "nullable": True, "is_array": True},
|
||||
"page_num_int": {"ob_type": "ARRAY(Integer)", "nullable": True, "is_array": True},
|
||||
"top_int": {"ob_type": "ARRAY(Integer)", "nullable": True, "is_array": True},
|
||||
|
||||
# Knowledge graph fields
|
||||
"knowledge_graph_kwd": {"ob_type": "String(256)", "nullable": True, "index": True},
|
||||
"source_id": {"ob_type": "ARRAY(String(256))", "nullable": True, "is_array": True},
|
||||
"entity_kwd": {"ob_type": "String(256)", "nullable": True},
|
||||
"entity_type_kwd": {"ob_type": "String(256)", "nullable": True, "index": True},
|
||||
"from_entity_kwd": {"ob_type": "String(256)", "nullable": True},
|
||||
"to_entity_kwd": {"ob_type": "String(256)", "nullable": True},
|
||||
"weight_int": {"ob_type": "Integer", "nullable": True},
|
||||
"weight_flt": {"ob_type": "Double", "nullable": True},
|
||||
"entities_kwd": {"ob_type": "ARRAY(String(256))", "nullable": True, "is_array": True},
|
||||
"rank_flt": {"ob_type": "Double", "nullable": True},
|
||||
|
||||
# Status
|
||||
"removed_kwd": {"ob_type": "String(256)", "nullable": True, "index": True, "default": "N"},
|
||||
|
||||
# JSON fields
|
||||
"metadata": {"ob_type": "JSON", "nullable": True, "is_json": True},
|
||||
"extra": {"ob_type": "JSON", "nullable": True, "is_json": True},
|
||||
|
||||
# New columns
|
||||
"_order_id": {"ob_type": "Integer", "nullable": True},
|
||||
"group_id": {"ob_type": "String(256)", "nullable": True},
|
||||
"mom_id": {"ob_type": "String(256)", "nullable": True},
|
||||
}
|
||||
|
||||
# Array column names for special handling
|
||||
ARRAY_COLUMNS = [
|
||||
"important_kwd", "question_kwd", "tag_kwd", "source_id",
|
||||
"entities_kwd", "position_int", "page_num_int", "top_int"
|
||||
]
|
||||
|
||||
# JSON column names
|
||||
JSON_COLUMNS = ["tag_feas", "metadata", "extra"]
|
||||
|
||||
# Fulltext search columns (for reference)
|
||||
FTS_COLUMNS_ORIGIN = ["docnm_kwd", "content_with_weight", "important_tks", "question_tks"]
|
||||
FTS_COLUMNS_TKS = ["title_tks", "title_sm_tks", "important_tks", "question_tks", "content_ltks", "content_sm_ltks"]
|
||||
|
||||
# Vector field pattern: q_{vector_size}_vec
|
||||
VECTOR_FIELD_PATTERN = re.compile(r"q_(?P<vector_size>\d+)_vec")
|
||||
|
||||
|
||||
class RAGFlowSchemaConverter:
|
||||
"""
|
||||
Convert RAGFlow Elasticsearch documents to OceanBase format.
|
||||
|
||||
RAGFlow uses a fixed schema, so this converter knows exactly
|
||||
what fields to expect and how to map them.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.vector_fields: list[dict[str, Any]] = []
|
||||
self.detected_vector_size: int | None = None
|
||||
|
||||
def analyze_es_mapping(self, es_mapping: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Analyze ES mapping to extract vector field dimensions.
|
||||
|
||||
Args:
|
||||
es_mapping: Elasticsearch index mapping
|
||||
|
||||
Returns:
|
||||
Analysis result with detected fields
|
||||
"""
|
||||
result = {
|
||||
"known_fields": [],
|
||||
"vector_fields": [],
|
||||
"unknown_fields": [],
|
||||
}
|
||||
|
||||
properties = es_mapping.get("properties", {})
|
||||
|
||||
for field_name, field_def in properties.items():
|
||||
# Check if it's a known RAGFlow field
|
||||
if field_name in RAGFLOW_COLUMNS:
|
||||
result["known_fields"].append(field_name)
|
||||
# Check if it's a vector field
|
||||
elif VECTOR_FIELD_PATTERN.match(field_name):
|
||||
match = VECTOR_FIELD_PATTERN.match(field_name)
|
||||
vec_size = int(match.group("vector_size"))
|
||||
result["vector_fields"].append({
|
||||
"name": field_name,
|
||||
"dimension": vec_size,
|
||||
})
|
||||
self.vector_fields.append({
|
||||
"name": field_name,
|
||||
"dimension": vec_size,
|
||||
})
|
||||
if self.detected_vector_size is None:
|
||||
self.detected_vector_size = vec_size
|
||||
else:
|
||||
# Unknown field - might be custom field stored in 'extra'
|
||||
result["unknown_fields"].append(field_name)
|
||||
|
||||
logger.info(
|
||||
f"Analyzed ES mapping: {len(result['known_fields'])} known fields, "
|
||||
f"{len(result['vector_fields'])} vector fields, "
|
||||
f"{len(result['unknown_fields'])} unknown fields"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def get_column_definitions(self) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Get RAGFlow column definitions for OceanBase table creation.
|
||||
|
||||
Returns:
|
||||
List of column definitions
|
||||
"""
|
||||
columns = []
|
||||
|
||||
for col_name, col_def in RAGFLOW_COLUMNS.items():
|
||||
columns.append({
|
||||
"name": col_name,
|
||||
"ob_type": col_def["ob_type"],
|
||||
"nullable": col_def.get("nullable", True),
|
||||
"is_primary": col_def.get("is_primary", False),
|
||||
"index": col_def.get("index", False),
|
||||
"is_array": col_def.get("is_array", False),
|
||||
"is_json": col_def.get("is_json", False),
|
||||
"default": col_def.get("default"),
|
||||
})
|
||||
|
||||
# Add detected vector fields
|
||||
for vec_field in self.vector_fields:
|
||||
columns.append({
|
||||
"name": vec_field["name"],
|
||||
"ob_type": f"VECTOR({vec_field['dimension']})",
|
||||
"nullable": True,
|
||||
"is_vector": True,
|
||||
"dimension": vec_field["dimension"],
|
||||
})
|
||||
|
||||
return columns
|
||||
|
||||
def get_vector_fields(self) -> list[dict[str, Any]]:
|
||||
"""Get list of vector fields for index creation."""
|
||||
return self.vector_fields
|
||||
|
||||
|
||||
class RAGFlowDataConverter:
|
||||
"""
|
||||
Convert RAGFlow ES documents to OceanBase row format.
|
||||
|
||||
This converter handles the specific data transformations needed
|
||||
for RAGFlow's data structure.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize data converter."""
|
||||
self.vector_fields: set[str] = set()
|
||||
|
||||
def detect_vector_fields(self, doc: dict[str, Any]) -> None:
|
||||
"""Detect vector fields from a sample document."""
|
||||
for key in doc.keys():
|
||||
if VECTOR_FIELD_PATTERN.match(key):
|
||||
self.vector_fields.add(key)
|
||||
|
||||
def convert_document(self, es_doc: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Convert an ES document to OceanBase row format.
|
||||
|
||||
Args:
|
||||
es_doc: Elasticsearch document (with _id and _source)
|
||||
|
||||
Returns:
|
||||
Dictionary ready for OceanBase insertion
|
||||
"""
|
||||
# Extract _id and _source
|
||||
doc_id = es_doc.get("_id")
|
||||
source = es_doc.get("_source", es_doc)
|
||||
|
||||
row = {}
|
||||
|
||||
# Set document ID
|
||||
if doc_id:
|
||||
row["id"] = str(doc_id)
|
||||
elif "id" in source:
|
||||
row["id"] = str(source["id"])
|
||||
|
||||
# Process each field
|
||||
for field_name, field_def in RAGFLOW_COLUMNS.items():
|
||||
if field_name == "id":
|
||||
continue # Already handled
|
||||
|
||||
value = source.get(field_name)
|
||||
|
||||
if value is None:
|
||||
# Use default if available
|
||||
default = field_def.get("default")
|
||||
if default is not None:
|
||||
row[field_name] = default
|
||||
continue
|
||||
|
||||
# Convert based on field type
|
||||
row[field_name] = self._convert_field_value(
|
||||
field_name, value, field_def
|
||||
)
|
||||
|
||||
# Handle vector fields
|
||||
for key, value in source.items():
|
||||
if VECTOR_FIELD_PATTERN.match(key):
|
||||
if isinstance(value, list):
|
||||
row[key] = value
|
||||
self.vector_fields.add(key)
|
||||
|
||||
# Handle unknown fields -> store in 'extra'
|
||||
extra_fields = {}
|
||||
for key, value in source.items():
|
||||
if key not in RAGFLOW_COLUMNS and not VECTOR_FIELD_PATTERN.match(key):
|
||||
extra_fields[key] = value
|
||||
|
||||
if extra_fields:
|
||||
existing_extra = row.get("extra")
|
||||
if existing_extra and isinstance(existing_extra, dict):
|
||||
existing_extra.update(extra_fields)
|
||||
else:
|
||||
row["extra"] = json.dumps(extra_fields, ensure_ascii=False)
|
||||
|
||||
return row
|
||||
|
||||
def _convert_field_value(
|
||||
self,
|
||||
field_name: str,
|
||||
value: Any,
|
||||
field_def: dict[str, Any]
|
||||
) -> Any:
|
||||
"""
|
||||
Convert a field value to the appropriate format for OceanBase.
|
||||
|
||||
Args:
|
||||
field_name: Field name
|
||||
value: Original value from ES
|
||||
field_def: Field definition from RAGFLOW_COLUMNS
|
||||
|
||||
Returns:
|
||||
Converted value
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
ob_type = field_def.get("ob_type", "")
|
||||
is_array = field_def.get("is_array", False)
|
||||
is_json = field_def.get("is_json", False)
|
||||
|
||||
# Handle array fields
|
||||
if is_array:
|
||||
return self._convert_array_value(value)
|
||||
|
||||
# Handle JSON fields
|
||||
if is_json:
|
||||
return self._convert_json_value(value)
|
||||
|
||||
# Handle specific types
|
||||
if "Integer" in ob_type:
|
||||
return self._convert_integer(value)
|
||||
|
||||
if "Double" in ob_type or "Float" in ob_type:
|
||||
return self._convert_float(value)
|
||||
|
||||
if "LONGTEXT" in ob_type or "TEXT" in ob_type:
|
||||
return self._convert_text(value)
|
||||
|
||||
if "String" in ob_type:
|
||||
return self._convert_string(value, field_name)
|
||||
|
||||
# Default: convert to string
|
||||
return str(value) if value is not None else None
|
||||
|
||||
def _convert_array_value(self, value: Any) -> str | None:
|
||||
"""Convert array value to JSON string for OceanBase."""
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
if isinstance(value, str):
|
||||
# Already a JSON string
|
||||
try:
|
||||
# Validate it's valid JSON
|
||||
json.loads(value)
|
||||
return value
|
||||
except json.JSONDecodeError:
|
||||
# Not valid JSON, wrap in array
|
||||
return json.dumps([value], ensure_ascii=False)
|
||||
|
||||
if isinstance(value, list):
|
||||
# Clean array values
|
||||
cleaned = []
|
||||
for item in value:
|
||||
if isinstance(item, str):
|
||||
# Clean special characters
|
||||
cleaned_str = item.strip()
|
||||
cleaned_str = cleaned_str.replace('\\', '\\\\')
|
||||
cleaned_str = cleaned_str.replace('\n', '\\n')
|
||||
cleaned_str = cleaned_str.replace('\r', '\\r')
|
||||
cleaned_str = cleaned_str.replace('\t', '\\t')
|
||||
cleaned.append(cleaned_str)
|
||||
else:
|
||||
cleaned.append(item)
|
||||
return json.dumps(cleaned, ensure_ascii=False)
|
||||
|
||||
# Single value - wrap in array
|
||||
return json.dumps([value], ensure_ascii=False)
|
||||
|
||||
def _convert_json_value(self, value: Any) -> str | None:
|
||||
"""Convert JSON value to string for OceanBase."""
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
if isinstance(value, str):
|
||||
# Already a string, validate JSON
|
||||
try:
|
||||
json.loads(value)
|
||||
return value
|
||||
except json.JSONDecodeError:
|
||||
# Not valid JSON, return as-is
|
||||
return value
|
||||
|
||||
if isinstance(value, (dict, list)):
|
||||
return json.dumps(value, ensure_ascii=False)
|
||||
|
||||
return str(value)
|
||||
|
||||
def _convert_integer(self, value: Any) -> int | None:
|
||||
"""Convert to integer."""
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
if isinstance(value, bool):
|
||||
return 1 if value else 0
|
||||
|
||||
try:
|
||||
return int(value)
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
|
||||
def _convert_float(self, value: Any) -> float | None:
|
||||
"""Convert to float."""
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
return float(value)
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
|
||||
def _convert_text(self, value: Any) -> str | None:
|
||||
"""Convert to text/longtext."""
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
if isinstance(value, dict):
|
||||
# content_with_weight might be stored as dict
|
||||
return json.dumps(value, ensure_ascii=False)
|
||||
|
||||
if isinstance(value, list):
|
||||
return json.dumps(value, ensure_ascii=False)
|
||||
|
||||
return str(value)
|
||||
|
||||
def _convert_string(self, value: Any, field_name: str) -> str | None:
|
||||
"""Convert to string with length considerations."""
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
# Handle kb_id which might be a list in ES
|
||||
if field_name == "kb_id" and isinstance(value, list):
|
||||
return str(value[0]) if value else None
|
||||
|
||||
if isinstance(value, (dict, list)):
|
||||
return json.dumps(value, ensure_ascii=False)
|
||||
|
||||
return str(value)
|
||||
|
||||
def convert_batch(self, es_docs: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Convert a batch of ES documents.
|
||||
|
||||
Args:
|
||||
es_docs: List of Elasticsearch documents
|
||||
|
||||
Returns:
|
||||
List of dictionaries ready for OceanBase insertion
|
||||
"""
|
||||
return [self.convert_document(doc) for doc in es_docs]
|
||||
|
||||
|
||||
# Backwards compatibility aliases
|
||||
SchemaConverter = RAGFlowSchemaConverter
|
||||
DataConverter = RAGFlowDataConverter
|
||||
349
tools/es-to-oceanbase-migration/src/es_ob_migration/verify.py
Normal file
349
tools/es-to-oceanbase-migration/src/es_ob_migration/verify.py
Normal file
@ -0,0 +1,349 @@
|
||||
"""
|
||||
Data verification for RAGFlow migration.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from .es_client import ESClient
|
||||
from .ob_client import OBClient
|
||||
from .schema import RAGFLOW_COLUMNS, ARRAY_COLUMNS, JSON_COLUMNS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VerificationResult:
|
||||
"""Migration verification result."""
|
||||
|
||||
es_index: str
|
||||
ob_table: str
|
||||
|
||||
# Counts
|
||||
es_count: int = 0
|
||||
ob_count: int = 0
|
||||
count_match: bool = False
|
||||
count_diff: int = 0
|
||||
|
||||
# Sample verification
|
||||
sample_size: int = 0
|
||||
samples_verified: int = 0
|
||||
samples_matched: int = 0
|
||||
sample_match_rate: float = 0.0
|
||||
|
||||
# Mismatches
|
||||
missing_in_ob: list[str] = field(default_factory=list)
|
||||
data_mismatches: list[dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
# Overall
|
||||
passed: bool = False
|
||||
message: str = ""
|
||||
|
||||
|
||||
class MigrationVerifier:
|
||||
"""Verify RAGFlow migration data consistency."""
|
||||
|
||||
# Fields to compare for verification
|
||||
VERIFY_FIELDS = [
|
||||
"id", "kb_id", "doc_id", "docnm_kwd", "content_with_weight",
|
||||
"available_int", "create_time",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
es_client: ESClient,
|
||||
ob_client: OBClient,
|
||||
):
|
||||
"""
|
||||
Initialize verifier.
|
||||
|
||||
Args:
|
||||
es_client: Elasticsearch client
|
||||
ob_client: OceanBase client
|
||||
"""
|
||||
self.es_client = es_client
|
||||
self.ob_client = ob_client
|
||||
|
||||
def verify(
|
||||
self,
|
||||
es_index: str,
|
||||
ob_table: str,
|
||||
sample_size: int = 100,
|
||||
primary_key: str = "id",
|
||||
verify_fields: list[str] | None = None,
|
||||
) -> VerificationResult:
|
||||
"""
|
||||
Verify migration by comparing ES and OceanBase data.
|
||||
|
||||
Args:
|
||||
es_index: Elasticsearch index name
|
||||
ob_table: OceanBase table name
|
||||
sample_size: Number of documents to sample for verification
|
||||
primary_key: Primary key column name
|
||||
verify_fields: Fields to verify (None = use defaults)
|
||||
|
||||
Returns:
|
||||
VerificationResult with details
|
||||
"""
|
||||
result = VerificationResult(
|
||||
es_index=es_index,
|
||||
ob_table=ob_table,
|
||||
)
|
||||
|
||||
if verify_fields is None:
|
||||
verify_fields = self.VERIFY_FIELDS
|
||||
|
||||
# Step 1: Verify document counts
|
||||
logger.info("Verifying document counts...")
|
||||
|
||||
result.es_count = self.es_client.count_documents(es_index)
|
||||
result.ob_count = self.ob_client.count_rows(ob_table)
|
||||
|
||||
result.count_diff = abs(result.es_count - result.ob_count)
|
||||
result.count_match = result.count_diff == 0
|
||||
|
||||
logger.info(
|
||||
f"Document counts - ES: {result.es_count}, OB: {result.ob_count}, "
|
||||
f"Diff: {result.count_diff}"
|
||||
)
|
||||
|
||||
# Step 2: Sample verification
|
||||
result.sample_size = min(sample_size, result.es_count)
|
||||
|
||||
if result.sample_size > 0:
|
||||
logger.info(f"Verifying {result.sample_size} sample documents...")
|
||||
self._verify_samples(
|
||||
es_index, ob_table, result, primary_key, verify_fields
|
||||
)
|
||||
|
||||
# Step 3: Determine overall result
|
||||
self._determine_result(result)
|
||||
|
||||
logger.info(result.message)
|
||||
return result
|
||||
|
||||
def _verify_samples(
|
||||
self,
|
||||
es_index: str,
|
||||
ob_table: str,
|
||||
result: VerificationResult,
|
||||
primary_key: str,
|
||||
verify_fields: list[str],
|
||||
):
|
||||
"""Verify sample documents."""
|
||||
# Get sample documents from ES
|
||||
es_samples = self.es_client.get_sample_documents(
|
||||
es_index, result.sample_size
|
||||
)
|
||||
|
||||
for es_doc in es_samples:
|
||||
result.samples_verified += 1
|
||||
doc_id = es_doc.get("_id") or es_doc.get("id")
|
||||
|
||||
if not doc_id:
|
||||
logger.warning("Document without ID found")
|
||||
continue
|
||||
|
||||
# Get corresponding document from OceanBase
|
||||
ob_doc = self.ob_client.get_row_by_id(ob_table, doc_id)
|
||||
|
||||
if ob_doc is None:
|
||||
result.missing_in_ob.append(doc_id)
|
||||
continue
|
||||
|
||||
# Compare documents
|
||||
match, differences = self._compare_documents(
|
||||
es_doc, ob_doc, verify_fields
|
||||
)
|
||||
|
||||
if match:
|
||||
result.samples_matched += 1
|
||||
else:
|
||||
result.data_mismatches.append({
|
||||
"id": doc_id,
|
||||
"differences": differences,
|
||||
})
|
||||
|
||||
# Calculate match rate
|
||||
if result.samples_verified > 0:
|
||||
result.sample_match_rate = result.samples_matched / result.samples_verified
|
||||
|
||||
def _compare_documents(
|
||||
self,
|
||||
es_doc: dict[str, Any],
|
||||
ob_doc: dict[str, Any],
|
||||
verify_fields: list[str],
|
||||
) -> tuple[bool, list[dict[str, Any]]]:
|
||||
"""
|
||||
Compare ES document with OceanBase row.
|
||||
|
||||
Returns:
|
||||
Tuple of (match: bool, differences: list)
|
||||
"""
|
||||
differences = []
|
||||
|
||||
for field_name in verify_fields:
|
||||
es_value = es_doc.get(field_name)
|
||||
ob_value = ob_doc.get(field_name)
|
||||
|
||||
# Skip if both are None/null
|
||||
if es_value is None and ob_value is None:
|
||||
continue
|
||||
|
||||
# Handle special comparisons
|
||||
if not self._values_equal(field_name, es_value, ob_value):
|
||||
differences.append({
|
||||
"field": field_name,
|
||||
"es_value": es_value,
|
||||
"ob_value": ob_value,
|
||||
})
|
||||
|
||||
return len(differences) == 0, differences
|
||||
|
||||
def _values_equal(
|
||||
self,
|
||||
field_name: str,
|
||||
es_value: Any,
|
||||
ob_value: Any
|
||||
) -> bool:
|
||||
"""Compare two values with type-aware logic."""
|
||||
if es_value is None and ob_value is None:
|
||||
return True
|
||||
|
||||
if es_value is None or ob_value is None:
|
||||
# One is None, the other isn't
|
||||
# For optional fields, this might be acceptable
|
||||
return False
|
||||
|
||||
# Handle array fields (stored as JSON strings in OB)
|
||||
if field_name in ARRAY_COLUMNS:
|
||||
if isinstance(ob_value, str):
|
||||
try:
|
||||
ob_value = json.loads(ob_value)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
if isinstance(es_value, list) and isinstance(ob_value, list):
|
||||
return set(str(x) for x in es_value) == set(str(x) for x in ob_value)
|
||||
|
||||
# Handle JSON fields
|
||||
if field_name in JSON_COLUMNS:
|
||||
if isinstance(ob_value, str):
|
||||
try:
|
||||
ob_value = json.loads(ob_value)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
if isinstance(es_value, str):
|
||||
try:
|
||||
es_value = json.loads(es_value)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
return es_value == ob_value
|
||||
|
||||
# Handle content_with_weight which might be dict or string
|
||||
if field_name == "content_with_weight":
|
||||
if isinstance(ob_value, str) and isinstance(es_value, dict):
|
||||
try:
|
||||
ob_value = json.loads(ob_value)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Handle kb_id which might be list in ES
|
||||
if field_name == "kb_id":
|
||||
if isinstance(es_value, list) and len(es_value) > 0:
|
||||
es_value = es_value[0]
|
||||
|
||||
# Standard comparison
|
||||
return str(es_value) == str(ob_value)
|
||||
|
||||
def _determine_result(self, result: VerificationResult):
|
||||
"""Determine overall verification result."""
|
||||
# Allow small count differences (e.g., documents added during migration)
|
||||
count_tolerance = 0.01 # 1% tolerance
|
||||
count_ok = (
|
||||
result.count_match or
|
||||
(result.es_count > 0 and result.count_diff / result.es_count <= count_tolerance)
|
||||
)
|
||||
|
||||
if count_ok and result.sample_match_rate >= 0.99:
|
||||
result.passed = True
|
||||
result.message = (
|
||||
f"Verification PASSED. "
|
||||
f"ES: {result.es_count:,}, OB: {result.ob_count:,}. "
|
||||
f"Sample match rate: {result.sample_match_rate:.2%}"
|
||||
)
|
||||
elif count_ok and result.sample_match_rate >= 0.95:
|
||||
result.passed = True
|
||||
result.message = (
|
||||
f"Verification PASSED with warnings. "
|
||||
f"ES: {result.es_count:,}, OB: {result.ob_count:,}. "
|
||||
f"Sample match rate: {result.sample_match_rate:.2%}"
|
||||
)
|
||||
else:
|
||||
result.passed = False
|
||||
issues = []
|
||||
if not count_ok:
|
||||
issues.append(
|
||||
f"Count mismatch (ES: {result.es_count}, OB: {result.ob_count}, diff: {result.count_diff})"
|
||||
)
|
||||
if result.sample_match_rate < 0.95:
|
||||
issues.append(f"Low sample match rate: {result.sample_match_rate:.2%}")
|
||||
if result.missing_in_ob:
|
||||
issues.append(f"{len(result.missing_in_ob)} documents missing in OB")
|
||||
result.message = f"Verification FAILED: {'; '.join(issues)}"
|
||||
|
||||
def generate_report(self, result: VerificationResult) -> str:
|
||||
"""Generate a verification report."""
|
||||
lines = [
|
||||
"",
|
||||
"=" * 60,
|
||||
"Migration Verification Report",
|
||||
"=" * 60,
|
||||
f"ES Index: {result.es_index}",
|
||||
f"OB Table: {result.ob_table}",
|
||||
]
|
||||
|
||||
lines.extend([
|
||||
"",
|
||||
"Document Counts:",
|
||||
f" Elasticsearch: {result.es_count:,}",
|
||||
f" OceanBase: {result.ob_count:,}",
|
||||
f" Difference: {result.count_diff:,}",
|
||||
f" Match: {'Yes' if result.count_match else 'No'}",
|
||||
"",
|
||||
"Sample Verification:",
|
||||
f" Sample Size: {result.sample_size}",
|
||||
f" Verified: {result.samples_verified}",
|
||||
f" Matched: {result.samples_matched}",
|
||||
f" Match Rate: {result.sample_match_rate:.2%}",
|
||||
"",
|
||||
])
|
||||
|
||||
if result.missing_in_ob:
|
||||
lines.append(f"Missing in OceanBase ({len(result.missing_in_ob)}):")
|
||||
for doc_id in result.missing_in_ob[:5]:
|
||||
lines.append(f" - {doc_id}")
|
||||
if len(result.missing_in_ob) > 5:
|
||||
lines.append(f" ... and {len(result.missing_in_ob) - 5} more")
|
||||
lines.append("")
|
||||
|
||||
if result.data_mismatches:
|
||||
lines.append(f"Data Mismatches ({len(result.data_mismatches)}):")
|
||||
for mismatch in result.data_mismatches[:3]:
|
||||
lines.append(f" - ID: {mismatch['id']}")
|
||||
for diff in mismatch.get("differences", [])[:2]:
|
||||
lines.append(f" {diff['field']}: ES={diff['es_value']}, OB={diff['ob_value']}")
|
||||
if len(result.data_mismatches) > 3:
|
||||
lines.append(f" ... and {len(result.data_mismatches) - 3} more")
|
||||
lines.append("")
|
||||
|
||||
lines.extend([
|
||||
"=" * 60,
|
||||
f"Result: {'PASSED' if result.passed else 'FAILED'}",
|
||||
result.message,
|
||||
"=" * 60,
|
||||
"",
|
||||
])
|
||||
|
||||
return "\n".join(lines)
|
||||
Reference in New Issue
Block a user