Source code for evo_gafs.visualization.plots

"""Matplotlib visualisations for analysing a GA run.

``matplotlib`` is an optional dependency (install with ``pip install evo-gafs[viz]``)
and is imported lazily so the core package stays lightweight.
"""

from __future__ import annotations

import numpy as np

from evo_gafs.core.config import SelectionResult


def _require_matplotlib():
    try:
        import matplotlib.pyplot as plt
    except ImportError as exc:  # pragma: no cover - depends on optional extra
        raise ImportError(
            "matplotlib is required for plotting. Install it with `pip install evo-gafs[viz]`."
        ) from exc
    return plt


[docs] class GAPlotter: """Static helpers to visualise evolution curves and Pareto fronts."""
[docs] @staticmethod def plot_evolution( result: SelectionResult, figsize: tuple[float, float] = (14, 5), title_prefix: str = "", ): """Plot fitness and feature-count evolution over generations.""" plt = _require_matplotlib() history = result.history gens = [s.generation for s in history] best_fit = [s.best_fitness for s in history] mean_fit = [s.mean_fitness for s in history] std_fit = [s.std_fitness for s in history] best_feats = [s.best_n_features for s in history] mean_feats = [s.mean_n_features for s in history] fig, axes = plt.subplots(1, 2, figsize=figsize) prefix = f"{title_prefix} - " if title_prefix else "" ax1 = axes[0] ax1.plot(gens, best_fit, "b-", linewidth=2, label="Best fitness") ax1.plot(gens, mean_fit, "g--", linewidth=1.5, label="Mean fitness") ax1.fill_between( gens, [m - s for m, s in zip(mean_fit, std_fit)], [m + s for m, s in zip(mean_fit, std_fit)], alpha=0.2, color="green", label="±1 std", ) ax1.set_xlabel("Generation") ax1.set_ylabel("Fitness") ax1.set_title(f"{prefix}Fitness evolution") ax1.legend() ax1.grid(True, alpha=0.3) ax2 = axes[1] ax2.plot(gens, best_feats, "r-", linewidth=2, label="Features (best ind.)") ax2.plot(gens, mean_feats, "m--", linewidth=1.5, label="Features (mean)") ax2.axhline( y=result.n_selected, color="darkred", linestyle=":", label=f"Final selected: {result.n_selected}", ) ax2.set_xlabel("Generation") ax2.set_ylabel("Number of features") ax2.set_title(f"{prefix}Selected-feature evolution") ax2.legend() ax2.grid(True, alpha=0.3) fig.tight_layout() return fig
[docs] @staticmethod def plot_pareto_front( result: SelectionResult, figsize: tuple[float, float] = (8, 6), title: str = "Pareto front", ): """Plot the Pareto front (multi-objective mode only).""" plt = _require_matplotlib() if result.pareto_front is None: raise ValueError("No Pareto front available (single-objective mode).") front = result.pareto_front cv_scores = [p["cv_score"] for p in front] n_feats = [p["n_features"] for p in front] compressions = [p["compression"] for p in front] fig, ax = plt.subplots(figsize=figsize) scatter = ax.scatter( n_feats, cv_scores, c=compressions, cmap="RdYlGn", s=100, edgecolors="black", linewidths=0.5, zorder=5, ) fig.colorbar(scatter, ax=ax, label="Compression (1 - feature ratio)") best_idx = int(np.argmax(cv_scores)) ax.scatter( n_feats[best_idx], cv_scores[best_idx], s=250, marker="*", color="gold", edgecolors="black", linewidths=1.5, label=f"Best CV score\n({cv_scores[best_idx]:.3f}, {n_feats[best_idx]} feats)", zorder=10, ) ax.set_xlabel("Number of selected features") ax.set_ylabel("CV score") ax.set_title(title) ax.legend() ax.grid(True, alpha=0.3) fig.tight_layout() return fig
[docs] @staticmethod def plot_selected_features( result: SelectionResult, feature_names: list[str] | None = None, figsize: tuple[float, float] = (10, 6), title: str = "Selected vs removed features", ): """Show which features were selected versus removed.""" plt = _require_matplotlib() from matplotlib.patches import Patch names = ( feature_names if feature_names is not None else [f"f{i}" for i in range(len(result.selected_mask))] ) n_total = len(names) colors = ["#2ecc71" if m else "#e74c3c" for m in result.selected_mask] fig, ax = plt.subplots(figsize=figsize) ax.barh(range(n_total), [1] * n_total, color=colors, edgecolor="white", linewidth=0.5) ax.set_yticks(range(n_total)) ax.set_yticklabels(names, fontsize=9) ax.set_xticks([]) ax.set_title(title) ax.legend( handles=[ Patch(facecolor="#2ecc71", label=f"Selected ({result.n_selected})"), Patch(facecolor="#e74c3c", label=f"Removed ({n_total - result.n_selected})"), ], loc="lower right", ) ax.invert_yaxis() fig.tight_layout() return fig