mirror of
https://github.com/haris-musa/excel-mcp-server.git
synced 2025-12-08 17:12:41 +08:00
272 lines
9.9 KiB
Python
272 lines
9.9 KiB
Python
from typing import Any
|
|
import uuid
|
|
import logging
|
|
|
|
from openpyxl import load_workbook
|
|
from openpyxl.utils import get_column_letter
|
|
from openpyxl.worksheet.table import Table, TableStyleInfo
|
|
from openpyxl.styles import Font
|
|
|
|
from .data import read_excel_range
|
|
from .cell_utils import parse_cell_range
|
|
from .exceptions import ValidationError, PivotError
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
def create_pivot_table(
|
|
filepath: str,
|
|
sheet_name: str,
|
|
data_range: str,
|
|
rows: list[str],
|
|
values: list[str],
|
|
columns: list[str] | None = None,
|
|
agg_func: str = "sum"
|
|
) -> dict[str, Any]:
|
|
"""Create pivot table in sheet using Excel table functionality
|
|
|
|
Args:
|
|
filepath: Path to Excel file
|
|
sheet_name: Name of worksheet containing source data
|
|
data_range: Source data range reference
|
|
target_cell: Cell reference for pivot table position
|
|
rows: Fields for row labels
|
|
values: Fields for values
|
|
columns: Optional fields for column labels
|
|
agg_func: Aggregation function (sum, count, average, max, min)
|
|
|
|
Returns:
|
|
Dictionary with status message and pivot table dimensions
|
|
"""
|
|
try:
|
|
wb = load_workbook(filepath)
|
|
if sheet_name not in wb.sheetnames:
|
|
raise ValidationError(f"Sheet '{sheet_name}' not found")
|
|
|
|
# Parse ranges
|
|
if ':' not in data_range:
|
|
raise ValidationError("Data range must be in format 'A1:B2'")
|
|
|
|
try:
|
|
start_cell, end_cell = data_range.split(':')
|
|
start_row, start_col, end_row, end_col = parse_cell_range(start_cell, end_cell)
|
|
except ValueError as e:
|
|
raise ValidationError(f"Invalid data range format: {str(e)}")
|
|
|
|
if end_row is None or end_col is None:
|
|
raise ValidationError("Invalid data range format: missing end coordinates")
|
|
|
|
# Create range string
|
|
data_range_str = f"{get_column_letter(start_col)}{start_row}:{get_column_letter(end_col)}{end_row}"
|
|
|
|
# Read source data
|
|
try:
|
|
data = read_excel_range(filepath, sheet_name, start_cell, end_cell)
|
|
if not data:
|
|
raise PivotError("No data found in range")
|
|
except Exception as e:
|
|
raise PivotError(f"Failed to read source data: {str(e)}")
|
|
|
|
# Validate aggregation function
|
|
valid_agg_funcs = ["sum", "average", "count", "min", "max"]
|
|
if agg_func.lower() not in valid_agg_funcs:
|
|
raise ValidationError(
|
|
f"Invalid aggregation function. Must be one of: {', '.join(valid_agg_funcs)}"
|
|
)
|
|
|
|
# Clean up field names by removing aggregation suffixes
|
|
def clean_field_name(field: str) -> str:
|
|
field = str(field).strip()
|
|
for suffix in [" (sum)", " (average)", " (count)", " (min)", " (max)"]:
|
|
if field.lower().endswith(suffix):
|
|
return field[:-len(suffix)]
|
|
return field
|
|
|
|
# Validate field names exist in data
|
|
if data:
|
|
first_row = data[0]
|
|
available_fields = {clean_field_name(str(header)).lower() for header in first_row.keys()}
|
|
|
|
for field_list, field_type in [(rows, "row"), (values, "value")]:
|
|
for field in field_list:
|
|
if clean_field_name(str(field)).lower() not in available_fields:
|
|
raise ValidationError(
|
|
f"Invalid {field_type} field '{field}'. "
|
|
f"Available fields: {', '.join(sorted(available_fields))}"
|
|
)
|
|
|
|
if columns:
|
|
for field in columns:
|
|
if clean_field_name(str(field)).lower() not in available_fields:
|
|
raise ValidationError(
|
|
f"Invalid column field '{field}'. "
|
|
f"Available fields: {', '.join(sorted(available_fields))}"
|
|
)
|
|
|
|
# Skip header row if it matches our fields
|
|
if all(
|
|
any(clean_field_name(str(header)).lower() == clean_field_name(str(field)).lower()
|
|
for field in rows + values)
|
|
for header in first_row.keys()
|
|
):
|
|
data = data[1:]
|
|
|
|
# Clean up row and value field names
|
|
cleaned_rows = [clean_field_name(field) for field in rows]
|
|
cleaned_values = [clean_field_name(field) for field in values]
|
|
|
|
# Create pivot sheet
|
|
pivot_sheet_name = f"{sheet_name}_pivot"
|
|
if pivot_sheet_name in wb.sheetnames:
|
|
wb.remove(wb[pivot_sheet_name])
|
|
pivot_ws = wb.create_sheet(pivot_sheet_name)
|
|
|
|
# Write headers
|
|
current_row = 1
|
|
current_col = 1
|
|
|
|
# Write row field headers
|
|
for field in cleaned_rows:
|
|
cell = pivot_ws.cell(row=current_row, column=current_col, value=field)
|
|
cell.font = Font(bold=True)
|
|
current_col += 1
|
|
|
|
# Write value field headers
|
|
for field in cleaned_values:
|
|
cell = pivot_ws.cell(row=current_row, column=current_col, value=f"{field} ({agg_func})")
|
|
cell.font = Font(bold=True)
|
|
current_col += 1
|
|
|
|
# Get unique values for each row field
|
|
field_values = {}
|
|
for field in cleaned_rows:
|
|
all_values = []
|
|
for record in data:
|
|
value = str(record.get(field, ''))
|
|
all_values.append(value)
|
|
field_values[field] = sorted(set(all_values))
|
|
|
|
# Generate all combinations of row field values
|
|
row_combinations = _get_combinations(field_values)
|
|
|
|
# Calculate table dimensions for formatting
|
|
total_rows = len(row_combinations) + 1 # +1 for header
|
|
total_cols = len(cleaned_rows) + len(cleaned_values)
|
|
|
|
# Write data rows
|
|
current_row = 2
|
|
for combo in row_combinations:
|
|
# Write row field values
|
|
col = 1
|
|
for field in cleaned_rows:
|
|
pivot_ws.cell(row=current_row, column=col, value=combo[field])
|
|
col += 1
|
|
|
|
# Filter data for current combination
|
|
filtered_data = _filter_data(data, combo, {})
|
|
|
|
# Calculate and write aggregated values
|
|
for value_field in cleaned_values:
|
|
try:
|
|
value = _aggregate_values(filtered_data, value_field, agg_func)
|
|
pivot_ws.cell(row=current_row, column=col, value=value)
|
|
except Exception as e:
|
|
raise PivotError(f"Failed to aggregate values for field '{value_field}': {str(e)}")
|
|
col += 1
|
|
|
|
current_row += 1
|
|
|
|
# Create a table for the pivot data
|
|
try:
|
|
pivot_range = f"A1:{get_column_letter(total_cols)}{total_rows}"
|
|
pivot_table = Table(
|
|
displayName=f"PivotTable_{uuid.uuid4().hex[:8]}",
|
|
ref=pivot_range
|
|
)
|
|
style = TableStyleInfo(
|
|
name="TableStyleMedium9",
|
|
showFirstColumn=False,
|
|
showLastColumn=False,
|
|
showRowStripes=True,
|
|
showColumnStripes=True
|
|
)
|
|
pivot_table.tableStyleInfo = style
|
|
pivot_ws.add_table(pivot_table)
|
|
except Exception as e:
|
|
raise PivotError(f"Failed to create pivot table formatting: {str(e)}")
|
|
|
|
try:
|
|
wb.save(filepath)
|
|
except Exception as e:
|
|
raise PivotError(f"Failed to save workbook: {str(e)}")
|
|
|
|
return {
|
|
"message": "Summary table created successfully",
|
|
"details": {
|
|
"source_range": data_range_str,
|
|
"pivot_sheet": pivot_sheet_name,
|
|
"rows": cleaned_rows,
|
|
"columns": columns or [],
|
|
"values": cleaned_values,
|
|
"aggregation": agg_func
|
|
}
|
|
}
|
|
|
|
except (ValidationError, PivotError) as e:
|
|
logger.error(str(e))
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Failed to create pivot table: {e}")
|
|
raise PivotError(str(e))
|
|
|
|
|
|
def _get_combinations(field_values: dict[str, set]) -> list[dict]:
|
|
"""Get all combinations of field values."""
|
|
result = [{}]
|
|
for field, values in list(field_values.items()): # Convert to list to avoid runtime changes
|
|
new_result = []
|
|
for combo in result:
|
|
for value in sorted(values): # Sort for consistent ordering
|
|
new_combo = combo.copy()
|
|
new_combo[field] = value
|
|
new_result.append(new_combo)
|
|
result = new_result
|
|
return result
|
|
|
|
|
|
def _filter_data(data: list[dict], row_filters: dict, col_filters: dict) -> list[dict]:
|
|
"""Filter data based on row and column filters."""
|
|
result = []
|
|
for record in data:
|
|
matches = True
|
|
for field, value in row_filters.items():
|
|
if record.get(field) != value:
|
|
matches = False
|
|
break
|
|
for field, value in col_filters.items():
|
|
if record.get(field) != value:
|
|
matches = False
|
|
break
|
|
if matches:
|
|
result.append(record)
|
|
return result
|
|
|
|
|
|
def _aggregate_values(data: list[dict], field: str, agg_func: str) -> float:
|
|
"""Aggregate values using the specified function."""
|
|
values = [record[field] for record in data if field in record and isinstance(record[field], (int, float))]
|
|
if not values:
|
|
return 0
|
|
|
|
if agg_func == "sum":
|
|
return sum(values)
|
|
elif agg_func == "average":
|
|
return sum(values) / len(values)
|
|
elif agg_func == "count":
|
|
return len(values)
|
|
elif agg_func == "min":
|
|
return min(values)
|
|
elif agg_func == "max":
|
|
return max(values)
|
|
else:
|
|
return sum(values) # Default to sum
|