Skip to content

PCA Plotting module

pca

PCA and dimensionality reduction visualization functions.

This module provides utilities for visualizing principal component analysis results, including component distributions, variance explanations, and dataset separation.

correlation_matrix(corr_matrix, k_start=0.001, k_end=0.23, nbins=20, figsize=(10, 8))

Plot correlation matrix between bins.

PARAMETER DESCRIPTION
corr_matrix

Correlation matrix from analyze_bin_correlations()

TYPE: ndarray

k_start

Minimum \(k\) for binning

TYPE: float DEFAULT: 0.001

k_end

Maximum \(k\) for binning

TYPE: float DEFAULT: 0.23

nbins

Number of bins

TYPE: int DEFAULT: 20

figsize

Figure size

TYPE: Tuple DEFAULT: (10, 8)

RETURNS DESCRIPTION

Matplotlib Figure object.

Source code in src/primefeat/plots/pca.py
def correlation_matrix(
    corr_matrix: np.ndarray,
    k_start: float = 0.001,
    k_end: float = 0.23,
    nbins: int = 20,
    figsize: Tuple = (10, 8),
):
    """
    Plot correlation matrix between bins.

    Args:
        corr_matrix: Correlation matrix from analyze_bin_correlations()
        k_start: Minimum $k$ for binning
        k_end: Maximum $k$ for binning
        nbins: Number of bins
        figsize: Figure size

    Returns:
        Matplotlib Figure object.
    """
    bin_centers = get_bin_centers(k_start, k_end, nbins)

    fig, ax = plt.subplots(figsize=figsize)

    im = ax.imshow(corr_matrix, cmap="RdBu_r", vmin=-1, vmax=1, aspect="auto")

    # Set ticks to bin indices
    ax.set_xticks(range(nbins))
    ax.set_yticks(range(nbins))
    ax.set_xticklabels([f"{i + 1}" for i in range(nbins)], fontsize=8)
    ax.set_yticklabels([f"{i + 1}" for i in range(nbins)], fontsize=8)

    ax.set_xlabel("Bin index", fontsize=11)
    ax.set_ylabel("Bin index", fontsize=11)
    ax.set_title("Correlation Matrix Between Bins", fontsize=12)

    # Add colorbar
    cbar = plt.colorbar(im, ax=ax)
    cbar.set_label("Correlation coefficient", fontsize=10)

    plt.tight_layout()
    return fig

dataset_separation_PC(results, chains_dict=None, pc_x=1, pc_y=2, figsize=(10, 8), colors=None, chain_entries=None, plot_type='scatter', filled=False, contour_levels=None, legend_loc=(0.5, 1.15))

Plot samples from different datasets in principal component space.

This shows whether different datasets have distinct feature preferences (separation in PC space) or if they agree (overlap in PC space).

PARAMETER DESCRIPTION
results

PCAResults from perform_pca()

TYPE: PCAResults

chains_dict

Dictionary of chains (for labels); auto-extracted from chain_entries if provided

TYPE: Optional[Dict] DEFAULT: None

pc_x

Which PC to plot on x-axis (default: 1)

TYPE: int DEFAULT: 1

pc_y

Which PC to plot on y-axis (default: 2)

TYPE: int DEFAULT: 2

figsize

Figure size (default: (10, 8))

TYPE: Tuple DEFAULT: (10, 8)

colors

Optional list of colors for each dataset; auto-extracted from chain_entries if provided

TYPE: Optional[List] DEFAULT: None

chain_entries

ChainsCollection or list of ChainEntry objects; colors and chains auto-extracted

TYPE: Optional[List] DEFAULT: None

plot_type

"scatter" for scatter plot or "contour" for getdist contours (default: "scatter")

TYPE: str DEFAULT: 'scatter'

filled

Whether to use filled contours (only for plot_type="contour", default: False)

TYPE: bool DEFAULT: False

contour_levels

Confidence levels for contours (default: [0.68, 0.95] for \(1\sigma\) and \(2\sigma\))

TYPE: Optional[List] DEFAULT: None

RETURNS DESCRIPTION

Matplotlib Figure object.

RAISES DESCRIPTION
ValueError

If chains_dict is not provided and chain_entries not usable.

Source code in src/primefeat/plots/pca.py
def dataset_separation_PC(
    results: PCAResults,
    chains_dict: Optional[Dict] = None,
    pc_x: int = 1,
    pc_y: int = 2,
    figsize: Tuple = (10, 8),
    colors: Optional[List] = None,
    chain_entries: Optional[List] = None,
    plot_type: str = "scatter",
    filled: bool = False,
    contour_levels: Optional[List] = None,
    legend_loc: Optional[Tuple] = (0.5, 1.15),
):
    """
    Plot samples from different datasets in principal component space.

    This shows whether different datasets have distinct feature preferences
    (separation in PC space) or if they agree (overlap in PC space).

    Args:
        results: PCAResults from perform_pca()
        chains_dict: Dictionary of chains (for labels); auto-extracted from chain_entries if provided
        pc_x: Which PC to plot on x-axis (default: 1)
        pc_y: Which PC to plot on y-axis (default: 2)
        figsize: Figure size (default: (10, 8))
        colors: Optional list of colors for each dataset; auto-extracted from chain_entries if provided
        chain_entries: ChainsCollection or list of ChainEntry objects; colors and chains auto-extracted
        plot_type: "scatter" for scatter plot or "contour" for getdist contours (default: "scatter")
        filled: Whether to use filled contours (only for plot_type="contour", default: False)
        contour_levels: Confidence levels for contours (default: [0.68, 0.95] for $1\\sigma$ and $2\\sigma$)

    Returns:
        Matplotlib Figure object.

    Raises:
        ValueError: If chains_dict is not provided and chain_entries not usable.
    """
    # Handle extraction from chain_entries
    if chain_entries is not None:
        colors = [e.color for e in chain_entries]
        if chains_dict is None:
            chains_dict = {e.label: e.samples for e in chain_entries}

    if chains_dict is None:
        raise ValueError(
            "Either chains_dict or chain_entries must be provided to dataset_separation_PC()"
        )

    if colors is None:
        colors = [
            "#2E86AB",
            "#A23B72",
            "#F18F01",
            "#C73E1D",
            "#6A994E",
            "#7209B7",
            "#3A86FF",
            "#FB5607",
        ]

    if contour_levels is None:
        contour_levels = [0.68, 0.95]

    fig, ax = plt.subplots(figsize=figsize)

    # Plot each dataset separately
    labels = results.dataset_labels
    unique_labels = list(chains_dict.keys())

    if plot_type == "contour":
        # Use getdist for contour plotting
        from getdist import MCSamples
        from matplotlib.lines import Line2D

        # Create legend handles
        legend_handles = []

        for i, label in enumerate(unique_labels):
            # Get indices for this dataset
            mask = np.array([l == label for l in labels])

            # Get PC coordinates
            pc_coords = results.transformed_data[mask]

            # Create MCSamples object for getdist
            # Use simple names for getdist - matplotlib will handle axis labels
            names = [f"PC{pc_x}", f"PC{pc_y}"]

            samples = MCSamples(
                samples=pc_coords[:, [pc_x - 1, pc_y - 1]],
                names=names,
                labels=names,  # Simple labels for getdist
                label=label,
            )

            # Get 2D density and plot contours manually
            density = samples.get2DDensity(names[0], names[1])

            # Convert contour levels to actual density levels
            # getdist uses confidence levels (0-1), need to convert to density values
            density_levels = [
                density.getContourLevels([level])[0] for level in contour_levels
            ]

            color = colors[i % len(colors)]

            if filled:
                # Filled contours
                ax.contourf(
                    density.x,
                    density.y,
                    density.P,
                    levels=sorted(density_levels + [density.P.max()]),
                    colors=[color],
                    alpha=0.3,
                )

            # Always plot contour lines
            ax.contour(
                density.x,
                density.y,
                density.P,
                levels=sorted(density_levels),
                colors=[color],
                linewidths=1.5,
            )

            # Create legend handle for this dataset
            legend_handles.append(
                Line2D([0], [0], color=color, linewidth=2, label=label)
            )

    else:  # scatter plot
        for i, label in enumerate(unique_labels):
            # Get indices for this dataset
            mask = np.array([l == label for l in labels])

            # Get PC coordinates
            pc_coords = results.transformed_data[mask]

            # Plot with transparency to show density
            ax.scatter(
                pc_coords[:, pc_x - 1],
                pc_coords[:, pc_y - 1],
                alpha=0.3,
                s=1,
                color=colors[i % len(colors)],
                label=label,
            )

    ax.axhline(0, ls="--", c="k", alpha=0.3)
    ax.axvline(0, ls="--", c="k", alpha=0.3)
    ax.set_xlabel(
        rf"PC{pc_x} ({1e2 * results.explained_variance_ratio[pc_x - 1]:.1f}$\%$)",
        fontsize=11,
    )
    ax.set_ylabel(
        rf"PC{pc_y} ({1e2 * results.explained_variance_ratio[pc_y - 1]:.1f}$\%$)",
        fontsize=11,
    )

    # Add legend (use custom handles for contour mode)
    if plot_type == "contour":
        ax.legend(
            handles=legend_handles,
            # framealpha=0.9,
            loc="center",
            bbox_to_anchor=legend_loc,
            ncols=len(chains_dict) // 2,
            frameon=False,
        )
    else:
        ax.legend(framealpha=0.9)

    plt.tight_layout()
    return fig

rectangle_PC_scores(results, chains_dict, pcs_x=[1, 3], pcs_y=[2, 4], colors=None, filled=False, **kwargs)

Plot PC scores in a rectangle grid using getdist contours.

Creates MCSamples from PC scores and uses the rectangle() function to display a grid of 2D marginalized contours.

PARAMETER DESCRIPTION
results

PCAResults from perform_pca()

TYPE: PCAResults

chains_dict

Dictionary of chains (for labels)

TYPE: Dict

pcs_x

Which PCs for x-axis (default: [1, 3])

TYPE: List[int] DEFAULT: [1, 3]

pcs_y

Which PCs for y-axis (default: [2, 4])

TYPE: List[int] DEFAULT: [2, 4]

colors

Optional list of colors for each dataset

TYPE: Optional[List] DEFAULT: None

filled

Whether to use filled contours (default: False)

TYPE: bool DEFAULT: False

**kwargs

Additional arguments passed to rectangle()

DEFAULT: {}

RETURNS DESCRIPTION

GetDistPlotter object.

Source code in src/primefeat/plots/pca.py
def rectangle_PC_scores(
    results: PCAResults,
    chains_dict: Dict,
    pcs_x: List[int] = [1, 3],
    pcs_y: List[int] = [2, 4],
    colors: Optional[List] = None,
    filled: bool = False,
    **kwargs,
):
    """
    Plot PC scores in a rectangle grid using getdist contours.

    Creates MCSamples from PC scores and uses the rectangle() function
    to display a grid of 2D marginalized contours.

    Args:
        results: PCAResults from perform_pca()
        chains_dict: Dictionary of chains (for labels)
        pcs_x: Which PCs for x-axis (default: [1, 3])
        pcs_y: Which PCs for y-axis (default: [2, 4])
        colors: Optional list of colors for each dataset
        filled: Whether to use filled contours (default: False)
        **kwargs: Additional arguments passed to rectangle()

    Returns:
        GetDistPlotter object.
    """
    from getdist import MCSamples

    if colors is None:
        colors = ["gray", "C0", "C1", "C2", "dodgerblue", "olive"]

    # Determine which PCs we need
    all_pcs = sorted(set(pcs_x + pcs_y))
    max_pc = max(all_pcs)

    # Build MCSamples for each dataset from PC scores
    labels = results.dataset_labels
    unique_labels = list(chains_dict.keys())
    pc_chains = {}

    for label in unique_labels:
        mask = np.array([l == label for l in labels])
        pc_coords = results.transformed_data[mask]

        # Create parameter names with variance explained
        names = [f"PC{i}" for i in range(1, max_pc + 1)]
        display_labels = [
            rf"PC{i} ({100 * results.explained_variance_ratio[i - 1]:.1f}\%)"
            for i in range(1, max_pc + 1)
        ]

        samples = MCSamples(
            samples=pc_coords[:, :max_pc],
            names=names,
            labels=display_labels,
            label=label,
        )
        pc_chains[label] = samples

    # Build parameter lists for rectangle()
    params_X = [f"PC{i}" for i in pcs_x]
    params_Y = [f"PC{i}" for i in pcs_y]

    return rectangle(
        pc_chains,
        params_X=params_X,
        params_Y=params_Y,
        colors=colors,
        filled=filled,
        **kwargs,
    )

plot_principal_components(results, k_start=0.001, k_end=0.23, nbins=20, n_components=5, figsize=(10, 10), binning=None)

Visualize principal components as functions in \(k\)-space.

This shows what each PC "looks like" as a pattern in the primordial power spectrum deviations.

PARAMETER DESCRIPTION
results

PCAResults from perform_pca()

TYPE: PCAResults

k_start

Minimum \(k\) for binning (default: 0.001 Mpc\(^{-1}\))

TYPE: float DEFAULT: 0.001

k_end

Maximum \(k\) for binning (default: 0.23 Mpc\(^{-1}\))

TYPE: float DEFAULT: 0.23

nbins

Number of bins (default: 20)

TYPE: int DEFAULT: 20

n_components

Number of PCs to plot (default: 5)

TYPE: int DEFAULT: 5

figsize

Figure size (default: (10, 10))

TYPE: Tuple DEFAULT: (10, 10)

binning

optional BinningScheme object (supersedes k_start, k_end, nbins if provided)

TYPE: Optional[BinningScheme] DEFAULT: None

RETURNS DESCRIPTION

Matplotlib Figure object.

Source code in src/primefeat/plots/pca.py
def plot_principal_components(
    results: PCAResults,
    k_start: float = 0.001,
    k_end: float = 0.23,
    nbins: int = 20,
    n_components: int = 5,
    figsize: Tuple = (10, 10),
    binning: Optional[BinningScheme] = None,
):
    """
    Visualize principal components as functions in $k$-space.

    This shows what each PC "looks like" as a pattern in the primordial
    power spectrum deviations.

    Args:
        results: PCAResults from perform_pca()
        k_start: Minimum $k$ for binning (default: 0.001 Mpc$^{-1}$)
        k_end: Maximum $k$ for binning (default: 0.23 Mpc$^{-1}$)
        nbins: Number of bins (default: 20)
        n_components: Number of PCs to plot (default: 5)
        figsize: Figure size (default: (10, 10))
        binning: optional BinningScheme object (supersedes k_start, k_end, nbins if provided)

    Returns:
        Matplotlib Figure object.
    """
    b = _resolve_binning(binning, k_start, k_end, nbins)
    bin_centers = b.bin_centers

    n_to_plot = min(n_components, results.n_components)
    fig, axes = plt.subplots(n_to_plot, 1, figsize=figsize, sharex=True)

    # Handle single axis case
    if n_to_plot == 1:
        axes = [axes]

    for i in range(n_to_plot):
        # Get PC loadings (how each bin contributes to this PC)
        loadings = results.components[i]

        # Transform back to original scale
        loadings_rescaled = loadings / results.scaler.scale_

        # Plot
        axes[i].plot(bin_centers, loadings_rescaled, "o-", color=f"C{i}", linewidth=2)
        axes[i].axhline(0, ls="--", c="k", alpha=0.3)
        axes[i].set_ylabel(f"PC{i + 1}\n({results.explained_variance_ratio[i]:.1%})")
        axes[i].set_xscale("log")
        axes[i].grid(True, alpha=0.3)

        # Add interpretation hints
        if i == 0:
            axes[i].set_title(
                "Principal Components in k-space\n(Dominant patterns of variation)",
                fontsize=12,
            )

    axes[-1].set_xlabel(r"$k$ [Mpc$^{-1}$]", fontsize=11)
    plt.tight_layout()
    return fig

plot_variance_explained(results, figsize=(12, 5))

Plot variance explained by principal components.

PARAMETER DESCRIPTION
results

PCAResults from perform_pca()

TYPE: PCAResults

figsize

Figure size (default: (12, 5))

TYPE: Tuple DEFAULT: (12, 5)

RETURNS DESCRIPTION
fig

matplotlib Figure object

Source code in src/primefeat/plots/pca.py
def plot_variance_explained(results: PCAResults, figsize: Tuple = (12, 5)):
    """
    Plot variance explained by principal components.

    Args:
        results: PCAResults from perform_pca()
        figsize: Figure size (default: (12, 5))

    Returns:
        fig: matplotlib Figure object
    """
    fig, axes = plt.subplots(1, 2, figsize=figsize)

    # Plot 1: Individual variance per component
    axes[0].bar(
        range(1, len(results.explained_variance_ratio) + 1),
        results.explained_variance_ratio,
        alpha=0.7,
        color="steelblue",
    )
    axes[0].axhline(0.05, ls="--", c="red", alpha=0.5, label="5% threshold")
    axes[0].set_xlabel("Principal Component")
    axes[0].set_ylabel("Variance Explained")
    axes[0].set_title("Variance Explained by Each PC")
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)

    # Plot 2: Cumulative variance
    axes[1].plot(
        range(1, len(results.cumulative_variance) + 1),
        results.cumulative_variance,
        "o-",
        color="steelblue",
        linewidth=2,
        markersize=6,
    )
    axes[1].axhline(0.95, ls="--", c="red", alpha=0.5, label="95% threshold")
    axes[1].axvline(
        results.effective_dim,
        ls="--",
        c="green",
        alpha=0.5,
        label=f"Effective dim = {results.effective_dim}",
    )
    axes[1].set_xlabel("Number of Components")
    axes[1].set_ylabel("Cumulative Variance Explained")
    axes[1].set_title("Cumulative Variance Explained")
    axes[1].set_ylim([0, 1.05])
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)

    plt.tight_layout()
    return fig

ridgeline_delta(chains, nbins=20, k_start=0.001, k_end=0.23, overlap=0.7, cmap='coolwarm', alpha=0.8, linewidth=1.0, xlim=(-0.6, 0.6), bw_method=None, fig_kw=None, binning=None)

Ridgeline plot of the marginal posterior distributions of the delta_i bin amplitudes.

Each ridge is the KDE of the posterior samples for one k-bin, normalized independently to its own peak so that the shape of every distribution is visible regardless of how much the posterior widths vary across bins. Bins are stacked bottom (k_start) to top (k_end).

PARAMETER DESCRIPTION
chains

Dict mapping dataset label to MCSamples, or a single MCSamples object. When a dict is given, each dataset gets its own panel (column).

nbins

Number of delta bins (default: 20).

DEFAULT: 20

k_start

Lower edge of the binning range in Mpc^-1 (default: 0.001).

DEFAULT: 0.001

k_end

Upper edge of the binning range in Mpc^-1 (default: 0.23).

DEFAULT: 0.23

overlap

Ridge height as a fraction of the inter-ridge spacing (default: 0.7). Values > 1 produce overlapping ridges.

DEFAULT: 0.7

cmap

Colormap applied across ridges to encode k value (default: "coolwarm").

DEFAULT: 'coolwarm'

alpha

Fill opacity (default: 0.8).

DEFAULT: 0.8

linewidth

Ridge outline width (default: 1.0).

DEFAULT: 1.0

xlim

x-axis limits for the delta_i axis (default: (-0.6, 0.6)).

DEFAULT: (-0.6, 0.6)

bw_method

Bandwidth method forwarded to scipy.stats.gaussian_kde (default: None uses Silverman's rule).

DEFAULT: None

fig_kw

Keyword arguments forwarded to plt.subplots() (default: None).

DEFAULT: None

RETURNS DESCRIPTION
fig

Matplotlib figure.

Example

chains = pf.get_chains("figure2") fig = pf.plot.ridgeline_delta(chains, nbins=20)

Source code in src/primefeat/plots/pca.py
def ridgeline_delta(
    chains,
    nbins=20,
    k_start=0.001,
    k_end=0.23,
    overlap=0.7,
    cmap="coolwarm",
    alpha=0.8,
    linewidth=1.0,
    xlim=(-0.6, 0.6),
    bw_method=None,
    fig_kw=None,
    binning: Optional[BinningScheme] = None,
):
    """
    Ridgeline plot of the marginal posterior distributions of the delta_i bin amplitudes.

    Each ridge is the KDE of the posterior samples for one k-bin, normalized
    independently to its own peak so that the shape of every distribution is
    visible regardless of how much the posterior widths vary across bins.
    Bins are stacked bottom (k_start) to top (k_end).

    Args:
        chains: Dict mapping dataset label to MCSamples, or a single MCSamples object.
                When a dict is given, each dataset gets its own panel (column).
        nbins: Number of delta bins (default: 20).
        k_start: Lower edge of the binning range in Mpc^-1 (default: 0.001).
        k_end: Upper edge of the binning range in Mpc^-1 (default: 0.23).
        overlap: Ridge height as a fraction of the inter-ridge spacing (default: 0.7).
                 Values > 1 produce overlapping ridges.
        cmap: Colormap applied across ridges to encode k value (default: "coolwarm").
        alpha: Fill opacity (default: 0.8).
        linewidth: Ridge outline width (default: 1.0).
        xlim: x-axis limits for the delta_i axis (default: (-0.6, 0.6)).
        bw_method: Bandwidth method forwarded to ``scipy.stats.gaussian_kde``
                   (default: None uses Silverman's rule).
        fig_kw: Keyword arguments forwarded to ``plt.subplots()`` (default: None).

    Returns:
        fig: Matplotlib figure.

    Example:
        >>> chains = pf.get_chains("figure2")
        >>> fig = pf.plot.ridgeline_delta(chains, nbins=20)
    """
    from scipy.stats import gaussian_kde
    from matplotlib.cm import get_cmap

    if not isinstance(chains, dict):
        chains = {"": chains}

    b = _resolve_binning(binning, k_start, k_end, nbins)
    bin_centers = b.bin_centers
    labels = [rf"$k={k:.3f}$" for k in bin_centers]
    x_eval = np.linspace(xlim[0], xlim[1], 400)

    ncols = len(chains)
    fig_kw = fig_kw or {}
    fig_kw.setdefault("figsize", (3.0 * ncols, 6))

    fig, axs = plt.subplots(**fig_kw)
    if ncols == 1:
        axs = [axs]

    cm = get_cmap(cmap)
    ridge_colors = [cm(i / max(b.nbins - 1, 1)) for i in range(b.nbins)]

    spacing = 1.0
    height = overlap * spacing

    for ax, (label, chain) in zip(axs, chains.items()):
        for i, param in enumerate(b.bin_param_names):
            samples = chain[param]
            y_base = i * spacing

            try:
                kde = gaussian_kde(samples, bw_method=bw_method)
                y_kde = kde(x_eval)
                peak = y_kde.max()
                if peak > 0:
                    y_kde = y_kde / peak  # normalize each ridge to its own peak
            except Exception:
                y_kde = np.zeros_like(x_eval)

            y_curve = y_base + y_kde * height
            color = ridge_colors[i]

            ax.fill_between(
                x_eval, y_base, y_curve, color=color, alpha=alpha, zorder=nbins - i
            )
            ax.plot(x_eval, y_curve, color=color, lw=linewidth, zorder=nbins - i)
            # white baseline to clip lower ridges
            ax.fill_between(
                x_eval, y_base - height, y_base, color="white", zorder=nbins - i - 0.5
            )

        ax.set_yticks([i * spacing + height * 0.5 for i in range(nbins)])
        ax.set_yticklabels(labels, fontsize=6)
        ax.set_ylim(-height * 0.1, (nbins - 1) * spacing + height * 1.1)
        ax.set_xlim(xlim)
        ax.set_xlabel(r"$\delta_i$")
        ax.set_title(label)
        ax.axvline(0, color="k", lw=1.5, ls="-", zorder=0)
        ax.tick_params(left=False)
        for spine in ("left", "right", "top"):
            ax.spines[spine].set_visible(False)
        ax.grid(False)

    # Hide y-tick labels on all but the leftmost panel
    for ax in axs[1:]:
        ax.set_yticklabels([])

    fig.suptitle(r"Posterior distributions of $\delta_i$ bins")
    fig.tight_layout()
    return fig

RMSE_variance(pca_result, error_data, **fig_kw)

Plot RMSE and variance explained vs number of PCA components.

PARAMETER DESCRIPTION
pca_result

PCAResults object containing effective_dim

error_data

List of dicts with keys: n_components, mean_rmse, median_rmse, percentile_95, variance_explained

**fig_kw

Keyword arguments passed to plt.subplots()

DEFAULT: {}

RETURNS DESCRIPTION

Matplotlib Figure object.

Source code in src/primefeat/plots/pca.py
def RMSE_variance(pca_result, error_data, **fig_kw):
    """
    Plot RMSE and variance explained vs number of PCA components.

    Args:
        pca_result: PCAResults object containing effective_dim
        error_data: List of dicts with keys: n_components, mean_rmse, median_rmse, percentile_95, variance_explained
        **fig_kw: Keyword arguments passed to plt.subplots()

    Returns:
        Matplotlib Figure object.
    """
    fig, (ax1, ax2) = plt.subplots(
        2, 1, sharex=True, gridspec_kw={"hspace": 0.05}, **fig_kw
    )

    n_comps = [d["n_components"] for d in error_data]
    mean_rmse = [d["mean_rmse"] for d in error_data]
    median_rmse = [d["median_rmse"] for d in error_data]
    p95_rmse = [d["percentile_95"] for d in error_data]

    ax1.plot(
        n_comps, mean_rmse, "o-", linewidth=2.5, label="Mean RMSE", color="steelblue"
    )
    ax1.plot(
        n_comps,
        median_rmse,
        "s--",
        linewidth=2,
        label="Median RMSE",
        alpha=0.7,
        color="orange",
    )
    ax1.fill_between(
        n_comps,
        mean_rmse,
        p95_rmse,
        alpha=0.2,
        color="steelblue",
        label="95th percentile",
    )
    ax1.axvline(pca_result.effective_dim, ls="--", c="k", alpha=1)

    ax1.set_ylabel("RMSE")
    ax1.set_yscale("log")
    ax1.set_ylim(1e-4, 1)
    ax1.legend()
    ax1.set_xlim(0, 21)

    variance_explained = [d["variance_explained"] for d in error_data]

    ax2.plot(
        n_comps,
        np.array(variance_explained) * 100,
        "o-",
        linewidth=2.5,
        color="steelblue",
    )
    ax2.axhline(95, ls="--", c="k", alpha=0.5, label=r"95\% threshold")
    ax2.axvline(
        pca_result.effective_dim,
        ls="--",
        c="k",
        alpha=1,
        label=r"$d_\mathrm{eff}$" + f"= {pca_result.effective_dim}",
    )
    ax2.set_xlabel("Number of Principal Components")
    ax2.set_ylabel(r"Variance Explained (\%)")
    ax2.legend()
    ax2.set_xlim(0, 21)
    ax2.set_ylim(0, 105)
    return fig