feat(tools): add Elasticsearch to OceanBase migration tool (#12927)

### What problem does this PR solve?

fixes https://github.com/infiniflow/ragflow/issues/12774

Add a CLI tool for migrating RAGFlow data from Elasticsearch to
OceanBase, enabling users to switch their document storage backend.

- Automatic discovery and migration of all `ragflow_*` indices
- Schema conversion with vector dimension auto-detection
- Batch processing with progress tracking and resume capability
- Data consistency validation and migration report generation

**Note**: Due to network issues, I was unable to pull the required
Docker images (Elasticsearch, OceanBase) to run the full end-to-end
verification. Unit tests have been verified to pass. I will complete the
e2e verification when network conditions allow, and submit a follow-up
PR if any fixes are needed.

```bash
============================= test session starts ==============================
platform darwin -- Python 3.13.6, pytest-9.0.2, pluggy-1.6.0
rootdir: /Users/sevenc/code/ai/oceanbase/ragflow/tools/es-to-oceanbase-migration
configfile: pyproject.toml
testpaths: tests
plugins: anyio-4.12.1, asyncio-1.3.0, cov-7.0.0
collected 86 items

tests/test_progress.py::TestMigrationProgress::test_create_basic_progress PASSED [  1%]
tests/test_progress.py::TestMigrationProgress::test_create_progress_with_counts PASSED [  2%]
tests/test_progress.py::TestMigrationProgress::test_progress_default_values PASSED [  3%]
tests/test_progress.py::TestMigrationProgress::test_progress_status_values PASSED [  4%]
tests/test_progress.py::TestProgressManager::test_create_progress_manager PASSED [  5%]
tests/test_progress.py::TestProgressManager::test_create_progress_manager_creates_dir PASSED [  6%]
tests/test_progress.py::TestProgressManager::test_create_progress PASSED [  8%]
tests/test_progress.py::TestProgressManager::test_save_and_load_progress PASSED [  9%]
tests/test_progress.py::TestProgressManager::test_load_nonexistent_progress PASSED [ 10%]
tests/test_progress.py::TestProgressManager::test_delete_progress PASSED [ 11%]
tests/test_progress.py::TestProgressManager::test_update_progress PASSED [ 12%]
tests/test_progress.py::TestProgressManager::test_update_progress_multiple_batches PASSED [ 13%]
tests/test_progress.py::TestProgressManager::test_mark_completed PASSED  [ 15%]
tests/test_progress.py::TestProgressManager::test_mark_failed PASSED     [ 16%]
tests/test_progress.py::TestProgressManager::test_mark_paused PASSED     [ 17%]
tests/test_progress.py::TestProgressManager::test_can_resume_running PASSED [ 18%]
tests/test_progress.py::TestProgressManager::test_can_resume_paused PASSED [ 19%]
tests/test_progress.py::TestProgressManager::test_can_resume_completed PASSED [ 20%]
tests/test_progress.py::TestProgressManager::test_can_resume_nonexistent PASSED [ 22%]
tests/test_progress.py::TestProgressManager::test_get_resume_info PASSED [ 23%]
tests/test_progress.py::TestProgressManager::test_get_resume_info_nonexistent PASSED [ 24%]
tests/test_progress.py::TestProgressManager::test_progress_file_path PASSED [ 25%]
tests/test_progress.py::TestProgressManager::test_progress_file_content PASSED [ 26%]
tests/test_schema.py::TestRAGFlowSchemaConverter::test_analyze_ragflow_mapping PASSED [ 27%]
tests/test_schema.py::TestRAGFlowSchemaConverter::test_detect_vector_size PASSED [ 29%]
tests/test_schema.py::TestRAGFlowSchemaConverter::test_unknown_fields PASSED [ 30%]
tests/test_schema.py::TestRAGFlowSchemaConverter::test_get_column_definitions PASSED [ 31%]
tests/test_schema.py::TestRAGFlowDataConverter::test_convert_basic_document PASSED [ 32%]
tests/test_schema.py::TestRAGFlowDataConverter::test_convert_with_vector PASSED [ 33%]
tests/test_schema.py::TestRAGFlowDataConverter::test_convert_array_fields PASSED [ 34%]
tests/test_schema.py::TestRAGFlowDataConverter::test_convert_json_fields PASSED [ 36%]
tests/test_schema.py::TestRAGFlowDataConverter::test_convert_unknown_fields_to_extra PASSED [ 37%]
tests/test_schema.py::TestRAGFlowDataConverter::test_convert_kb_id_list PASSED [ 38%]
tests/test_schema.py::TestRAGFlowDataConverter::test_convert_content_with_weight_dict PASSED [ 39%]
tests/test_schema.py::TestRAGFlowDataConverter::test_convert_batch PASSED [ 40%]
tests/test_schema.py::TestVectorFieldPattern::test_valid_patterns PASSED [ 41%]
tests/test_schema.py::TestVectorFieldPattern::test_invalid_patterns PASSED [ 43%]
tests/test_schema.py::TestVectorFieldPattern::test_extract_dimension PASSED [ 44%]
tests/test_schema.py::TestConstants::test_array_columns PASSED           [ 45%]
tests/test_schema.py::TestConstants::test_json_columns PASSED            [ 46%]
tests/test_schema.py::TestConstants::test_ragflow_columns_completeness PASSED [ 47%]
tests/test_schema.py::TestConstants::test_fts_columns PASSED             [ 48%]
tests/test_schema.py::TestConstants::test_ragflow_columns_types PASSED   [ 50%]
tests/test_schema.py::TestRAGFlowSchemaConverterEdgeCases::test_empty_mapping PASSED [ 51%]
tests/test_schema.py::TestRAGFlowSchemaConverterEdgeCases::test_mapping_without_properties PASSED [ 52%]
tests/test_schema.py::TestRAGFlowSchemaConverterEdgeCases::test_multiple_vector_fields PASSED [ 53%]
tests/test_schema.py::TestRAGFlowSchemaConverterEdgeCases::test_get_column_definitions_without_analysis PASSED [ 54%]
tests/test_schema.py::TestRAGFlowSchemaConverterEdgeCases::test_get_vector_fields PASSED [ 55%]
tests/test_schema.py::TestRAGFlowDataConverterEdgeCases::test_convert_empty_document PASSED [ 56%]
tests/test_schema.py::TestRAGFlowDataConverterEdgeCases::test_convert_document_without_source PASSED [ 58%]
tests/test_schema.py::TestRAGFlowDataConverterEdgeCases::test_convert_boolean_to_integer PASSED [ 59%]
tests/test_schema.py::TestRAGFlowDataConverterEdgeCases::test_convert_invalid_integer PASSED [ 60%]
tests/test_schema.py::TestRAGFlowDataConverterEdgeCases::test_convert_float_field PASSED [ 61%]
tests/test_schema.py::TestRAGFlowDataConverterEdgeCases::test_convert_array_with_special_characters PASSED [ 62%]
tests/test_schema.py::TestRAGFlowDataConverterEdgeCases::test_convert_already_json_array PASSED [ 63%]
tests/test_schema.py::TestRAGFlowDataConverterEdgeCases::test_convert_single_value_to_array PASSED [ 65%]
tests/test_schema.py::TestRAGFlowDataConverterEdgeCases::test_detect_vector_fields_from_document PASSED [ 66%]
tests/test_schema.py::TestRAGFlowDataConverterEdgeCases::test_convert_with_default_values PASSED [ 67%]
tests/test_schema.py::TestRAGFlowDataConverterEdgeCases::test_convert_list_content PASSED [ 68%]
tests/test_schema.py::TestRAGFlowDataConverterEdgeCases::test_convert_batch_empty PASSED [ 69%]
tests/test_schema.py::TestRAGFlowDataConverterEdgeCases::test_existing_extra_field_merged PASSED [ 70%]
tests/test_verify.py::TestVerificationResult::test_create_basic_result PASSED [ 72%]
tests/test_verify.py::TestVerificationResult::test_result_default_values PASSED [ 73%]
tests/test_verify.py::TestVerificationResult::test_result_with_counts PASSED [ 74%]
tests/test_verify.py::TestMigrationVerifier::test_verify_counts_match PASSED [ 75%]
tests/test_verify.py::TestMigrationVerifier::test_verify_counts_mismatch PASSED [ 76%]
tests/test_verify.py::TestMigrationVerifier::test_verify_samples_all_match PASSED [ 77%]
tests/test_verify.py::TestMigrationVerifier::test_verify_samples_some_missing PASSED [ 79%]
tests/test_verify.py::TestMigrationVerifier::test_verify_samples_data_mismatch PASSED [ 80%]
tests/test_verify.py::TestMigrationVerifier::test_values_equal_none_values PASSED [ 81%]
tests/test_verify.py::TestMigrationVerifier::test_values_equal_array_columns PASSED [ 82%]
tests/test_verify.py::TestMigrationVerifier::test_values_equal_json_columns PASSED [ 83%]
tests/test_verify.py::TestMigrationVerifier::test_values_equal_kb_id_list PASSED [ 84%]
tests/test_verify.py::TestMigrationVerifier::test_values_equal_content_with_weight_dict PASSED [ 86%]
tests/test_verify.py::TestMigrationVerifier::test_determine_result_passed PASSED [ 87%]
tests/test_verify.py::TestMigrationVerifier::test_determine_result_failed_count PASSED [ 88%]
tests/test_verify.py::TestMigrationVerifier::test_determine_result_failed_samples PASSED [ 89%]
tests/test_verify.py::TestMigrationVerifier::test_generate_report PASSED [ 90%]
tests/test_verify.py::TestMigrationVerifier::test_generate_report_with_missing PASSED [ 91%]
tests/test_verify.py::TestMigrationVerifier::test_generate_report_with_mismatches PASSED [ 93%]
tests/test_verify.py::TestValueComparison::test_string_comparison PASSED [ 94%]
tests/test_verify.py::TestValueComparison::test_integer_comparison PASSED [ 95%]
tests/test_verify.py::TestValueComparison::test_float_comparison PASSED  [ 96%]
tests/test_verify.py::TestValueComparison::test_boolean_comparison PASSED [ 97%]
tests/test_verify.py::TestValueComparison::test_empty_array_comparison PASSED [ 98%]
tests/test_verify.py::TestValueComparison::test_nested_json_comparison PASSED [100%]

======================= 86 passed, 88 warnings in 0.66s ========================
```

### Type of change

- [ ] Bug Fix (non-breaking change which fixes an issue)
- [x] New Feature (non-breaking change which adds functionality)
- [ ] Documentation Update
- [ ] Refactoring
- [ ] Performance Improvement
- [ ] Other (please describe):
This commit is contained in:
Se7en
2026-01-31 16:11:27 +08:00
committed by GitHub
parent c4c3f744c0
commit 332b11cf96
15 changed files with 5606 additions and 0 deletions

View File

@ -0,0 +1,41 @@
"""
RAGFlow ES to OceanBase Migration Tool
A CLI tool for migrating RAGFlow data from Elasticsearch 8+ to OceanBase,
supporting schema conversion, vector data mapping, batch import, and resume capability.
This tool is specifically designed for RAGFlow's data structure.
"""
__version__ = "0.1.0"
from .migrator import ESToOceanBaseMigrator
from .es_client import ESClient
from .ob_client import OBClient
from .schema import RAGFlowSchemaConverter, RAGFlowDataConverter, RAGFLOW_COLUMNS
from .verify import MigrationVerifier, VerificationResult
from .progress import ProgressManager, MigrationProgress
# Backwards compatibility aliases
SchemaConverter = RAGFlowSchemaConverter
DataConverter = RAGFlowDataConverter
__all__ = [
# Main classes
"ESToOceanBaseMigrator",
"ESClient",
"OBClient",
# Schema
"RAGFlowSchemaConverter",
"RAGFlowDataConverter",
"RAGFLOW_COLUMNS",
# Verification
"MigrationVerifier",
"VerificationResult",
# Progress
"ProgressManager",
"MigrationProgress",
# Aliases
"SchemaConverter",
"DataConverter",
]

View File

@ -0,0 +1,574 @@
"""
CLI entry point for RAGFlow ES to OceanBase migration tool.
"""
import json
import logging
import sys
import click
from rich.console import Console
from rich.table import Table
from rich.logging import RichHandler
from .es_client import ESClient
from .ob_client import OBClient
from .migrator import ESToOceanBaseMigrator
from .verify import MigrationVerifier
from .schema import RAGFLOW_COLUMNS
console = Console()
def setup_logging(verbose: bool = False):
"""Setup logging configuration."""
level = logging.DEBUG if verbose else logging.INFO
logging.basicConfig(
level=level,
format="%(message)s",
datefmt="[%X]",
handlers=[RichHandler(rich_tracebacks=True, console=console)],
)
@click.group()
@click.option("-v", "--verbose", is_flag=True, help="Enable verbose logging")
@click.pass_context
def main(ctx, verbose):
"""RAGFlow ES to OceanBase Migration Tool.
Migrate RAGFlow data from Elasticsearch 8+ to OceanBase with schema conversion,
vector data mapping, batch import, and resume capability.
This tool is specifically designed for RAGFlow's data structure.
"""
ctx.ensure_object(dict)
ctx.obj["verbose"] = verbose
setup_logging(verbose)
@main.command()
@click.option("--es-host", default="localhost", help="Elasticsearch host")
@click.option("--es-port", default=9200, type=int, help="Elasticsearch port")
@click.option("--es-user", default=None, help="Elasticsearch username")
@click.option("--es-password", default=None, help="Elasticsearch password")
@click.option("--es-api-key", default=None, help="Elasticsearch API key")
@click.option("--ob-host", default="localhost", help="OceanBase host")
@click.option("--ob-port", default=2881, type=int, help="OceanBase port")
@click.option("--ob-user", default="root@test", help="OceanBase user (format: user@tenant)")
@click.option("--ob-password", default="", help="OceanBase password")
@click.option("--ob-database", default="test", help="OceanBase database")
@click.option("--index", "-i", default=None, help="Source ES index name (omit to migrate all ragflow_* indices)")
@click.option("--table", "-t", default=None, help="Target OceanBase table name (omit to use same name as index)")
@click.option("--batch-size", default=1000, type=int, help="Batch size for migration")
@click.option("--resume", is_flag=True, help="Resume from previous progress")
@click.option("--verify/--no-verify", default=True, help="Verify after migration")
@click.option("--progress-dir", default=".migration_progress", help="Progress file directory")
@click.pass_context
def migrate(
ctx,
es_host,
es_port,
es_user,
es_password,
es_api_key,
ob_host,
ob_port,
ob_user,
ob_password,
ob_database,
index,
table,
batch_size,
resume,
verify,
progress_dir,
):
"""Run RAGFlow data migration from Elasticsearch to OceanBase.
If --index is omitted, all indices starting with 'ragflow_' will be migrated.
If --table is omitted, the same name as the source index will be used.
"""
console.print("[bold]RAGFlow ES to OceanBase Migration[/]")
try:
# Initialize ES client first to discover indices if needed
es_client = ESClient(
host=es_host,
port=es_port,
username=es_user,
password=es_password,
api_key=es_api_key,
)
ob_client = OBClient(
host=ob_host,
port=ob_port,
user=ob_user,
password=ob_password,
database=ob_database,
)
# Determine indices to migrate
if index:
# Single index specified
indices_to_migrate = [(index, table if table else index)]
else:
# Auto-discover all ragflow_* indices
console.print(f"\n[cyan]Discovering RAGFlow indices...[/]")
ragflow_indices = es_client.list_ragflow_indices()
if not ragflow_indices:
console.print("[yellow]No ragflow_* indices found in Elasticsearch[/]")
sys.exit(0)
# Each index maps to a table with the same name
indices_to_migrate = [(idx, idx) for idx in ragflow_indices]
console.print(f"[green]Found {len(indices_to_migrate)} RAGFlow indices:[/]")
for idx, _ in indices_to_migrate:
doc_count = es_client.count_documents(idx)
console.print(f" - {idx} ({doc_count:,} documents)")
console.print()
# Initialize migrator
migrator = ESToOceanBaseMigrator(
es_client=es_client,
ob_client=ob_client,
progress_dir=progress_dir,
)
# Track overall results
total_success = 0
total_failed = 0
results = []
# Migrate each index
for es_index, ob_table in indices_to_migrate:
console.print(f"\n[bold blue]{'='*60}[/]")
console.print(f"[bold]Migrating: {es_index} -> {ob_database}.{ob_table}[/]")
console.print(f"[bold blue]{'='*60}[/]")
result = migrator.migrate(
es_index=es_index,
ob_table=ob_table,
batch_size=batch_size,
resume=resume,
verify_after=verify,
)
results.append(result)
if result["success"]:
total_success += 1
else:
total_failed += 1
# Summary for multiple indices
if len(indices_to_migrate) > 1:
console.print(f"\n[bold]{'='*60}[/]")
console.print(f"[bold]Migration Summary[/]")
console.print(f"[bold]{'='*60}[/]")
console.print(f" Total indices: {len(indices_to_migrate)}")
console.print(f" [green]Successful: {total_success}[/]")
if total_failed > 0:
console.print(f" [red]Failed: {total_failed}[/]")
# Exit code based on results
if total_failed == 0:
console.print("\n[bold green]All migrations completed successfully![/]")
sys.exit(0)
else:
console.print(f"\n[bold red]{total_failed} migration(s) failed[/]")
sys.exit(1)
except Exception as e:
console.print(f"[bold red]Error: {e}[/]")
if ctx.obj.get("verbose"):
console.print_exception()
sys.exit(1)
finally:
# Cleanup
if "es_client" in locals():
es_client.close()
if "ob_client" in locals():
ob_client.close()
@main.command()
@click.option("--es-host", default="localhost", help="Elasticsearch host")
@click.option("--es-port", default=9200, type=int, help="Elasticsearch port")
@click.option("--es-user", default=None, help="Elasticsearch username")
@click.option("--es-password", default=None, help="Elasticsearch password")
@click.option("--index", "-i", required=True, help="ES index name")
@click.option("--output", "-o", default=None, help="Output file (JSON)")
@click.pass_context
def schema(ctx, es_host, es_port, es_user, es_password, index, output):
"""Preview RAGFlow schema analysis from ES mapping."""
try:
es_client = ESClient(
host=es_host,
port=es_port,
username=es_user,
password=es_password,
)
# Dummy OB client for schema preview
ob_client = None
migrator = ESToOceanBaseMigrator(es_client, ob_client if ob_client else OBClient.__new__(OBClient))
# Directly use schema converter
from .schema import RAGFlowSchemaConverter
converter = RAGFlowSchemaConverter()
es_mapping = es_client.get_index_mapping(index)
analysis = converter.analyze_es_mapping(es_mapping)
column_defs = converter.get_column_definitions()
# Display analysis
console.print(f"\n[bold]ES Index Analysis: {index}[/]\n")
# Known RAGFlow fields
console.print(f"[green]Known RAGFlow fields:[/] {len(analysis['known_fields'])}")
# Vector fields
if analysis['vector_fields']:
console.print(f"\n[cyan]Vector fields detected:[/]")
for vf in analysis['vector_fields']:
console.print(f" - {vf['name']} (dimension: {vf['dimension']})")
# Unknown fields
if analysis['unknown_fields']:
console.print(f"\n[yellow]Unknown fields (will be stored in 'extra'):[/]")
for uf in analysis['unknown_fields']:
console.print(f" - {uf}")
# Display RAGFlow column schema
console.print(f"\n[bold]RAGFlow OceanBase Schema ({len(column_defs)} columns):[/]\n")
table = Table(title="Column Definitions")
table.add_column("Column Name", style="cyan")
table.add_column("OB Type", style="green")
table.add_column("Nullable", style="yellow")
table.add_column("Special", style="magenta")
for col in column_defs[:20]: # Show first 20
special = []
if col.get("is_primary"):
special.append("PK")
if col.get("index"):
special.append("IDX")
if col.get("is_array"):
special.append("ARRAY")
if col.get("is_vector"):
special.append("VECTOR")
table.add_row(
col["name"],
col["ob_type"],
"Yes" if col.get("nullable", True) else "No",
", ".join(special) if special else "-",
)
if len(column_defs) > 20:
table.add_row("...", f"({len(column_defs) - 20} more)", "", "")
console.print(table)
# Save to file if requested
if output:
preview = {
"es_index": index,
"es_mapping": es_mapping,
"analysis": analysis,
"ob_columns": column_defs,
}
with open(output, "w") as f:
json.dump(preview, f, indent=2, default=str)
console.print(f"\nSchema saved to {output}")
except Exception as e:
console.print(f"[bold red]Error: {e}[/]")
if ctx.obj.get("verbose"):
console.print_exception()
sys.exit(1)
finally:
if "es_client" in locals():
es_client.close()
@main.command()
@click.option("--es-host", default="localhost", help="Elasticsearch host")
@click.option("--es-port", default=9200, type=int, help="Elasticsearch port")
@click.option("--ob-host", default="localhost", help="OceanBase host")
@click.option("--ob-port", default=2881, type=int, help="OceanBase port")
@click.option("--ob-user", default="root@test", help="OceanBase user")
@click.option("--ob-password", default="", help="OceanBase password")
@click.option("--ob-database", default="test", help="OceanBase database")
@click.option("--index", "-i", required=True, help="Source ES index name")
@click.option("--table", "-t", required=True, help="Target OceanBase table name")
@click.option("--sample-size", default=100, type=int, help="Sample size for verification")
@click.pass_context
def verify(
ctx,
es_host,
es_port,
ob_host,
ob_port,
ob_user,
ob_password,
ob_database,
index,
table,
sample_size,
):
"""Verify migration data consistency."""
try:
es_client = ESClient(host=es_host, port=es_port)
ob_client = OBClient(
host=ob_host,
port=ob_port,
user=ob_user,
password=ob_password,
database=ob_database,
)
verifier = MigrationVerifier(es_client, ob_client)
result = verifier.verify(
index, table,
sample_size=sample_size,
)
console.print(verifier.generate_report(result))
sys.exit(0 if result.passed else 1)
except Exception as e:
console.print(f"[bold red]Error: {e}[/]")
if ctx.obj.get("verbose"):
console.print_exception()
sys.exit(1)
finally:
if "es_client" in locals():
es_client.close()
if "ob_client" in locals():
ob_client.close()
@main.command("list-indices")
@click.option("--es-host", default="localhost", help="Elasticsearch host")
@click.option("--es-port", default=9200, type=int, help="Elasticsearch port")
@click.option("--es-user", default=None, help="Elasticsearch username")
@click.option("--es-password", default=None, help="Elasticsearch password")
@click.pass_context
def list_indices(ctx, es_host, es_port, es_user, es_password):
"""List all RAGFlow indices (ragflow_*) in Elasticsearch."""
try:
es_client = ESClient(
host=es_host,
port=es_port,
username=es_user,
password=es_password,
)
console.print(f"\n[bold]RAGFlow Indices in Elasticsearch ({es_host}:{es_port})[/]\n")
indices = es_client.list_ragflow_indices()
if not indices:
console.print("[yellow]No ragflow_* indices found[/]")
return
table = Table(title="RAGFlow Indices")
table.add_column("Index Name", style="cyan")
table.add_column("Document Count", style="green", justify="right")
table.add_column("Type", style="yellow")
total_docs = 0
for idx in indices:
doc_count = es_client.count_documents(idx)
total_docs += doc_count
# Determine index type
if idx.startswith("ragflow_doc_meta_"):
idx_type = "Metadata"
elif idx.startswith("ragflow_"):
idx_type = "Document Chunks"
else:
idx_type = "Unknown"
table.add_row(idx, f"{doc_count:,}", idx_type)
table.add_row("", "", "")
table.add_row("[bold]Total[/]", f"[bold]{total_docs:,}[/]", f"[bold]{len(indices)} indices[/]")
console.print(table)
except Exception as e:
console.print(f"[bold red]Error: {e}[/]")
if ctx.obj.get("verbose"):
console.print_exception()
sys.exit(1)
finally:
if "es_client" in locals():
es_client.close()
@main.command("list-kb")
@click.option("--es-host", default="localhost", help="Elasticsearch host")
@click.option("--es-port", default=9200, type=int, help="Elasticsearch port")
@click.option("--es-user", default=None, help="Elasticsearch username")
@click.option("--es-password", default=None, help="Elasticsearch password")
@click.option("--index", "-i", required=True, help="ES index name")
@click.pass_context
def list_kb(ctx, es_host, es_port, es_user, es_password, index):
"""List all knowledge bases in an ES index."""
try:
es_client = ESClient(
host=es_host,
port=es_port,
username=es_user,
password=es_password,
)
console.print(f"\n[bold]Knowledge Bases in index: {index}[/]\n")
# Get kb_id aggregation
agg_result = es_client.aggregate_field(index, "kb_id")
buckets = agg_result.get("buckets", [])
if not buckets:
console.print("[yellow]No knowledge bases found[/]")
return
table = Table(title="Knowledge Bases")
table.add_column("KB ID", style="cyan")
table.add_column("Document Count", style="green", justify="right")
total_docs = 0
for bucket in buckets:
table.add_row(
bucket["key"],
f"{bucket['doc_count']:,}",
)
total_docs += bucket["doc_count"]
table.add_row("", "")
table.add_row("[bold]Total[/]", f"[bold]{total_docs:,}[/]")
console.print(table)
console.print(f"\nTotal knowledge bases: {len(buckets)}")
except Exception as e:
console.print(f"[bold red]Error: {e}[/]")
if ctx.obj.get("verbose"):
console.print_exception()
sys.exit(1)
finally:
if "es_client" in locals():
es_client.close()
@main.command()
@click.option("--es-host", default="localhost", help="Elasticsearch host")
@click.option("--es-port", default=9200, type=int, help="Elasticsearch port")
@click.option("--ob-host", default="localhost", help="OceanBase host")
@click.option("--ob-port", default=2881, type=int, help="OceanBase port")
@click.option("--ob-user", default="root@test", help="OceanBase user")
@click.option("--ob-password", default="", help="OceanBase password")
@click.pass_context
def status(ctx, es_host, es_port, ob_host, ob_port, ob_user, ob_password):
"""Check connection status to ES and OceanBase."""
console.print("[bold]Connection Status[/]\n")
# Check ES
try:
es_client = ESClient(host=es_host, port=es_port)
health = es_client.health_check()
info = es_client.get_cluster_info()
console.print(f"[green]Elasticsearch ({es_host}:{es_port}): Connected[/]")
console.print(f" Cluster: {health.get('cluster_name')}")
console.print(f" Status: {health.get('status')}")
console.print(f" Version: {info.get('version', {}).get('number', 'unknown')}")
# List indices
indices = es_client.list_indices("*")
console.print(f" Indices: {len(indices)}")
es_client.close()
except Exception as e:
console.print(f"[red]Elasticsearch ({es_host}:{es_port}): Failed[/]")
console.print(f" Error: {e}")
console.print()
# Check OceanBase
try:
ob_client = OBClient(
host=ob_host,
port=ob_port,
user=ob_user,
password=ob_password,
)
if ob_client.health_check():
version = ob_client.get_version()
console.print(f"[green]OceanBase ({ob_host}:{ob_port}): Connected[/]")
console.print(f" Version: {version}")
else:
console.print(f"[red]OceanBase ({ob_host}:{ob_port}): Health check failed[/]")
ob_client.close()
except Exception as e:
console.print(f"[red]OceanBase ({ob_host}:{ob_port}): Failed[/]")
console.print(f" Error: {e}")
@main.command()
@click.option("--es-host", default="localhost", help="Elasticsearch host")
@click.option("--es-port", default=9200, type=int, help="Elasticsearch port")
@click.option("--index", "-i", required=True, help="ES index name")
@click.option("--size", "-n", default=5, type=int, help="Number of samples")
@click.pass_context
def sample(ctx, es_host, es_port, index, size):
"""Show sample documents from ES index."""
try:
es_client = ESClient(host=es_host, port=es_port)
docs = es_client.get_sample_documents(index, size)
console.print(f"\n[bold]Sample documents from {index}[/]")
console.print()
for i, doc in enumerate(docs, 1):
console.print(f"[bold cyan]Document {i}[/]")
console.print(f" _id: {doc.get('_id')}")
console.print(f" kb_id: {doc.get('kb_id')}")
console.print(f" doc_id: {doc.get('doc_id')}")
console.print(f" docnm_kwd: {doc.get('docnm_kwd')}")
# Check for vector fields
vector_fields = [k for k in doc.keys() if k.startswith("q_") and k.endswith("_vec")]
if vector_fields:
for vf in vector_fields:
vec = doc.get(vf)
if vec:
console.print(f" {vf}: [{len(vec)} dimensions]")
content = doc.get("content_with_weight", "")
if content:
if isinstance(content, dict):
content = json.dumps(content, ensure_ascii=False)
preview = content[:100] + "..." if len(str(content)) > 100 else content
console.print(f" content: {preview}")
console.print()
es_client.close()
except Exception as e:
console.print(f"[bold red]Error: {e}[/]")
if ctx.obj.get("verbose"):
console.print_exception()
sys.exit(1)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,292 @@
"""
Elasticsearch 8+ Client for RAGFlow data migration.
"""
import logging
from typing import Any, Iterator
from elasticsearch import Elasticsearch
logger = logging.getLogger(__name__)
class ESClient:
"""Elasticsearch client wrapper for RAGFlow migration operations."""
def __init__(
self,
host: str = "localhost",
port: int = 9200,
username: str | None = None,
password: str | None = None,
api_key: str | None = None,
use_ssl: bool = False,
verify_certs: bool = True,
):
"""
Initialize ES client.
Args:
host: ES host address
port: ES port
username: Basic auth username
password: Basic auth password
api_key: API key for authentication
use_ssl: Whether to use SSL
verify_certs: Whether to verify SSL certificates
"""
self.host = host
self.port = port
# Build connection URL
scheme = "https" if use_ssl else "http"
url = f"{scheme}://{host}:{port}"
# Build connection arguments
conn_args: dict[str, Any] = {
"hosts": [url],
"verify_certs": verify_certs,
}
if api_key:
conn_args["api_key"] = api_key
elif username and password:
conn_args["basic_auth"] = (username, password)
self.client = Elasticsearch(**conn_args)
logger.info(f"Connected to Elasticsearch at {url}")
def health_check(self) -> dict[str, Any]:
"""Check cluster health."""
return self.client.cluster.health().body
def get_cluster_info(self) -> dict[str, Any]:
"""Get cluster information."""
return self.client.info().body
def list_indices(self, pattern: str = "*") -> list[str]:
"""List all indices matching pattern."""
response = self.client.indices.get(index=pattern)
return list(response.keys())
def list_ragflow_indices(self) -> list[str]:
"""
List all RAGFlow-related indices.
Returns indices matching patterns:
- ragflow_* (document chunks)
- ragflow_doc_meta_* (document metadata)
Returns:
List of RAGFlow index names
"""
try:
# Get all ragflow_* indices
ragflow_indices = self.list_indices("ragflow_*")
return sorted(ragflow_indices)
except Exception:
# If no indices match, return empty list
return []
def get_index_mapping(self, index_name: str) -> dict[str, Any]:
"""
Get index mapping.
Args:
index_name: Name of the index
Returns:
Index mapping dictionary
"""
response = self.client.indices.get_mapping(index=index_name)
return response[index_name]["mappings"]
def get_index_settings(self, index_name: str) -> dict[str, Any]:
"""Get index settings."""
response = self.client.indices.get_settings(index=index_name)
return response[index_name]["settings"]
def count_documents(self, index_name: str) -> int:
"""Count documents in an index."""
response = self.client.count(index=index_name)
return response["count"]
def count_documents_with_filter(
self,
index_name: str,
filters: dict[str, Any]
) -> int:
"""
Count documents with filter conditions.
Args:
index_name: Index name
filters: Filter conditions (e.g., {"kb_id": "xxx"})
Returns:
Document count
"""
# Build bool query with filters
must_clauses = []
for field, value in filters.items():
if isinstance(value, list):
must_clauses.append({"terms": {field: value}})
else:
must_clauses.append({"term": {field: value}})
query = {
"bool": {
"must": must_clauses
}
} if must_clauses else {"match_all": {}}
response = self.client.count(index=index_name, query=query)
return response["count"]
def aggregate_field(
self,
index_name: str,
field: str,
size: int = 10000,
) -> dict[str, Any]:
"""
Aggregate field values (like getting all unique kb_ids).
Args:
index_name: Index name
field: Field to aggregate
size: Max number of buckets
Returns:
Aggregation result with buckets
"""
response = self.client.search(
index=index_name,
size=0,
aggs={
"field_values": {
"terms": {
"field": field,
"size": size,
}
}
}
)
return response["aggregations"]["field_values"]
def scroll_documents(
self,
index_name: str,
batch_size: int = 1000,
query: dict[str, Any] | None = None,
sort_field: str = "_doc",
) -> Iterator[list[dict[str, Any]]]:
"""
Scroll through all documents in an index using search_after (ES 8+).
This is the recommended approach for ES 8+ instead of scroll API.
Uses search_after for efficient deep pagination.
Args:
index_name: Name of the index
batch_size: Number of documents per batch
query: Optional query filter
sort_field: Field to sort by (default: _doc for efficiency)
Yields:
Batches of documents
"""
search_body: dict[str, Any] = {
"size": batch_size,
"sort": [{sort_field: "asc"}, {"_id": "asc"}],
}
if query:
search_body["query"] = query
else:
search_body["query"] = {"match_all": {}}
# Initial search
response = self.client.search(index=index_name, body=search_body)
hits = response["hits"]["hits"]
while hits:
# Extract documents with _id
documents = []
for hit in hits:
doc = hit["_source"].copy()
doc["_id"] = hit["_id"]
if "_score" in hit:
doc["_score"] = hit["_score"]
documents.append(doc)
yield documents
# Check if there are more results
if len(hits) < batch_size:
break
# Get search_after value from last hit
search_after = hits[-1]["sort"]
search_body["search_after"] = search_after
response = self.client.search(index=index_name, body=search_body)
hits = response["hits"]["hits"]
def get_document(self, index_name: str, doc_id: str) -> dict[str, Any] | None:
"""Get a single document by ID."""
try:
response = self.client.get(index=index_name, id=doc_id)
doc = response["_source"].copy()
doc["_id"] = response["_id"]
return doc
except Exception:
return None
def get_sample_documents(
self,
index_name: str,
size: int = 10,
query: dict[str, Any] | None = None,
) -> list[dict[str, Any]]:
"""
Get sample documents from an index.
Args:
index_name: Index name
size: Number of samples
query: Optional query filter
"""
search_body = {
"query": query if query else {"match_all": {}},
"size": size
}
response = self.client.search(index=index_name, body=search_body)
documents = []
for hit in response["hits"]["hits"]:
doc = hit["_source"].copy()
doc["_id"] = hit["_id"]
documents.append(doc)
return documents
def get_document_ids(
self,
index_name: str,
size: int = 1000,
query: dict[str, Any] | None = None,
) -> list[str]:
"""Get list of document IDs."""
search_body = {
"query": query if query else {"match_all": {}},
"size": size,
"_source": False,
}
response = self.client.search(index=index_name, body=search_body)
return [hit["_id"] for hit in response["hits"]["hits"]]
def close(self):
"""Close the ES client connection."""
self.client.close()
logger.info("Elasticsearch connection closed")

View File

@ -0,0 +1,370 @@
"""
RAGFlow-specific migration orchestrator from Elasticsearch to OceanBase.
"""
import logging
import time
from typing import Any, Callable
from rich.console import Console
from rich.progress import (
Progress,
SpinnerColumn,
TextColumn,
BarColumn,
TaskProgressColumn,
TimeRemainingColumn,
)
from .es_client import ESClient
from .ob_client import OBClient
from .schema import RAGFlowSchemaConverter, RAGFlowDataConverter, VECTOR_FIELD_PATTERN
from .progress import ProgressManager, MigrationProgress
from .verify import MigrationVerifier
logger = logging.getLogger(__name__)
console = Console()
class ESToOceanBaseMigrator:
"""
RAGFlow-specific migration orchestrator.
This migrator is designed specifically for RAGFlow's data structure,
handling the fixed schema and vector embeddings correctly.
"""
def __init__(
self,
es_client: ESClient,
ob_client: OBClient,
progress_dir: str = ".migration_progress",
):
"""
Initialize migrator.
Args:
es_client: Elasticsearch client
ob_client: OceanBase client
progress_dir: Directory for progress files
"""
self.es_client = es_client
self.ob_client = ob_client
self.progress_manager = ProgressManager(progress_dir)
self.schema_converter = RAGFlowSchemaConverter()
def migrate(
self,
es_index: str,
ob_table: str,
batch_size: int = 1000,
resume: bool = False,
verify_after: bool = True,
on_progress: Callable[[int, int], None] | None = None,
) -> dict[str, Any]:
"""
Execute full migration from ES to OceanBase for RAGFlow data.
Args:
es_index: Source Elasticsearch index
ob_table: Target OceanBase table
batch_size: Documents per batch
resume: Resume from previous progress
verify_after: Run verification after migration
on_progress: Progress callback (migrated, total)
Returns:
Migration result dictionary
"""
start_time = time.time()
result = {
"success": False,
"es_index": es_index,
"ob_table": ob_table,
"total_documents": 0,
"migrated_documents": 0,
"failed_documents": 0,
"duration_seconds": 0,
"verification": None,
"error": None,
}
progress: MigrationProgress | None = None
try:
# Step 1: Check connections
console.print("[bold blue]Step 1: Checking connections...[/]")
self._check_connections()
# Step 2: Analyze ES index
console.print("\n[bold blue]Step 2: Analyzing ES index...[/]")
analysis = self._analyze_es_index(es_index)
# Auto-detect vector size from ES mapping
vector_size = 768 # Default fallback
if analysis["vector_fields"]:
vector_size = analysis["vector_fields"][0]["dimension"]
console.print(f" [green]Auto-detected vector dimension: {vector_size}[/]")
else:
console.print(f" [yellow]No vector fields found, using default: {vector_size}[/]")
console.print(f" Known RAGFlow fields: {len(analysis['known_fields'])}")
if analysis["unknown_fields"]:
console.print(f" [yellow]Unknown fields (will be stored in 'extra'): {analysis['unknown_fields']}[/]")
# Step 3: Get total document count
total_docs = self.es_client.count_documents(es_index)
console.print(f" Total documents: {total_docs:,}")
result["total_documents"] = total_docs
if total_docs == 0:
console.print("[yellow]No documents to migrate[/]")
result["success"] = True
return result
# Step 4: Handle resume or fresh start
if resume and self.progress_manager.can_resume(es_index, ob_table):
console.print("\n[bold yellow]Resuming from previous progress...[/]")
progress = self.progress_manager.load_progress(es_index, ob_table)
console.print(
f" Previously migrated: {progress.migrated_documents:,} documents"
)
else:
# Fresh start - check if table already exists
if self.ob_client.table_exists(ob_table):
raise RuntimeError(
f"Table '{ob_table}' already exists in OceanBase. "
f"Migration aborted to prevent data conflicts. "
f"Please drop the table manually or use a different table name."
)
progress = self.progress_manager.create_progress(
es_index, ob_table, total_docs
)
# Step 5: Create table if needed
if not progress.table_created:
console.print("\n[bold blue]Step 3: Creating OceanBase table...[/]")
if not self.ob_client.table_exists(ob_table):
self.ob_client.create_ragflow_table(
table_name=ob_table,
vector_size=vector_size,
create_indexes=True,
create_fts_indexes=True,
)
console.print(f" Created table '{ob_table}' with RAGFlow schema")
else:
console.print(f" Table '{ob_table}' already exists")
# Check and add vector column if needed
self.ob_client.add_vector_column(ob_table, vector_size)
progress.table_created = True
progress.indexes_created = True
progress.schema_converted = True
self.progress_manager.save_progress(progress)
# Step 6: Migrate data
console.print("\n[bold blue]Step 4: Migrating data...[/]")
data_converter = RAGFlowDataConverter()
migrated = self._migrate_data(
es_index=es_index,
ob_table=ob_table,
data_converter=data_converter,
progress=progress,
batch_size=batch_size,
on_progress=on_progress,
)
result["migrated_documents"] = migrated
result["failed_documents"] = progress.failed_documents
# Step 7: Mark completed
self.progress_manager.mark_completed(progress)
# Step 8: Verify (optional)
if verify_after:
console.print("\n[bold blue]Step 5: Verifying migration...[/]")
verifier = MigrationVerifier(self.es_client, self.ob_client)
verification = verifier.verify(
es_index, ob_table,
primary_key="id"
)
result["verification"] = {
"passed": verification.passed,
"message": verification.message,
"es_count": verification.es_count,
"ob_count": verification.ob_count,
"sample_match_rate": verification.sample_match_rate,
}
console.print(verifier.generate_report(verification))
result["success"] = True
result["duration_seconds"] = time.time() - start_time
console.print(
f"\n[bold green]Migration completed successfully![/]"
f"\n Total: {result['total_documents']:,} documents"
f"\n Migrated: {result['migrated_documents']:,} documents"
f"\n Failed: {result['failed_documents']:,} documents"
f"\n Duration: {result['duration_seconds']:.1f} seconds"
)
except KeyboardInterrupt:
console.print("\n[bold yellow]Migration interrupted by user[/]")
if progress:
self.progress_manager.mark_paused(progress)
result["error"] = "Interrupted by user"
except Exception as e:
logger.exception("Migration failed")
if progress:
self.progress_manager.mark_failed(progress, str(e))
result["error"] = str(e)
console.print(f"\n[bold red]Migration failed: {e}[/]")
return result
def _check_connections(self):
"""Verify connections to both databases."""
# Check ES
es_health = self.es_client.health_check()
if es_health.get("status") not in ("green", "yellow"):
raise RuntimeError(f"ES cluster unhealthy: {es_health}")
console.print(f" ES cluster status: {es_health.get('status')}")
# Check OceanBase
if not self.ob_client.health_check():
raise RuntimeError("OceanBase connection failed")
ob_version = self.ob_client.get_version()
console.print(f" OceanBase connection: OK (version: {ob_version})")
def _analyze_es_index(self, es_index: str) -> dict[str, Any]:
"""Analyze ES index structure for RAGFlow compatibility."""
es_mapping = self.es_client.get_index_mapping(es_index)
return self.schema_converter.analyze_es_mapping(es_mapping)
def _migrate_data(
self,
es_index: str,
ob_table: str,
data_converter: RAGFlowDataConverter,
progress: MigrationProgress,
batch_size: int,
on_progress: Callable[[int, int], None] | None,
) -> int:
"""Migrate data in batches."""
total = progress.total_documents
migrated = progress.migrated_documents
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TaskProgressColumn(),
TimeRemainingColumn(),
console=console,
) as pbar:
task = pbar.add_task(
"Migrating...",
total=total,
completed=migrated,
)
batch_count = 0
for batch in self.es_client.scroll_documents(es_index, batch_size):
batch_count += 1
# Convert batch to OceanBase format
ob_rows = data_converter.convert_batch(batch)
# Insert batch
try:
inserted = self.ob_client.insert_batch(ob_table, ob_rows)
migrated += inserted
# Update progress
last_ids = [doc.get("_id", doc.get("id", "")) for doc in batch]
self.progress_manager.update_progress(
progress,
migrated_count=inserted,
last_batch_ids=last_ids,
)
# Update progress bar
pbar.update(task, completed=migrated)
# Callback
if on_progress:
on_progress(migrated, total)
# Log periodically
if batch_count % 10 == 0:
logger.info(f"Migrated {migrated:,}/{total:,} documents")
except Exception as e:
logger.error(f"Batch insert failed: {e}")
progress.failed_documents += len(batch)
# Continue with next batch
return migrated
def get_schema_preview(self, es_index: str) -> dict[str, Any]:
"""
Get a preview of schema analysis without executing migration.
Args:
es_index: Elasticsearch index name
Returns:
Schema analysis information
"""
es_mapping = self.es_client.get_index_mapping(es_index)
analysis = self.schema_converter.analyze_es_mapping(es_mapping)
column_defs = self.schema_converter.get_column_definitions()
return {
"es_index": es_index,
"es_mapping": es_mapping,
"analysis": analysis,
"ob_columns": column_defs,
"vector_fields": self.schema_converter.get_vector_fields(),
"total_columns": len(column_defs),
}
def get_data_preview(
self,
es_index: str,
sample_size: int = 5,
kb_id: str | None = None,
) -> list[dict[str, Any]]:
"""
Get sample documents from ES for preview.
Args:
es_index: ES index name
sample_size: Number of samples
kb_id: Optional KB filter
"""
query = None
if kb_id:
query = {"term": {"kb_id": kb_id}}
return self.es_client.get_sample_documents(es_index, sample_size, query=query)
def list_knowledge_bases(self, es_index: str) -> list[str]:
"""
List all knowledge base IDs in an ES index.
Args:
es_index: ES index name
Returns:
List of kb_id values
"""
try:
agg_result = self.es_client.aggregate_field(es_index, "kb_id")
return [bucket["key"] for bucket in agg_result.get("buckets", [])]
except Exception as e:
logger.warning(f"Failed to list knowledge bases: {e}")
return []

View File

@ -0,0 +1,442 @@
"""
OceanBase Client for RAGFlow data migration.
This client is specifically designed for RAGFlow's data structure.
"""
import json
import logging
from typing import Any
from pyobvector import ObVecClient, FtsIndexParam, FtsParser, VECTOR, ARRAY
from sqlalchemy import Column, String, Integer, Float, JSON, Text, text, Double
from sqlalchemy.dialects.mysql import LONGTEXT, TEXT as MYSQL_TEXT
from .schema import RAGFLOW_COLUMNS, ARRAY_COLUMNS, FTS_COLUMNS_TKS
logger = logging.getLogger(__name__)
# Index naming templates (from RAGFlow ob_conn.py)
INDEX_NAME_TEMPLATE = "ix_%s_%s"
FULLTEXT_INDEX_NAME_TEMPLATE = "fts_idx_%s"
VECTOR_INDEX_NAME_TEMPLATE = "%s_idx"
# Columns that need regular indexes
INDEX_COLUMNS = [
"kb_id",
"doc_id",
"available_int",
"knowledge_graph_kwd",
"entity_type_kwd",
"removed_kwd",
]
class OBClient:
"""OceanBase client wrapper for RAGFlow migration operations."""
def __init__(
self,
host: str = "localhost",
port: int = 2881,
user: str = "root",
password: str = "",
database: str = "test",
pool_size: int = 10,
):
"""
Initialize OceanBase client.
Args:
host: OceanBase host address
port: OceanBase port
user: Database user (format: user@tenant for OceanBase)
password: Database password
database: Database name
pool_size: Connection pool size
"""
self.host = host
self.port = port
self.user = user
self.password = password
self.database = database
# Initialize pyobvector client
self.uri = f"{host}:{port}"
self.client = ObVecClient(
uri=self.uri,
user=user,
password=password,
db_name=database,
pool_pre_ping=True,
pool_recycle=3600,
pool_size=pool_size,
)
logger.info(f"Connected to OceanBase at {self.uri}, database: {database}")
def health_check(self) -> bool:
"""Check database connectivity."""
try:
result = self.client.perform_raw_text_sql("SELECT 1 FROM DUAL")
result.fetchone()
return True
except Exception as e:
logger.error(f"OceanBase health check failed: {e}")
return False
def get_version(self) -> str | None:
"""Get OceanBase version."""
try:
result = self.client.perform_raw_text_sql("SELECT OB_VERSION() FROM DUAL")
row = result.fetchone()
return row[0] if row else None
except Exception as e:
logger.warning(f"Failed to get OceanBase version: {e}")
return None
def table_exists(self, table_name: str) -> bool:
"""Check if a table exists."""
try:
return self.client.check_table_exists(table_name)
except Exception:
return False
def create_ragflow_table(
self,
table_name: str,
vector_size: int = 768,
create_indexes: bool = True,
create_fts_indexes: bool = True,
):
"""
Create a RAGFlow-compatible table in OceanBase.
This creates a table with the exact schema that RAGFlow expects,
including all columns, indexes, and vector columns.
Args:
table_name: Name of the table (usually the ES index name)
vector_size: Vector dimension (e.g., 768, 1024, 1536)
create_indexes: Whether to create regular indexes
create_fts_indexes: Whether to create fulltext indexes
"""
# Build column definitions
columns = self._build_ragflow_columns()
# Add vector column
vector_column_name = f"q_{vector_size}_vec"
columns.append(
Column(vector_column_name, VECTOR(vector_size), nullable=True,
comment=f"vector embedding ({vector_size} dimensions)")
)
# Table options (from RAGFlow)
table_options = {
"mysql_charset": "utf8mb4",
"mysql_collate": "utf8mb4_unicode_ci",
"mysql_organization": "heap",
}
# Create table
self.client.create_table(
table_name=table_name,
columns=columns,
**table_options,
)
logger.info(f"Created table: {table_name}")
# Create regular indexes
if create_indexes:
self._create_regular_indexes(table_name)
# Create fulltext indexes
if create_fts_indexes:
self._create_fulltext_indexes(table_name)
# Create vector index
self._create_vector_index(table_name, vector_column_name)
# Refresh metadata
self.client.refresh_metadata([table_name])
def _build_ragflow_columns(self) -> list[Column]:
"""Build SQLAlchemy Column objects for RAGFlow schema."""
columns = []
for col_name, col_def in RAGFLOW_COLUMNS.items():
ob_type = col_def["ob_type"]
nullable = col_def.get("nullable", True)
default = col_def.get("default")
is_primary = col_def.get("is_primary", False)
is_array = col_def.get("is_array", False)
# Parse type and create appropriate Column
col = self._create_column(col_name, ob_type, nullable, default, is_primary, is_array)
columns.append(col)
return columns
def _create_column(
self,
name: str,
ob_type: str,
nullable: bool,
default: Any,
is_primary: bool,
is_array: bool,
) -> Column:
"""Create a SQLAlchemy Column object based on type string."""
# Handle array types
if is_array or ob_type.startswith("ARRAY"):
# Extract inner type
if "String" in ob_type:
inner_type = String(256)
elif "Integer" in ob_type:
inner_type = Integer
else:
inner_type = String(256)
# Nested array (e.g., ARRAY(ARRAY(Integer)))
if ob_type.count("ARRAY") > 1:
return Column(name, ARRAY(ARRAY(inner_type)), nullable=nullable)
else:
return Column(name, ARRAY(inner_type), nullable=nullable)
# Handle String types with length
if ob_type.startswith("String"):
# Extract length: String(256) -> 256
import re
match = re.search(r'\((\d+)\)', ob_type)
length = int(match.group(1)) if match else 256
return Column(
name, String(length),
primary_key=is_primary,
nullable=nullable,
server_default=f"'{default}'" if default else None
)
# Map other types
type_map = {
"Integer": Integer,
"Double": Double,
"Float": Float,
"JSON": JSON,
"LONGTEXT": LONGTEXT,
"TEXT": MYSQL_TEXT,
}
for type_name, type_class in type_map.items():
if type_name in ob_type:
return Column(
name, type_class,
primary_key=is_primary,
nullable=nullable,
server_default=str(default) if default is not None else None
)
# Default to String
return Column(name, String(256), nullable=nullable)
def _create_regular_indexes(self, table_name: str):
"""Create regular indexes for indexed columns."""
for col_name in INDEX_COLUMNS:
index_name = INDEX_NAME_TEMPLATE % (table_name, col_name)
try:
self.client.create_index(
table_name=table_name,
is_vec_index=False,
index_name=index_name,
column_names=[col_name],
)
logger.debug(f"Created index: {index_name}")
except Exception as e:
if "Duplicate" in str(e):
logger.debug(f"Index {index_name} already exists")
else:
logger.warning(f"Failed to create index {index_name}: {e}")
def _create_fulltext_indexes(self, table_name: str):
"""Create fulltext indexes for text columns."""
for fts_column in FTS_COLUMNS_TKS:
col_name = fts_column.split("^")[0] # Remove weight suffix
index_name = FULLTEXT_INDEX_NAME_TEMPLATE % col_name
try:
self.client.create_fts_idx_with_fts_index_param(
table_name=table_name,
fts_idx_param=FtsIndexParam(
index_name=index_name,
field_names=[col_name],
parser_type=FtsParser.IK,
),
)
logger.debug(f"Created fulltext index: {index_name}")
except Exception as e:
if "Duplicate" in str(e):
logger.debug(f"Fulltext index {index_name} already exists")
else:
logger.warning(f"Failed to create fulltext index {index_name}: {e}")
def _create_vector_index(self, table_name: str, vector_column_name: str):
"""Create vector index for embedding column."""
index_name = VECTOR_INDEX_NAME_TEMPLATE % vector_column_name
try:
self.client.create_index(
table_name=table_name,
is_vec_index=True,
index_name=index_name,
column_names=[vector_column_name],
vidx_params="distance=cosine, type=hnsw, lib=vsag",
)
logger.info(f"Created vector index: {index_name}")
except Exception as e:
if "Duplicate" in str(e):
logger.debug(f"Vector index {index_name} already exists")
else:
logger.warning(f"Failed to create vector index {index_name}: {e}")
def add_vector_column(self, table_name: str, vector_size: int):
"""Add a vector column to an existing table."""
vector_column_name = f"q_{vector_size}_vec"
# Check if column exists
if self._column_exists(table_name, vector_column_name):
logger.info(f"Vector column {vector_column_name} already exists")
return
try:
self.client.add_columns(
table_name=table_name,
columns=[Column(vector_column_name, VECTOR(vector_size), nullable=True)],
)
logger.info(f"Added vector column: {vector_column_name}")
# Create index
self._create_vector_index(table_name, vector_column_name)
except Exception as e:
logger.error(f"Failed to add vector column: {e}")
raise
def _column_exists(self, table_name: str, column_name: str) -> bool:
"""Check if a column exists in a table."""
try:
result = self.client.perform_raw_text_sql(
f"SELECT COUNT(*) FROM INFORMATION_SCHEMA.COLUMNS "
f"WHERE TABLE_SCHEMA = '{self.database}' "
f"AND TABLE_NAME = '{table_name}' "
f"AND COLUMN_NAME = '{column_name}'"
)
count = result.fetchone()[0]
return count > 0
except Exception:
return False
def _index_exists(self, table_name: str, index_name: str) -> bool:
"""Check if an index exists."""
try:
result = self.client.perform_raw_text_sql(
f"SELECT COUNT(*) FROM INFORMATION_SCHEMA.STATISTICS "
f"WHERE TABLE_SCHEMA = '{self.database}' "
f"AND TABLE_NAME = '{table_name}' "
f"AND INDEX_NAME = '{index_name}'"
)
count = result.fetchone()[0]
return count > 0
except Exception:
return False
def insert_batch(
self,
table_name: str,
documents: list[dict[str, Any]],
) -> int:
"""
Insert a batch of documents using upsert.
Args:
table_name: Name of the table
documents: List of documents to insert
Returns:
Number of documents inserted
"""
if not documents:
return 0
try:
self.client.upsert(table_name=table_name, data=documents)
return len(documents)
except Exception as e:
logger.error(f"Batch insert failed: {e}")
raise
def count_rows(self, table_name: str, kb_id: str | None = None) -> int:
"""
Count rows in a table.
Args:
table_name: Table name
kb_id: Optional knowledge base ID filter
"""
try:
sql = f"SELECT COUNT(*) FROM `{table_name}`"
if kb_id:
sql += f" WHERE kb_id = '{kb_id}'"
result = self.client.perform_raw_text_sql(sql)
return result.fetchone()[0]
except Exception:
return 0
def get_sample_rows(
self,
table_name: str,
limit: int = 10,
kb_id: str | None = None,
) -> list[dict[str, Any]]:
"""Get sample rows from a table."""
try:
sql = f"SELECT * FROM `{table_name}`"
if kb_id:
sql += f" WHERE kb_id = '{kb_id}'"
sql += f" LIMIT {limit}"
result = self.client.perform_raw_text_sql(sql)
columns = result.keys()
rows = []
for row in result:
rows.append(dict(zip(columns, row)))
return rows
except Exception as e:
logger.error(f"Failed to get sample rows: {e}")
return []
def get_row_by_id(self, table_name: str, doc_id: str) -> dict[str, Any] | None:
"""Get a single row by ID."""
try:
result = self.client.get(table_name=table_name, ids=[doc_id])
row = result.fetchone()
if row:
columns = result.keys()
return dict(zip(columns, row))
return None
except Exception as e:
logger.error(f"Failed to get row: {e}")
return None
def drop_table(self, table_name: str):
"""Drop a table if exists."""
try:
self.client.drop_table_if_exist(table_name)
logger.info(f"Dropped table: {table_name}")
except Exception as e:
logger.warning(f"Failed to drop table: {e}")
def execute_sql(self, sql: str) -> Any:
"""Execute raw SQL."""
return self.client.perform_raw_text_sql(sql)
def close(self):
"""Close the OB client connection."""
self.client.engine.dispose()
logger.info("OceanBase connection closed")

View File

@ -0,0 +1,221 @@
"""
Progress tracking and resume capability for migration.
"""
import json
import logging
import os
from dataclasses import dataclass, field, asdict
from datetime import datetime
from pathlib import Path
from typing import Any
logger = logging.getLogger(__name__)
@dataclass
class MigrationProgress:
"""Migration progress state."""
# Basic info
es_index: str
ob_table: str
started_at: str = ""
updated_at: str = ""
# Progress counters
total_documents: int = 0
migrated_documents: int = 0
failed_documents: int = 0
# State for resume
last_sort_values: list[Any] = field(default_factory=list)
last_batch_ids: list[str] = field(default_factory=list)
# Status
status: str = "pending" # pending, running, completed, failed, paused
error_message: str = ""
# Schema info
schema_converted: bool = False
table_created: bool = False
indexes_created: bool = False
def __post_init__(self):
if not self.started_at:
self.started_at = datetime.utcnow().isoformat()
self.updated_at = datetime.utcnow().isoformat()
class ProgressManager:
"""Manage migration progress persistence."""
def __init__(self, progress_dir: str = ".migration_progress"):
"""
Initialize progress manager.
Args:
progress_dir: Directory to store progress files
"""
self.progress_dir = Path(progress_dir)
self.progress_dir.mkdir(parents=True, exist_ok=True)
def _get_progress_file(self, es_index: str, ob_table: str) -> Path:
"""Get progress file path for a migration."""
filename = f"{es_index}_to_{ob_table}.json"
return self.progress_dir / filename
def load_progress(
self, es_index: str, ob_table: str
) -> MigrationProgress | None:
"""
Load progress from file.
Args:
es_index: Elasticsearch index name
ob_table: OceanBase table name
Returns:
MigrationProgress if exists, None otherwise
"""
progress_file = self._get_progress_file(es_index, ob_table)
if not progress_file.exists():
return None
try:
with open(progress_file, "r") as f:
data = json.load(f)
progress = MigrationProgress(**data)
logger.info(
f"Loaded progress: {progress.migrated_documents}/{progress.total_documents} documents"
)
return progress
except Exception as e:
logger.warning(f"Failed to load progress: {e}")
return None
def save_progress(self, progress: MigrationProgress):
"""
Save progress to file.
Args:
progress: MigrationProgress instance
"""
progress.updated_at = datetime.utcnow().isoformat()
progress_file = self._get_progress_file(progress.es_index, progress.ob_table)
try:
with open(progress_file, "w") as f:
json.dump(asdict(progress), f, indent=2, default=str)
logger.debug(f"Saved progress to {progress_file}")
except Exception as e:
logger.error(f"Failed to save progress: {e}")
def delete_progress(self, es_index: str, ob_table: str):
"""Delete progress file."""
progress_file = self._get_progress_file(es_index, ob_table)
if progress_file.exists():
progress_file.unlink()
logger.info(f"Deleted progress file: {progress_file}")
def create_progress(
self,
es_index: str,
ob_table: str,
total_documents: int,
) -> MigrationProgress:
"""
Create new progress tracker.
Args:
es_index: Elasticsearch index name
ob_table: OceanBase table name
total_documents: Total documents to migrate
Returns:
New MigrationProgress instance
"""
progress = MigrationProgress(
es_index=es_index,
ob_table=ob_table,
total_documents=total_documents,
status="running",
)
self.save_progress(progress)
return progress
def update_progress(
self,
progress: MigrationProgress,
migrated_count: int,
last_sort_values: list[Any] | None = None,
last_batch_ids: list[str] | None = None,
):
"""
Update progress after a batch.
Args:
progress: MigrationProgress instance
migrated_count: Number of documents migrated in this batch
last_sort_values: Sort values for search_after
last_batch_ids: IDs of documents in last batch
"""
progress.migrated_documents += migrated_count
if last_sort_values:
progress.last_sort_values = last_sort_values
if last_batch_ids:
progress.last_batch_ids = last_batch_ids
self.save_progress(progress)
def mark_completed(self, progress: MigrationProgress):
"""Mark migration as completed."""
progress.status = "completed"
progress.updated_at = datetime.utcnow().isoformat()
self.save_progress(progress)
logger.info(
f"Migration completed: {progress.migrated_documents} documents"
)
def mark_failed(self, progress: MigrationProgress, error: str):
"""Mark migration as failed."""
progress.status = "failed"
progress.error_message = error
progress.updated_at = datetime.utcnow().isoformat()
self.save_progress(progress)
logger.error(f"Migration failed: {error}")
def mark_paused(self, progress: MigrationProgress):
"""Mark migration as paused (for resume later)."""
progress.status = "paused"
progress.updated_at = datetime.utcnow().isoformat()
self.save_progress(progress)
logger.info(
f"Migration paused at {progress.migrated_documents}/{progress.total_documents}"
)
def can_resume(self, es_index: str, ob_table: str) -> bool:
"""Check if migration can be resumed."""
progress = self.load_progress(es_index, ob_table)
if not progress:
return False
return progress.status in ("running", "paused", "failed")
def get_resume_info(self, es_index: str, ob_table: str) -> dict[str, Any] | None:
"""Get information needed to resume migration."""
progress = self.load_progress(es_index, ob_table)
if not progress:
return None
return {
"migrated_documents": progress.migrated_documents,
"total_documents": progress.total_documents,
"last_sort_values": progress.last_sort_values,
"last_batch_ids": progress.last_batch_ids,
"schema_converted": progress.schema_converted,
"table_created": progress.table_created,
"indexes_created": progress.indexes_created,
"status": progress.status,
}

View File

@ -0,0 +1,451 @@
"""
RAGFlow-specific schema conversion from Elasticsearch to OceanBase.
This module handles the fixed RAGFlow table structure migration.
RAGFlow uses a predefined schema for both ES and OceanBase.
"""
import json
import logging
import re
from typing import Any
logger = logging.getLogger(__name__)
# RAGFlow fixed column definitions (from rag/utils/ob_conn.py)
# These are the actual columns used by RAGFlow
RAGFLOW_COLUMNS = {
# Primary identifiers
"id": {"ob_type": "String(256)", "nullable": False, "is_primary": True},
"kb_id": {"ob_type": "String(256)", "nullable": False, "index": True},
"doc_id": {"ob_type": "String(256)", "nullable": True, "index": True},
# Document metadata
"docnm_kwd": {"ob_type": "String(256)", "nullable": True}, # document name
"doc_type_kwd": {"ob_type": "String(256)", "nullable": True}, # document type
# Title fields
"title_tks": {"ob_type": "String(256)", "nullable": True}, # title tokens
"title_sm_tks": {"ob_type": "String(256)", "nullable": True}, # fine-grained title tokens
# Content fields
"content_with_weight": {"ob_type": "LONGTEXT", "nullable": True}, # original content
"content_ltks": {"ob_type": "LONGTEXT", "nullable": True}, # long text tokens
"content_sm_ltks": {"ob_type": "LONGTEXT", "nullable": True}, # fine-grained tokens
# Feature fields
"pagerank_fea": {"ob_type": "Integer", "nullable": True}, # page rank priority
# Array fields
"important_kwd": {"ob_type": "ARRAY(String(256))", "nullable": True, "is_array": True}, # keywords
"important_tks": {"ob_type": "TEXT", "nullable": True}, # keyword tokens
"question_kwd": {"ob_type": "ARRAY(String(1024))", "nullable": True, "is_array": True}, # questions
"question_tks": {"ob_type": "TEXT", "nullable": True}, # question tokens
"tag_kwd": {"ob_type": "ARRAY(String(256))", "nullable": True, "is_array": True}, # tags
"tag_feas": {"ob_type": "JSON", "nullable": True, "is_json": True}, # tag features
# Status fields
"available_int": {"ob_type": "Integer", "nullable": False, "default": 1},
# Time fields
"create_time": {"ob_type": "String(19)", "nullable": True},
"create_timestamp_flt": {"ob_type": "Double", "nullable": True},
# Image field
"img_id": {"ob_type": "String(128)", "nullable": True},
# Position fields (arrays)
"position_int": {"ob_type": "ARRAY(ARRAY(Integer))", "nullable": True, "is_array": True},
"page_num_int": {"ob_type": "ARRAY(Integer)", "nullable": True, "is_array": True},
"top_int": {"ob_type": "ARRAY(Integer)", "nullable": True, "is_array": True},
# Knowledge graph fields
"knowledge_graph_kwd": {"ob_type": "String(256)", "nullable": True, "index": True},
"source_id": {"ob_type": "ARRAY(String(256))", "nullable": True, "is_array": True},
"entity_kwd": {"ob_type": "String(256)", "nullable": True},
"entity_type_kwd": {"ob_type": "String(256)", "nullable": True, "index": True},
"from_entity_kwd": {"ob_type": "String(256)", "nullable": True},
"to_entity_kwd": {"ob_type": "String(256)", "nullable": True},
"weight_int": {"ob_type": "Integer", "nullable": True},
"weight_flt": {"ob_type": "Double", "nullable": True},
"entities_kwd": {"ob_type": "ARRAY(String(256))", "nullable": True, "is_array": True},
"rank_flt": {"ob_type": "Double", "nullable": True},
# Status
"removed_kwd": {"ob_type": "String(256)", "nullable": True, "index": True, "default": "N"},
# JSON fields
"metadata": {"ob_type": "JSON", "nullable": True, "is_json": True},
"extra": {"ob_type": "JSON", "nullable": True, "is_json": True},
# New columns
"_order_id": {"ob_type": "Integer", "nullable": True},
"group_id": {"ob_type": "String(256)", "nullable": True},
"mom_id": {"ob_type": "String(256)", "nullable": True},
}
# Array column names for special handling
ARRAY_COLUMNS = [
"important_kwd", "question_kwd", "tag_kwd", "source_id",
"entities_kwd", "position_int", "page_num_int", "top_int"
]
# JSON column names
JSON_COLUMNS = ["tag_feas", "metadata", "extra"]
# Fulltext search columns (for reference)
FTS_COLUMNS_ORIGIN = ["docnm_kwd", "content_with_weight", "important_tks", "question_tks"]
FTS_COLUMNS_TKS = ["title_tks", "title_sm_tks", "important_tks", "question_tks", "content_ltks", "content_sm_ltks"]
# Vector field pattern: q_{vector_size}_vec
VECTOR_FIELD_PATTERN = re.compile(r"q_(?P<vector_size>\d+)_vec")
class RAGFlowSchemaConverter:
"""
Convert RAGFlow Elasticsearch documents to OceanBase format.
RAGFlow uses a fixed schema, so this converter knows exactly
what fields to expect and how to map them.
"""
def __init__(self):
self.vector_fields: list[dict[str, Any]] = []
self.detected_vector_size: int | None = None
def analyze_es_mapping(self, es_mapping: dict[str, Any]) -> dict[str, Any]:
"""
Analyze ES mapping to extract vector field dimensions.
Args:
es_mapping: Elasticsearch index mapping
Returns:
Analysis result with detected fields
"""
result = {
"known_fields": [],
"vector_fields": [],
"unknown_fields": [],
}
properties = es_mapping.get("properties", {})
for field_name, field_def in properties.items():
# Check if it's a known RAGFlow field
if field_name in RAGFLOW_COLUMNS:
result["known_fields"].append(field_name)
# Check if it's a vector field
elif VECTOR_FIELD_PATTERN.match(field_name):
match = VECTOR_FIELD_PATTERN.match(field_name)
vec_size = int(match.group("vector_size"))
result["vector_fields"].append({
"name": field_name,
"dimension": vec_size,
})
self.vector_fields.append({
"name": field_name,
"dimension": vec_size,
})
if self.detected_vector_size is None:
self.detected_vector_size = vec_size
else:
# Unknown field - might be custom field stored in 'extra'
result["unknown_fields"].append(field_name)
logger.info(
f"Analyzed ES mapping: {len(result['known_fields'])} known fields, "
f"{len(result['vector_fields'])} vector fields, "
f"{len(result['unknown_fields'])} unknown fields"
)
return result
def get_column_definitions(self) -> list[dict[str, Any]]:
"""
Get RAGFlow column definitions for OceanBase table creation.
Returns:
List of column definitions
"""
columns = []
for col_name, col_def in RAGFLOW_COLUMNS.items():
columns.append({
"name": col_name,
"ob_type": col_def["ob_type"],
"nullable": col_def.get("nullable", True),
"is_primary": col_def.get("is_primary", False),
"index": col_def.get("index", False),
"is_array": col_def.get("is_array", False),
"is_json": col_def.get("is_json", False),
"default": col_def.get("default"),
})
# Add detected vector fields
for vec_field in self.vector_fields:
columns.append({
"name": vec_field["name"],
"ob_type": f"VECTOR({vec_field['dimension']})",
"nullable": True,
"is_vector": True,
"dimension": vec_field["dimension"],
})
return columns
def get_vector_fields(self) -> list[dict[str, Any]]:
"""Get list of vector fields for index creation."""
return self.vector_fields
class RAGFlowDataConverter:
"""
Convert RAGFlow ES documents to OceanBase row format.
This converter handles the specific data transformations needed
for RAGFlow's data structure.
"""
def __init__(self):
"""Initialize data converter."""
self.vector_fields: set[str] = set()
def detect_vector_fields(self, doc: dict[str, Any]) -> None:
"""Detect vector fields from a sample document."""
for key in doc.keys():
if VECTOR_FIELD_PATTERN.match(key):
self.vector_fields.add(key)
def convert_document(self, es_doc: dict[str, Any]) -> dict[str, Any]:
"""
Convert an ES document to OceanBase row format.
Args:
es_doc: Elasticsearch document (with _id and _source)
Returns:
Dictionary ready for OceanBase insertion
"""
# Extract _id and _source
doc_id = es_doc.get("_id")
source = es_doc.get("_source", es_doc)
row = {}
# Set document ID
if doc_id:
row["id"] = str(doc_id)
elif "id" in source:
row["id"] = str(source["id"])
# Process each field
for field_name, field_def in RAGFLOW_COLUMNS.items():
if field_name == "id":
continue # Already handled
value = source.get(field_name)
if value is None:
# Use default if available
default = field_def.get("default")
if default is not None:
row[field_name] = default
continue
# Convert based on field type
row[field_name] = self._convert_field_value(
field_name, value, field_def
)
# Handle vector fields
for key, value in source.items():
if VECTOR_FIELD_PATTERN.match(key):
if isinstance(value, list):
row[key] = value
self.vector_fields.add(key)
# Handle unknown fields -> store in 'extra'
extra_fields = {}
for key, value in source.items():
if key not in RAGFLOW_COLUMNS and not VECTOR_FIELD_PATTERN.match(key):
extra_fields[key] = value
if extra_fields:
existing_extra = row.get("extra")
if existing_extra and isinstance(existing_extra, dict):
existing_extra.update(extra_fields)
else:
row["extra"] = json.dumps(extra_fields, ensure_ascii=False)
return row
def _convert_field_value(
self,
field_name: str,
value: Any,
field_def: dict[str, Any]
) -> Any:
"""
Convert a field value to the appropriate format for OceanBase.
Args:
field_name: Field name
value: Original value from ES
field_def: Field definition from RAGFLOW_COLUMNS
Returns:
Converted value
"""
if value is None:
return None
ob_type = field_def.get("ob_type", "")
is_array = field_def.get("is_array", False)
is_json = field_def.get("is_json", False)
# Handle array fields
if is_array:
return self._convert_array_value(value)
# Handle JSON fields
if is_json:
return self._convert_json_value(value)
# Handle specific types
if "Integer" in ob_type:
return self._convert_integer(value)
if "Double" in ob_type or "Float" in ob_type:
return self._convert_float(value)
if "LONGTEXT" in ob_type or "TEXT" in ob_type:
return self._convert_text(value)
if "String" in ob_type:
return self._convert_string(value, field_name)
# Default: convert to string
return str(value) if value is not None else None
def _convert_array_value(self, value: Any) -> str | None:
"""Convert array value to JSON string for OceanBase."""
if value is None:
return None
if isinstance(value, str):
# Already a JSON string
try:
# Validate it's valid JSON
json.loads(value)
return value
except json.JSONDecodeError:
# Not valid JSON, wrap in array
return json.dumps([value], ensure_ascii=False)
if isinstance(value, list):
# Clean array values
cleaned = []
for item in value:
if isinstance(item, str):
# Clean special characters
cleaned_str = item.strip()
cleaned_str = cleaned_str.replace('\\', '\\\\')
cleaned_str = cleaned_str.replace('\n', '\\n')
cleaned_str = cleaned_str.replace('\r', '\\r')
cleaned_str = cleaned_str.replace('\t', '\\t')
cleaned.append(cleaned_str)
else:
cleaned.append(item)
return json.dumps(cleaned, ensure_ascii=False)
# Single value - wrap in array
return json.dumps([value], ensure_ascii=False)
def _convert_json_value(self, value: Any) -> str | None:
"""Convert JSON value to string for OceanBase."""
if value is None:
return None
if isinstance(value, str):
# Already a string, validate JSON
try:
json.loads(value)
return value
except json.JSONDecodeError:
# Not valid JSON, return as-is
return value
if isinstance(value, (dict, list)):
return json.dumps(value, ensure_ascii=False)
return str(value)
def _convert_integer(self, value: Any) -> int | None:
"""Convert to integer."""
if value is None:
return None
if isinstance(value, bool):
return 1 if value else 0
try:
return int(value)
except (ValueError, TypeError):
return None
def _convert_float(self, value: Any) -> float | None:
"""Convert to float."""
if value is None:
return None
try:
return float(value)
except (ValueError, TypeError):
return None
def _convert_text(self, value: Any) -> str | None:
"""Convert to text/longtext."""
if value is None:
return None
if isinstance(value, dict):
# content_with_weight might be stored as dict
return json.dumps(value, ensure_ascii=False)
if isinstance(value, list):
return json.dumps(value, ensure_ascii=False)
return str(value)
def _convert_string(self, value: Any, field_name: str) -> str | None:
"""Convert to string with length considerations."""
if value is None:
return None
# Handle kb_id which might be a list in ES
if field_name == "kb_id" and isinstance(value, list):
return str(value[0]) if value else None
if isinstance(value, (dict, list)):
return json.dumps(value, ensure_ascii=False)
return str(value)
def convert_batch(self, es_docs: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""
Convert a batch of ES documents.
Args:
es_docs: List of Elasticsearch documents
Returns:
List of dictionaries ready for OceanBase insertion
"""
return [self.convert_document(doc) for doc in es_docs]
# Backwards compatibility aliases
SchemaConverter = RAGFlowSchemaConverter
DataConverter = RAGFlowDataConverter

View File

@ -0,0 +1,349 @@
"""
Data verification for RAGFlow migration.
"""
import json
import logging
from dataclasses import dataclass, field
from typing import Any
from .es_client import ESClient
from .ob_client import OBClient
from .schema import RAGFLOW_COLUMNS, ARRAY_COLUMNS, JSON_COLUMNS
logger = logging.getLogger(__name__)
@dataclass
class VerificationResult:
"""Migration verification result."""
es_index: str
ob_table: str
# Counts
es_count: int = 0
ob_count: int = 0
count_match: bool = False
count_diff: int = 0
# Sample verification
sample_size: int = 0
samples_verified: int = 0
samples_matched: int = 0
sample_match_rate: float = 0.0
# Mismatches
missing_in_ob: list[str] = field(default_factory=list)
data_mismatches: list[dict[str, Any]] = field(default_factory=list)
# Overall
passed: bool = False
message: str = ""
class MigrationVerifier:
"""Verify RAGFlow migration data consistency."""
# Fields to compare for verification
VERIFY_FIELDS = [
"id", "kb_id", "doc_id", "docnm_kwd", "content_with_weight",
"available_int", "create_time",
]
def __init__(
self,
es_client: ESClient,
ob_client: OBClient,
):
"""
Initialize verifier.
Args:
es_client: Elasticsearch client
ob_client: OceanBase client
"""
self.es_client = es_client
self.ob_client = ob_client
def verify(
self,
es_index: str,
ob_table: str,
sample_size: int = 100,
primary_key: str = "id",
verify_fields: list[str] | None = None,
) -> VerificationResult:
"""
Verify migration by comparing ES and OceanBase data.
Args:
es_index: Elasticsearch index name
ob_table: OceanBase table name
sample_size: Number of documents to sample for verification
primary_key: Primary key column name
verify_fields: Fields to verify (None = use defaults)
Returns:
VerificationResult with details
"""
result = VerificationResult(
es_index=es_index,
ob_table=ob_table,
)
if verify_fields is None:
verify_fields = self.VERIFY_FIELDS
# Step 1: Verify document counts
logger.info("Verifying document counts...")
result.es_count = self.es_client.count_documents(es_index)
result.ob_count = self.ob_client.count_rows(ob_table)
result.count_diff = abs(result.es_count - result.ob_count)
result.count_match = result.count_diff == 0
logger.info(
f"Document counts - ES: {result.es_count}, OB: {result.ob_count}, "
f"Diff: {result.count_diff}"
)
# Step 2: Sample verification
result.sample_size = min(sample_size, result.es_count)
if result.sample_size > 0:
logger.info(f"Verifying {result.sample_size} sample documents...")
self._verify_samples(
es_index, ob_table, result, primary_key, verify_fields
)
# Step 3: Determine overall result
self._determine_result(result)
logger.info(result.message)
return result
def _verify_samples(
self,
es_index: str,
ob_table: str,
result: VerificationResult,
primary_key: str,
verify_fields: list[str],
):
"""Verify sample documents."""
# Get sample documents from ES
es_samples = self.es_client.get_sample_documents(
es_index, result.sample_size
)
for es_doc in es_samples:
result.samples_verified += 1
doc_id = es_doc.get("_id") or es_doc.get("id")
if not doc_id:
logger.warning("Document without ID found")
continue
# Get corresponding document from OceanBase
ob_doc = self.ob_client.get_row_by_id(ob_table, doc_id)
if ob_doc is None:
result.missing_in_ob.append(doc_id)
continue
# Compare documents
match, differences = self._compare_documents(
es_doc, ob_doc, verify_fields
)
if match:
result.samples_matched += 1
else:
result.data_mismatches.append({
"id": doc_id,
"differences": differences,
})
# Calculate match rate
if result.samples_verified > 0:
result.sample_match_rate = result.samples_matched / result.samples_verified
def _compare_documents(
self,
es_doc: dict[str, Any],
ob_doc: dict[str, Any],
verify_fields: list[str],
) -> tuple[bool, list[dict[str, Any]]]:
"""
Compare ES document with OceanBase row.
Returns:
Tuple of (match: bool, differences: list)
"""
differences = []
for field_name in verify_fields:
es_value = es_doc.get(field_name)
ob_value = ob_doc.get(field_name)
# Skip if both are None/null
if es_value is None and ob_value is None:
continue
# Handle special comparisons
if not self._values_equal(field_name, es_value, ob_value):
differences.append({
"field": field_name,
"es_value": es_value,
"ob_value": ob_value,
})
return len(differences) == 0, differences
def _values_equal(
self,
field_name: str,
es_value: Any,
ob_value: Any
) -> bool:
"""Compare two values with type-aware logic."""
if es_value is None and ob_value is None:
return True
if es_value is None or ob_value is None:
# One is None, the other isn't
# For optional fields, this might be acceptable
return False
# Handle array fields (stored as JSON strings in OB)
if field_name in ARRAY_COLUMNS:
if isinstance(ob_value, str):
try:
ob_value = json.loads(ob_value)
except json.JSONDecodeError:
pass
if isinstance(es_value, list) and isinstance(ob_value, list):
return set(str(x) for x in es_value) == set(str(x) for x in ob_value)
# Handle JSON fields
if field_name in JSON_COLUMNS:
if isinstance(ob_value, str):
try:
ob_value = json.loads(ob_value)
except json.JSONDecodeError:
pass
if isinstance(es_value, str):
try:
es_value = json.loads(es_value)
except json.JSONDecodeError:
pass
return es_value == ob_value
# Handle content_with_weight which might be dict or string
if field_name == "content_with_weight":
if isinstance(ob_value, str) and isinstance(es_value, dict):
try:
ob_value = json.loads(ob_value)
except json.JSONDecodeError:
pass
# Handle kb_id which might be list in ES
if field_name == "kb_id":
if isinstance(es_value, list) and len(es_value) > 0:
es_value = es_value[0]
# Standard comparison
return str(es_value) == str(ob_value)
def _determine_result(self, result: VerificationResult):
"""Determine overall verification result."""
# Allow small count differences (e.g., documents added during migration)
count_tolerance = 0.01 # 1% tolerance
count_ok = (
result.count_match or
(result.es_count > 0 and result.count_diff / result.es_count <= count_tolerance)
)
if count_ok and result.sample_match_rate >= 0.99:
result.passed = True
result.message = (
f"Verification PASSED. "
f"ES: {result.es_count:,}, OB: {result.ob_count:,}. "
f"Sample match rate: {result.sample_match_rate:.2%}"
)
elif count_ok and result.sample_match_rate >= 0.95:
result.passed = True
result.message = (
f"Verification PASSED with warnings. "
f"ES: {result.es_count:,}, OB: {result.ob_count:,}. "
f"Sample match rate: {result.sample_match_rate:.2%}"
)
else:
result.passed = False
issues = []
if not count_ok:
issues.append(
f"Count mismatch (ES: {result.es_count}, OB: {result.ob_count}, diff: {result.count_diff})"
)
if result.sample_match_rate < 0.95:
issues.append(f"Low sample match rate: {result.sample_match_rate:.2%}")
if result.missing_in_ob:
issues.append(f"{len(result.missing_in_ob)} documents missing in OB")
result.message = f"Verification FAILED: {'; '.join(issues)}"
def generate_report(self, result: VerificationResult) -> str:
"""Generate a verification report."""
lines = [
"",
"=" * 60,
"Migration Verification Report",
"=" * 60,
f"ES Index: {result.es_index}",
f"OB Table: {result.ob_table}",
]
lines.extend([
"",
"Document Counts:",
f" Elasticsearch: {result.es_count:,}",
f" OceanBase: {result.ob_count:,}",
f" Difference: {result.count_diff:,}",
f" Match: {'Yes' if result.count_match else 'No'}",
"",
"Sample Verification:",
f" Sample Size: {result.sample_size}",
f" Verified: {result.samples_verified}",
f" Matched: {result.samples_matched}",
f" Match Rate: {result.sample_match_rate:.2%}",
"",
])
if result.missing_in_ob:
lines.append(f"Missing in OceanBase ({len(result.missing_in_ob)}):")
for doc_id in result.missing_in_ob[:5]:
lines.append(f" - {doc_id}")
if len(result.missing_in_ob) > 5:
lines.append(f" ... and {len(result.missing_in_ob) - 5} more")
lines.append("")
if result.data_mismatches:
lines.append(f"Data Mismatches ({len(result.data_mismatches)}):")
for mismatch in result.data_mismatches[:3]:
lines.append(f" - ID: {mismatch['id']}")
for diff in mismatch.get("differences", [])[:2]:
lines.append(f" {diff['field']}: ES={diff['es_value']}, OB={diff['ob_value']}")
if len(result.data_mismatches) > 3:
lines.append(f" ... and {len(result.data_mismatches) - 3} more")
lines.append("")
lines.extend([
"=" * 60,
f"Result: {'PASSED' if result.passed else 'FAILED'}",
result.message,
"=" * 60,
"",
])
return "\n".join(lines)