diff --git a/src/excel_mcp/chart.py b/src/excel_mcp/chart.py index 20b4c66..e92fd6f 100644 --- a/src/excel_mcp/chart.py +++ b/src/excel_mcp/chart.py @@ -66,6 +66,12 @@ def create_chart_in_sheet( style: Optional[Dict] = None ) -> dict[str, Any]: """Create chart in sheet with enhanced styling options""" + # Ensure style dict exists and defaults to showing data labels + if style is None: + style = {"show_data_labels": True} + else: + # If caller omitted the flag, default to True + style.setdefault("show_data_labels", True) try: wb = load_workbook(filepath) if sheet_name not in wb.sheetnames: @@ -165,22 +171,36 @@ def create_chart_in_sheet( # Apply style if provided try: - if style: - if style.get("show_legend", True): - chart.legend = Legend() - chart.legend.position = style.get("legend_position", "r") - else: - chart.legend = None + if style.get("show_legend", True): + chart.legend = Legend() + chart.legend.position = style.get("legend_position", "r") + else: + chart.legend = None - if style.get("show_data_labels", False): - chart.dataLabels = DataLabelList() - chart.dataLabels.showVal = True + if style.get("show_data_labels", False): + data_labels = DataLabelList() + # Gather optional overrides + dlo = style.get("data_label_options", {}) if isinstance(style.get("data_label_options", {}), dict) else {} - if style.get("grid_lines", False): - if hasattr(chart, "x_axis"): - chart.x_axis.majorGridlines = ChartLines() - if hasattr(chart, "y_axis"): - chart.y_axis.majorGridlines = ChartLines() + # Helper to read bool with fallback + def _opt(name: str, default: bool) -> bool: + return bool(dlo.get(name, default)) + + # Apply options – Excel will concatenate any that are set to True + data_labels.showVal = _opt("show_val", True) + data_labels.showCatName = _opt("show_cat_name", False) + data_labels.showSerName = _opt("show_ser_name", False) + data_labels.showLegendKey = _opt("show_legend_key", False) + data_labels.showPercent = _opt("show_percent", False) + data_labels.showBubbleSize = _opt("show_bubble_size", False) + + chart.dataLabels = data_labels + + if style.get("grid_lines", False): + if hasattr(chart, "x_axis"): + chart.x_axis.majorGridlines = ChartLines() + if hasattr(chart, "y_axis"): + chart.y_axis.majorGridlines = ChartLines() except Exception as e: logger.error(f"Failed to apply chart style: {e}") raise ChartError(f"Failed to apply chart style: {str(e)}")