Skip to content

GP Plotting module

gp

Visualization functions for Gaussian Process analysis.

This module provides plotting utilities specifically for GP-based significance testing, hyperparameter analysis, and LML landscape visualization.

Original functions from gp_plots.py are consolidated here.

plot_lml_landscape(landscape, figsize=(12, 5), levels=20, show_confidence=True, confidence_levels=[0.68, 0.95], cmap='viridis', vmin=-10.0)

Visualize log-marginal likelihood landscape in (σ, ℓ) hyperparameter space.

Creates a two-panel figure: - Left panel: 2D contour plot showing LML across (σ, ℓ) space - Optimal point marked with red star - Confidence regions (68%, 95%) based on χ² approximation - Color scale shows Δ LML from maximum - Right panel: 1D marginal likelihood profiles - Profile over σ (maximized over ℓ) - Profile over ℓ (maximized over σ)

Physical Interpretation:
  • Narrow peak: Hyperparameters well-constrained by data
  • Ridge structure: σ-ℓ degeneracy (multiple models fit equally well)
  • Broad, flat region: Data uninformative about hyperparameters
  • σ ≈ 0 at maximum: No evidence for signal (null hypothesis)
  • σ > 0, small ℓ: Sharp, localized features detected
  • σ > 0, large ℓ: Smooth, broad features detected
Confidence Regions:

Based on Wilks' theorem, -2Δ LML ~ χ²(k) for k parameters. For 2 parameters (σ, ℓ): - 68% CI: Δ LML ≥ -1.15 (χ²(2, 0.68) / 2) - 95% CI: Δ LML ≥ -3.00 (χ²(2, 0.95) / 2)

PARAMETER DESCRIPTION
landscape

Output from primefeat.gp.compute_lml_landscape()

TYPE: Dict

figsize

Figure size (width, height) in inches

TYPE: Tuple[float, float] DEFAULT: (12, 5)

levels

Number of contour levels

TYPE: int DEFAULT: 20

show_confidence

Whether to show confidence region contours

TYPE: bool DEFAULT: True

confidence_levels

List of confidence levels (e.g., [0.68, 0.95])

TYPE: List[float] DEFAULT: [0.68, 0.95]

cmap

Matplotlib colormap name

TYPE: str DEFAULT: 'viridis'

vmin

Minimum Δ LML to display (clips very low values)

TYPE: float DEFAULT: -10.0

RETURNS DESCRIPTION
fig

Matplotlib Figure object

TYPE: Figure

axes

Array of Axes objects [ax_2d, ax_profiles]

TYPE: ndarray

Example

from primefeat.gp import compute_lml_landscape landscape = compute_lml_landscape(delta_mean, log_k, nbins=20, ...) fig, axes = plot_lml_landscape(landscape) plt.savefig('lml_landscape.pdf', bbox_inches='tight', dpi=150) plt.show()

Source code in src/primefeat/plots/gp.py
def plot_lml_landscape(
    landscape: Dict,
    figsize: Tuple[float, float] = (12, 5),
    levels: int = 20,
    show_confidence: bool = True,
    confidence_levels: List[float] = [0.68, 0.95],
    cmap: str = "viridis",
    vmin: float = -10.0,
) -> Tuple[Figure, np.ndarray]:
    """
    Visualize log-marginal likelihood landscape in (σ, ℓ) hyperparameter space.

    Creates a two-panel figure:
    - **Left panel**: 2D contour plot showing LML across (σ, ℓ) space
      - Optimal point marked with red star
      - Confidence regions (68%, 95%) based on χ² approximation
      - Color scale shows Δ LML from maximum
    - **Right panel**: 1D marginal likelihood profiles
      - Profile over σ (maximized over ℓ)
      - Profile over ℓ (maximized over σ)

    Physical Interpretation:
    ------------------------
    - **Narrow peak**: Hyperparameters well-constrained by data
    - **Ridge structure**: σ-ℓ degeneracy (multiple models fit equally well)
    - **Broad, flat region**: Data uninformative about hyperparameters
    - **σ ≈ 0 at maximum**: No evidence for signal (null hypothesis)
    - **σ > 0, small ℓ**: Sharp, localized features detected
    - **σ > 0, large ℓ**: Smooth, broad features detected

    Confidence Regions:
    ------------------
    Based on Wilks' theorem, -2Δ LML ~ χ²(k) for k parameters.
    For 2 parameters (σ, ℓ):
    - 68% CI: Δ LML ≥ -1.15 (χ²(2, 0.68) / 2)
    - 95% CI: Δ LML ≥ -3.00 (χ²(2, 0.95) / 2)

    Args:
        landscape: Output from primefeat.gp.compute_lml_landscape()
        figsize: Figure size (width, height) in inches
        levels: Number of contour levels
        show_confidence: Whether to show confidence region contours
        confidence_levels: List of confidence levels (e.g., [0.68, 0.95])
        cmap: Matplotlib colormap name
        vmin: Minimum Δ LML to display (clips very low values)

    Returns:
        fig: Matplotlib Figure object
        axes: Array of Axes objects [ax_2d, ax_profiles]

    Example:
        >>> from primefeat.gp import compute_lml_landscape
        >>> landscape = compute_lml_landscape(delta_mean, log_k, nbins=20, ...)
        >>> fig, axes = plot_lml_landscape(landscape)
        >>> plt.savefig('lml_landscape.pdf', bbox_inches='tight', dpi=150)
        >>> plt.show()
    """
    sigma_grid = landscape["sigma_grid"]
    length_scale_grid = landscape["length_scale_grid"]
    lml_grid = landscape["lml_grid"]

    optimal_sigma = landscape["optimal_sigma"]
    optimal_length_scale = landscape["optimal_length_scale"]
    max_lml = landscape["max_lml"]
    bayes_factor = landscape["bayes_factor"]

    fig, axes = plt.subplots(1, 2, figsize=figsize)

    ax = axes[0]
    delta_lml = lml_grid - max_lml
    L, S = np.meshgrid(length_scale_grid, sigma_grid)

    contour = ax.contourf(
        L,
        S,
        delta_lml,
        levels=levels,
        cmap=cmap,
        vmin=vmin,
        vmax=0,
        extend="min",
    )

    cs_lines = ax.contour(
        L,
        S,
        delta_lml,
        levels=10,
        colors="white",
        alpha=0.3,
        linewidths=0.5,
    )

    ax.plot(
        optimal_length_scale,
        optimal_sigma,
        "r*",
        markersize=20,
        markeredgecolor="white",
        markeredgewidth=1.5,
        zorder=10,
    )

    if show_confidence:
        from scipy.stats import chi2

        for i, conf_level in enumerate(confidence_levels):
            chi2_threshold = chi2.ppf(conf_level, df=2)
            delta_lml_threshold = -0.5 * chi2_threshold

            try:
                cs_conf = ax.contour(
                    L,
                    S,
                    delta_lml,
                    levels=[delta_lml_threshold],
                    colors=["red", "orange"][i % 2],
                    linewidths=2.5,
                    alpha=0.8,
                    linestyles=["-", "--"][i % 2],
                )
                fmt_str = f"{int(conf_level * 100)}% CI"
                ax.clabel(cs_conf, inline=True, fontsize=9, fmt=fmt_str, manual=False)
            except (ValueError, IndexError):
                pass

    ax.set_xlabel(r"Length Scale $\ell$ (in log k space)", fontsize=11)
    ax.set_ylabel(r"Signal Std $\sigma$", fontsize=11)
    ax.set_title("Log-Marginal Likelihood Landscape", fontsize=12, fontweight="bold")
    ax.grid(True, alpha=0.2, linestyle=":", linewidth=0.5)
    ax.legend(loc="best", fontsize=9, framealpha=0.9)

    cbar = plt.colorbar(contour, ax=ax, label=r"$\Delta$ log(L) from maximum")
    cbar.ax.tick_params(labelsize=9)

    ax.text(
        0.02,
        0.98,
        f"BF = {bayes_factor:.1e}",
        transform=ax.transAxes,
        fontsize=10,
        verticalalignment="top",
        bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.8),
    )

    ax = axes[1]
    lml_sigma_profile = lml_grid.max(axis=1)
    delta_lml_sigma = lml_sigma_profile - max_lml

    lml_length_profile = lml_grid.max(axis=0)
    delta_lml_length = lml_length_profile - max_lml

    ax.plot(
        sigma_grid,
        delta_lml_sigma,
        "b-",
        linewidth=2.5,
        alpha=0.8,
    )
    ax.axvline(
        optimal_sigma,
        color="b",
        linestyle="--",
        alpha=0.5,
        linewidth=1.5,
    )

    ax.plot(
        length_scale_grid,
        delta_lml_length,
        "r-",
        linewidth=2.5,
        label=r"Profile over $\ell$" + "\n(max over " + r"$\sigma$)",
        alpha=0.8,
    )
    ax.axvline(
        optimal_length_scale,
        color="r",
        linestyle="--",
        alpha=0.5,
        linewidth=1.5,
    )

    ax.axhline(0, color="black", linestyle="-", alpha=0.3, linewidth=1)
    ax.axhline(
        -2,
        color="gray",
        linestyle=":",
        alpha=0.5,
        linewidth=1,
    )

    ax.set_xlabel("Hyperparameter Value", fontsize=11)
    ax.set_ylabel(r"$\Delta$ log(L) from maximum", fontsize=11)
    ax.set_title("Marginal Likelihood Profiles", fontsize=12, fontweight="bold")
    ax.legend(loc="best", fontsize=9, framealpha=0.9)
    ax.grid(True, alpha=0.3, linestyle=":", linewidth=0.5)
    ax.set_ylim([max(vmin, delta_lml_sigma.min(), delta_lml_length.min()), 0.5])

    plt.tight_layout()

    return fig, axes

plot_gp_posterior_predictive(landscape, n_samples=100, figsize=(10, 5), show_data=True, k_start=None, k_end=None)

Plot GP posterior predictive distribution given optimal hyperparameters.

Shows the inferred smooth GP function that best explains the data, along with uncertainty bands. Properly accounts for correlations in the posterior covariance of delta values.

PARAMETER DESCRIPTION
landscape

Output from compute_lml_landscape()

TYPE: Dict

n_samples

Number of function samples to draw

TYPE: int DEFAULT: 100

figsize

Figure size

TYPE: Tuple[float, float] DEFAULT: (10, 5)

show_data

Whether to plot observed data points

TYPE: bool DEFAULT: True

k_start, k_end

For x-axis labeling (optional)

RETURNS DESCRIPTION
Tuple[Figure, Axes]

fig, ax: Matplotlib Figure and Axes objects

Example

fig, ax = plot_gp_posterior_predictive(landscape, n_samples=50)

Source code in src/primefeat/plots/gp.py
def plot_gp_posterior_predictive(
    landscape: Dict,
    n_samples: int = 100,
    figsize: Tuple[float, float] = (10, 5),
    show_data: bool = True,
    k_start: Optional[float] = None,
    k_end: Optional[float] = None,
) -> Tuple[Figure, Axes]:
    """
    Plot GP posterior predictive distribution given optimal hyperparameters.

    Shows the inferred smooth GP function that best explains the data,
    along with uncertainty bands. Properly accounts for correlations in
    the posterior covariance of delta values.

    Args:
        landscape: Output from compute_lml_landscape()
        n_samples: Number of function samples to draw
        figsize: Figure size
        show_data: Whether to plot observed data points
        k_start, k_end: For x-axis labeling (optional)

    Returns:
        fig, ax: Matplotlib Figure and Axes objects

    Example:
        >>> fig, ax = plot_gp_posterior_predictive(landscape, n_samples=50)
    """
    from sklearn.gaussian_process import GaussianProcessRegressor
    from sklearn.gaussian_process.kernels import RBF, ConstantKernel, WhiteKernel
    from scipy.linalg import cho_factor, cho_solve

    log_k = landscape["log_k"]
    delta_values = landscape["delta_values"]
    optimal_sigma = landscape["optimal_sigma"]
    optimal_length_scale = landscape["optimal_length_scale"]
    K_noise = landscape.get("K_noise")
    noise_level = landscape.get("noise_level", 1e-3)

    fig, ax = plt.subplots(figsize=figsize)

    if K_noise is not None:
        try:
            L_factor = cho_factor(K_noise)
            delta_white = cho_solve(L_factor, delta_values)
            effective_noise = 1e-8
        except np.linalg.LinAlgError:
            delta_white = delta_values.copy()
            effective_noise = noise_level**2
    else:
        delta_white = delta_values.copy()
        effective_noise = noise_level**2

    kernel = ConstantKernel(optimal_sigma**2) * RBF(optimal_length_scale) + WhiteKernel(
        effective_noise
    )

    gp = GaussianProcessRegressor(kernel=kernel, alpha=1e-10, normalize_y=False)
    gp.fit(log_k, delta_white)

    log_k_fine = np.linspace(log_k.min(), log_k.max(), 200).reshape(-1, 1)
    mean_pred, std_pred = gp.predict(log_k_fine, return_std=True)

    if K_noise is not None:
        L, _ = cho_factor(K_noise)
        delta_recon = L @ delta_white

        gp_recon = GaussianProcessRegressor(
            kernel=ConstantKernel(optimal_sigma**2) * RBF(optimal_length_scale)
            + WhiteKernel(1e-8),
            alpha=1e-10,
            normalize_y=False,
        )
        gp_recon.fit(log_k, delta_recon)
        mean_pred, std_pred = gp_recon.predict(log_k_fine, return_std=True)

        delta_plot = delta_recon
        label_text = r"Smoothed $\delta$ (decorrelated)"
    else:
        delta_plot = delta_values
        label_text = r"Observed $\delta$"

    ax.plot(
        log_k_fine, mean_pred, "b-", linewidth=2, label="GP posterior mean", zorder=5
    )

    ax.fill_between(
        log_k_fine.ravel(),
        mean_pred - 2 * std_pred,
        mean_pred + 2 * std_pred,
        alpha=0.3,
        color="b",
        label="95% credible region",
        zorder=3,
    )

    if show_data:
        ax.plot(
            log_k,
            delta_plot,
            "ro",
            markersize=6,
            label=label_text,
            zorder=10,
            alpha=0.7,
        )

    ax.axhline(0, color="gray", linestyle="--", alpha=0.5, linewidth=1)

    ax.set_xlabel(r"$\log(k)$", fontsize=11)
    ax.set_ylabel(r"$\delta$", fontsize=11)
    ax.set_title(
        rf"GP Posterior Predictive ($\ell$={optimal_length_scale:.3f}, $\sigma$={optimal_sigma:.3f})",
        fontsize=12,
        fontweight="bold",
    )
    ax.legend(loc="best", fontsize=10)
    ax.grid(True, alpha=0.3)

    plt.tight_layout()

    return fig, ax

plot_lml_slice(landscape, fix_param='length_scale', fix_value=None, figsize=(8, 5))

Plot 1D slice through LML landscape at fixed hyperparameter value.

Useful for understanding the σ-ℓ trade-off and parameter sensitivity.

PARAMETER DESCRIPTION
landscape

Output from compute_lml_landscape()

TYPE: Dict

fix_param

Which parameter to fix ('length_scale' or 'sigma')

TYPE: str DEFAULT: 'length_scale'

fix_value

Value to fix parameter at (uses optimal if None)

TYPE: Optional[float] DEFAULT: None

figsize

Figure size

TYPE: Tuple[float, float] DEFAULT: (8, 5)

RETURNS DESCRIPTION
Tuple[Figure, Axes]

fig, ax: Matplotlib Figure and Axes objects

Example

Fix length scale at optimal, vary sigma

fig, ax = plot_lml_slice(landscape, fix_param='length_scale')

Fix length scale at small value, vary sigma

fig, ax = plot_lml_slice(landscape, fix_param='length_scale', fix_value=0.2)

Source code in src/primefeat/plots/gp.py
def plot_lml_slice(
    landscape: Dict,
    fix_param: str = "length_scale",
    fix_value: Optional[float] = None,
    figsize: Tuple[float, float] = (8, 5),
) -> Tuple[Figure, Axes]:
    """
    Plot 1D slice through LML landscape at fixed hyperparameter value.

    Useful for understanding the σ-ℓ trade-off and parameter sensitivity.

    Args:
        landscape: Output from compute_lml_landscape()
        fix_param: Which parameter to fix ('length_scale' or 'sigma')
        fix_value: Value to fix parameter at (uses optimal if None)
        figsize: Figure size

    Returns:
        fig, ax: Matplotlib Figure and Axes objects

    Example:
        >>> # Fix length scale at optimal, vary sigma
        >>> fig, ax = plot_lml_slice(landscape, fix_param='length_scale')
        >>>
        >>> # Fix length scale at small value, vary sigma
        >>> fig, ax = plot_lml_slice(landscape, fix_param='length_scale', fix_value=0.2)
    """
    sigma_grid = landscape["sigma_grid"]
    length_scale_grid = landscape["length_scale_grid"]
    lml_grid = landscape["lml_grid"]
    max_lml = landscape["max_lml"]

    fig, ax = plt.subplots(figsize=figsize)

    if fix_param == "length_scale":
        if fix_value is None:
            fix_value = landscape["optimal_length_scale"]

        idx = np.argmin(np.abs(length_scale_grid - fix_value))
        actual_value = length_scale_grid[idx]

        lml_slice = lml_grid[:, idx]
        x_values = sigma_grid
        x_label = r"Signal Std $\sigma$"

    elif fix_param == "sigma":
        if fix_value is None:
            fix_value = landscape["optimal_sigma"]

        idx = np.argmin(np.abs(sigma_grid - fix_value))
        actual_value = sigma_grid[idx]

        lml_slice = lml_grid[idx, :]
        x_values = length_scale_grid
        x_label = r"Length Scale $\ell$"

    else:
        raise ValueError(
            f"fix_param must be 'length_scale' or 'sigma', got {fix_param}"
        )

    delta_lml = lml_slice - max_lml
    ax.plot(x_values, delta_lml, "b-", linewidth=2.5)

    max_idx = np.argmax(lml_slice)
    ax.plot(
        x_values[max_idx],
        delta_lml[max_idx],
        "r*",
        markersize=15,
        label=f"Max at {x_values[max_idx]:.3f}",
    )

    ax.axhline(0, color="black", linestyle="-", alpha=0.3)
    ax.axhline(-2, color="gray", linestyle=":", alpha=0.5, label=r"$\Delta$ LML = -2")

    ax.set_xlabel(x_label, fontsize=11)
    ax.set_ylabel(r"$\Delta$ log(L) from global maximum", fontsize=11)
    ax.legend(loc="best", fontsize=10)
    ax.grid(True, alpha=0.3)

    return fig, ax

plot_bin_lml_contributions(landscape, figsize=(9, 4), color_positive='steelblue', color_negative='tomato', show_delta=True)

Plot per-bin decomposition of the data-fit part of the Bayes factor.

The Bayes factor decomposes as:

ln B = data_fit_term + complexity_penalty

where the data-fit term can be attributed to individual bins:

c_i = 0.5 * delta_i * [(K_null^{-1} - K_opt^{-1}) delta]_i

with sum(c_i) = data_fit_term exactly, and the complexity penalty (Occam factor) is a scalar property of the covariance matrices.

Positive bars indicate bins where H1 fits better than H0 (evidence for features at that scale). Negative bars indicate bins where H0 is preferred locally.

PARAMETER DESCRIPTION
landscape

Output dict from gp_significance_test with method='lml'. Must contain 'bin_lml_contributions', 'complexity_penalty', 'log_bayes_factor', 'log_k', 'delta_values'.

TYPE: Dict

figsize

Figure size (width, height) in inches.

TYPE: Tuple[float, float] DEFAULT: (9, 4)

color_positive

Bar color for bins with positive contribution.

TYPE: str DEFAULT: 'steelblue'

color_negative

Bar color for bins with negative contribution.

TYPE: str DEFAULT: 'tomato'

show_delta

If True, overlay the posterior mean delta(k) as a line.

TYPE: bool DEFAULT: True

RETURNS DESCRIPTION
fig

Matplotlib Figure object.

TYPE: Figure

ax

Axes object (or array [ax_main, ax_delta] if show_delta=True).

TYPE: Axes

Example

result = gp_significance_test(chain, method='lml', n_bootstrap=200) fig, ax = plot_bin_lml_contributions(result.lml_landscape) plt.savefig('bin_lml_contributions.pdf', bbox_inches='tight')

Source code in src/primefeat/plots/gp.py
def plot_bin_lml_contributions(
    landscape: Dict,
    figsize: Tuple[float, float] = (9, 4),
    color_positive: str = "steelblue",
    color_negative: str = "tomato",
    show_delta: bool = True,
) -> Tuple[Figure, Axes]:
    """
    Plot per-bin decomposition of the data-fit part of the Bayes factor.

    The Bayes factor decomposes as:

        ln B = data_fit_term + complexity_penalty

    where the data-fit term can be attributed to individual bins:

        c_i = 0.5 * delta_i * [(K_null^{-1} - K_opt^{-1}) delta]_i

    with sum(c_i) = data_fit_term exactly, and the complexity penalty
    (Occam factor) is a scalar property of the covariance matrices.

    Positive bars indicate bins where H1 fits better than H0 (evidence
    for features at that scale). Negative bars indicate bins where H0
    is preferred locally.

    Args:
        landscape: Output dict from gp_significance_test with method='lml'.
                   Must contain 'bin_lml_contributions', 'complexity_penalty',
                   'log_bayes_factor', 'log_k', 'delta_values'.
        figsize: Figure size (width, height) in inches.
        color_positive: Bar color for bins with positive contribution.
        color_negative: Bar color for bins with negative contribution.
        show_delta: If True, overlay the posterior mean delta(k) as a line.

    Returns:
        fig: Matplotlib Figure object.
        ax: Axes object (or array [ax_main, ax_delta] if show_delta=True).

    Example:
        >>> result = gp_significance_test(chain, method='lml', n_bootstrap=200)
        >>> fig, ax = plot_bin_lml_contributions(result.lml_landscape)
        >>> plt.savefig('bin_lml_contributions.pdf', bbox_inches='tight')
    """
    contributions = landscape["bin_lml_contributions"]
    complexity_penalty = landscape["complexity_penalty"]
    log_BF = landscape["log_bayes_factor"]
    log_k = landscape["log_k"].ravel()
    delta_values = landscape["delta_values"]
    k_centers = np.exp(log_k)

    data_fit_total = np.sum(contributions)

    if show_delta:
        fig, (ax, ax_d) = plt.subplots(
            2,
            1,
            figsize=figsize,
            sharex=True,
            gridspec_kw={"height_ratios": [3, 1], "hspace": 0.05},
        )
    else:
        fig, ax = plt.subplots(1, 1, figsize=figsize)

    colors = [color_positive if c >= 0 else color_negative for c in contributions]

    ax.bar(
        range(len(contributions)),
        contributions,
        color=colors,
        alpha=0.8,
        edgecolor="white",
        linewidth=0.5,
        label=r"$c_i^{\rm fit}$ (per-bin data-fit)",
    )

    ax.axhline(
        complexity_penalty,
        color="gray",
        linestyle="--",
        linewidth=1.5,
        label=rf"Complexity penalty = {complexity_penalty:.2f}",
    )

    ax.axhline(0, color="black", linewidth=0.8, alpha=0.5)

    ax.text(
        0.02,
        0.97,
        rf"$\ln\mathcal{{B}} = {log_BF:.2f}$"
        + "\n"
        + rf"Data-fit = {data_fit_total:.2f}"
        + "\n"
        + rf"Occam = {complexity_penalty:.2f}",
        transform=ax.transAxes,
        fontsize=9,
        verticalalignment="top",
        bbox=dict(boxstyle="round", facecolor="white", alpha=0.85, edgecolor="gray"),
    )

    tick_step = max(1, len(k_centers) // 10)
    tick_indices = range(0, len(k_centers), tick_step)
    ax.set_xticks(list(tick_indices))
    ax.set_xticklabels(
        [rf"$10^{{{np.log10(k_centers[i]):.1f}}}$" for i in tick_indices],
        fontsize=8,
    )

    ax.set_ylabel(r"$c_i^{\rm fit}$", fontsize=11)
    ax.legend(fontsize=9, loc="upper right")
    ax.grid(True, alpha=0.25, axis="y")

    if show_delta:
        ax_d.step(
            range(len(delta_values)),
            delta_values,
            where="mid",
            color="black",
            linewidth=1.2,
        )
        ax_d.axhline(0, color="gray", linewidth=0.8, alpha=0.5)
        ax_d.set_ylabel(r"$\bar{\delta}_i$", fontsize=10)
        ax_d.set_xlabel(r"$k\ [\mathrm{Mpc}^{-1}]$", fontsize=11)
        ax_d.grid(True, alpha=0.25, axis="y")
        return fig, np.array([ax, ax_d])

    ax.set_xlabel(r"$k\ [\mathrm{Mpc}^{-1}]$", fontsize=11)
    return fig, ax

GP_prediction(gp_result, chain, binning=None, ax=None, color='C0', label=None, show_data=True, show_samples=False, n_samples=50, alpha_band=0.3, sigma_levels=2, figname=None, fig_kw=None)

Plot GP prediction from optimized hyperparameters.

Given the model

\(\delta = f + \epsilon\), where \(f \sim \mathrm{GP}(0, \sigma^2 K(\ell))\) and \(\epsilon \sim N(0, \Sigma_{\mathrm{post}})\) \(K_{\mathrm{signal}} = \sigma^2 K(\ell)\) (signal kernel) \(K_{\mathrm{total}} = K_{\mathrm{signal}} + \Sigma_{\mathrm{post}}\)

The posterior for \(f\) given observed mean \(\bar{\delta}\) is: \(f | \bar{\delta} \sim N(\mu_f, \Sigma_f)\)

where

\(\mu_f = K_{\mathrm{signal}} K_{\mathrm{total}}^{-1} \bar{\delta}\) \(\Sigma_f = K_{\mathrm{signal}} - K_{\mathrm{signal}} K_{\mathrm{total}}^{-1} K_{\mathrm{signal}}\)

PARAMETER DESCRIPTION
gp_result

GPSignificanceResult from gp_significance_test with method='lml' or 'null+GP'

chain

MCMC chain with delta_i parameters

binning

BinningScheme for bin centers and parameter names. If None, uses LogBinningScheme defaults.

TYPE: Optional[BinningScheme] DEFAULT: None

ax

Optional matplotlib axes to plot on

TYPE: Optional[Axes] DEFAULT: None

color

Color for the plot

TYPE: str DEFAULT: 'C0'

label

Label for legend

TYPE: Optional[str] DEFAULT: None

show_data

If True, show observed \(\delta\) means as scatter points

TYPE: bool DEFAULT: True

show_samples

If True, draw samples from the GP posterior

TYPE: bool DEFAULT: False

n_samples

Number of samples to draw if show_samples=True

TYPE: int DEFAULT: 50

alpha_band

Transparency for confidence bands

TYPE: float DEFAULT: 0.3

sigma_levels

Number of \(\sigma\) levels for confidence bands (1 or 2)

TYPE: int DEFAULT: 2

figname

If provided, save figure to this path

TYPE: Optional[str] DEFAULT: None

fig_kw

Additional kwargs for figure creation

TYPE: Optional[Dict] DEFAULT: None

RETURNS DESCRIPTION
Figure

Matplotlib Figure object.

Examples:

>>> result = pf.significance.gp_significance_test(chain, method='lml')
>>> fig = pf.plot.GP_prediction(result, chain)
Source code in src/primefeat/plots/gp.py
def GP_prediction(
    gp_result,
    chain,
    binning: Optional[BinningScheme] = None,
    ax: Optional[plt.Axes] = None,
    color: str = "C0",
    label: Optional[str] = None,
    show_data: bool = True,
    show_samples: bool = False,
    n_samples: int = 50,
    alpha_band: float = 0.3,
    sigma_levels: int = 2,
    figname: Optional[str] = None,
    fig_kw: Optional[Dict] = None,
) -> plt.Figure:
    """
    Plot GP prediction from optimized hyperparameters.

    Given the model:
        $\\delta = f + \\epsilon$, where $f \\sim \\mathrm{GP}(0, \\sigma^2 K(\\ell))$ and $\\epsilon \\sim N(0, \\Sigma_{\\mathrm{post}})$
        $K_{\\mathrm{signal}} = \\sigma^2 K(\\ell)$ (signal kernel)
        $K_{\\mathrm{total}} = K_{\\mathrm{signal}} + \\Sigma_{\\mathrm{post}}$

    The posterior for $f$ given observed mean $\\bar{\\delta}$ is:
        $f | \\bar{\\delta} \\sim N(\\mu_f, \\Sigma_f)$

    where:
        $\\mu_f = K_{\\mathrm{signal}} K_{\\mathrm{total}}^{-1} \\bar{\\delta}$
        $\\Sigma_f = K_{\\mathrm{signal}} - K_{\\mathrm{signal}} K_{\\mathrm{total}}^{-1} K_{\\mathrm{signal}}$

    Args:
        gp_result: GPSignificanceResult from gp_significance_test with method='lml' or 'null+GP'
        chain: MCMC chain with delta_i parameters
        binning: BinningScheme for bin centers and parameter names. If None, uses LogBinningScheme defaults.
        ax: Optional matplotlib axes to plot on
        color: Color for the plot
        label: Label for legend
        show_data: If True, show observed $\\delta$ means as scatter points
        show_samples: If True, draw samples from the GP posterior
        n_samples: Number of samples to draw if show_samples=True
        alpha_band: Transparency for confidence bands
        sigma_levels: Number of $\\sigma$ levels for confidence bands (1 or 2)
        figname: If provided, save figure to this path
        fig_kw: Additional kwargs for figure creation

    Returns:
        Matplotlib Figure object.

    Examples:
        >>> result = pf.significance.gp_significance_test(chain, method='lml')
        >>> fig = pf.plot.GP_prediction(result, chain)
    """
    from scipy.linalg import cho_factor, cho_solve

    if binning is None:
        binning = LogBinningScheme()

    if gp_result.lml_landscape is None:
        raise ValueError(
            "GP_prediction requires a GPSignificanceResult with lml_landscape. "
            "Use method='lml' or 'null+GP' in gp_significance_test()."
        )

    landscape = gp_result.lml_landscape
    bin_centers = binning.bin_centers
    nbins = binning.nbins

    delta_mean = compute_delta_mean(chain, binning)
    delta_std = compute_delta_std(chain, binning)

    K_signal = landscape["K_signal"]
    K_total = landscape["K"]

    try:
        K_total_factor = cho_factor(K_total)
        K_inv_delta = cho_solve(K_total_factor, delta_mean)
        K_inv_K_signal = cho_solve(K_total_factor, K_signal)
    except np.linalg.LinAlgError:
        K_total_reg = K_total + 1e-8 * np.eye(nbins)
        K_total_factor = cho_factor(K_total_reg)
        K_inv_delta = cho_solve(K_total_factor, delta_mean)
        K_inv_K_signal = cho_solve(K_total_factor, K_signal)

    mu_f = K_signal @ K_inv_delta
    Sigma_f = K_signal - K_signal @ K_inv_K_signal
    std_f = np.sqrt(np.maximum(np.diag(Sigma_f), 0))

    if ax is None:
        fig_kw = fig_kw or {}
        fig_kw.setdefault("figsize", (8, 5))
        fig, ax = plt.subplots(**fig_kw)
    else:
        fig = ax.figure

    for n_sigma in range(sigma_levels, 0, -1):
        alpha = alpha_band / n_sigma
        ax.fill_between(
            bin_centers,
            mu_f - n_sigma * std_f,
            mu_f + n_sigma * std_f,
            alpha=alpha,
            color=color,
            label=f"${n_sigma}\\sigma$ band" if n_sigma == sigma_levels else None,
        )

    ax.plot(bin_centers, mu_f, color=color, linewidth=2, label=label or "GP mean")
    ax.semilogx()

    if show_samples:
        try:
            samples = np.random.multivariate_normal(mu_f, Sigma_f, size=n_samples)
            for sample in samples:
                ax.plot(bin_centers, sample, color=color, alpha=0.1, linewidth=0.5)
        except np.linalg.LinAlgError:
            console.print(
                "[yellow]Warning:[/yellow] Could not draw GP samples (covariance not positive definite)"
            )

    if show_data:
        ax.errorbar(
            bin_centers,
            delta_mean,
            yerr=delta_std,
            fmt="o",
            color="black",
            markersize=5,
            capsize=3,
            label="Observed $\\bar{\\delta}$",
            zorder=10,
        )

    ax.axhline(0, color="gray", linestyle="--", linewidth=0.8, alpha=0.5)

    ax.set_xlabel(r"$k$ [Mpc$^{-1}$]", fontsize=12)
    ax.set_ylabel(r"$\delta_i$", fontsize=12)

    sigma_opt = landscape["optimal_sigma"]
    ell_opt = landscape["optimal_length_scale"]
    ax.set_title(
        f"GP Prediction ($\\sigma^* = {sigma_opt:.3f}$, $\\ell^* = {ell_opt:.2f}$)",
        fontsize=12,
    )

    ax.legend(loc="best", fontsize=10)

    plt.tight_layout()
    if figname:
        plt.savefig(figname, dpi=300, bbox_inches="tight")

    return fig

GP_prediction_comparison(gp_results, chains, binning=None, colors=None, show_data=True, alpha_band=0.2, sigma_levels=1, figname=None, fig_kw=None)

Plot GP predictions from multiple datasets on the same axes.

PARAMETER DESCRIPTION
gp_results

Dictionary mapping dataset names to GPSignificanceResult objects

TYPE: Dict

chains

Dictionary mapping dataset names to MCMC chains

TYPE: Dict

binning

BinningScheme for bin centers and parameter names. If None, uses LogBinningScheme defaults.

TYPE: Optional[BinningScheme] DEFAULT: None

colors

Optional dictionary mapping dataset names to colors

TYPE: Optional[Dict[str, str]] DEFAULT: None

show_data

If True, show observed \(\delta\) means as scatter points

TYPE: bool DEFAULT: True

alpha_band

Transparency for confidence bands

TYPE: float DEFAULT: 0.2

sigma_levels

Number of \(\sigma\) levels for confidence bands

TYPE: int DEFAULT: 1

figname

If provided, save figure to this path

TYPE: Optional[str] DEFAULT: None

fig_kw

Additional kwargs for figure creation

TYPE: Optional[Dict] DEFAULT: None

RETURNS DESCRIPTION
Figure

Matplotlib Figure object.

Examples:

>>> gp_results = {
...     'Dataset A': pf.significance.gp_significance_test(chain_a, method='lml'),
...     'Dataset B': pf.significance.gp_significance_test(chain_b, method='lml'),
... }
>>> chains = {'Dataset A': chain_a, 'Dataset B': chain_b}
>>> fig = pf.plot.GP_prediction_comparison(gp_results, chains)
Source code in src/primefeat/plots/gp.py
def GP_prediction_comparison(
    gp_results: Dict,
    chains: Dict,
    binning: Optional[BinningScheme] = None,
    colors: Optional[Dict[str, str]] = None,
    show_data: bool = True,
    alpha_band: float = 0.2,
    sigma_levels: int = 1,
    figname: Optional[str] = None,
    fig_kw: Optional[Dict] = None,
) -> plt.Figure:
    """
    Plot GP predictions from multiple datasets on the same axes.

    Args:
        gp_results: Dictionary mapping dataset names to GPSignificanceResult objects
        chains: Dictionary mapping dataset names to MCMC chains
        binning: BinningScheme for bin centers and parameter names. If None, uses LogBinningScheme defaults.
        colors: Optional dictionary mapping dataset names to colors
        show_data: If True, show observed $\\delta$ means as scatter points
        alpha_band: Transparency for confidence bands
        sigma_levels: Number of $\\sigma$ levels for confidence bands
        figname: If provided, save figure to this path
        fig_kw: Additional kwargs for figure creation

    Returns:
        Matplotlib Figure object.

    Examples:
        >>> gp_results = {
        ...     'Dataset A': pf.significance.gp_significance_test(chain_a, method='lml'),
        ...     'Dataset B': pf.significance.gp_significance_test(chain_b, method='lml'),
        ... }
        >>> chains = {'Dataset A': chain_a, 'Dataset B': chain_b}
        >>> fig = pf.plot.GP_prediction_comparison(gp_results, chains)
    """
    from scipy.linalg import cho_factor, cho_solve

    if binning is None:
        binning = LogBinningScheme()

    if set(gp_results.keys()) != set(chains.keys()):
        raise ValueError("gp_results and chains must have the same keys")

    for name, result in gp_results.items():
        if result.lml_landscape is None:
            raise ValueError(
                f"GP_prediction_comparison requires lml_landscape for '{name}'. "
                "Use method='lml' or method='null+GP' to compute it."
            )

    bin_centers = binning.bin_centers
    nbins = binning.nbins

    if colors is None:
        default_colors = [f"C{i}" for i in range(10)]
        colors = {
            name: default_colors[i % 10] for i, name in enumerate(gp_results.keys())
        }

    fig_kw = fig_kw or {}
    fig_kw.setdefault("figsize", (10, 6))
    fig, ax = plt.subplots(**fig_kw)

    for name in gp_results.keys():
        result = gp_results[name]
        chain = chains[name]
        color = colors[name]
        landscape = result.lml_landscape

        delta_mean = compute_delta_mean(chain, binning)
        delta_std = compute_delta_std(chain, binning)

        K_signal = landscape["K_signal"]
        K_total = landscape["K"]

        try:
            K_total_factor = cho_factor(K_total)
            K_inv_delta = cho_solve(K_total_factor, delta_mean)
            K_inv_K_signal = cho_solve(K_total_factor, K_signal)
        except np.linalg.LinAlgError:
            K_total_reg = K_total + 1e-8 * np.eye(nbins)
            K_total_factor = cho_factor(K_total_reg)
            K_inv_delta = cho_solve(K_total_factor, delta_mean)
            K_inv_K_signal = cho_solve(K_total_factor, K_signal)

        mu_f = K_signal @ K_inv_delta
        Sigma_f = K_signal - K_signal @ K_inv_K_signal
        std_f = np.sqrt(np.maximum(np.diag(Sigma_f), 0))

        sigma_opt = landscape["optimal_sigma"]
        ell_opt = landscape["optimal_length_scale"]

        for n_sigma in range(sigma_levels, 0, -1):
            alpha = alpha_band / n_sigma
            ax.fill_between(
                bin_centers,
                mu_f - n_sigma * std_f,
                mu_f + n_sigma * std_f,
                alpha=alpha,
                color=color,
            )

        label_str = f"{name} ($\\sigma^*={sigma_opt:.3f}$, $\\ell^*={ell_opt:.2f}$)"
        ax.plot(bin_centers, mu_f, color=color, linewidth=2, label=label_str)
        ax.semilogx()

        if show_data:
            ax.errorbar(
                bin_centers,
                delta_mean,
                yerr=delta_std,
                color=color,
                marker="o",
                alpha=0.7,
                zorder=10,
            )

    ax.axhline(0, color="gray", linestyle="--", linewidth=0.8, alpha=0.5)

    ax.set_xlabel(r"$k$ [Mpc$^{-1}$]", fontsize=12)
    ax.set_ylabel(r"$\delta_i$", fontsize=12)
    ax.set_title("GP Prediction Comparison", fontsize=14)
    ax.legend(loc="best", fontsize=9)
    ax.grid(alpha=0.3)

    plt.tight_layout()
    if figname:
        plt.savefig(figname, dpi=300, bbox_inches="tight")

    return fig

delta_CMB(chains, binning=None, offset=None, ax=None, colors=None, auto_offset_scale=0.05, fig_kw=None, chain_entries=None)

Plot binned primordial power spectrum deviations with measurements from multiple datasets.

Uses symmetric log-space offsets to prevent overlap when displaying multiple measurements at the same bin center. Offsets are distributed symmetrically around the true bin center and applied proportionally (each measurement scaled by 10^offset in log-space), ensuring equal relative separation across the entire k range.

Creates a dual-axis plot with \(k\) and \(\ell\) scales using create_pk_canvas().

PARAMETER DESCRIPTION
chains

Dictionary mapping dataset labels to MCMC chain objects

binning

BinningScheme for bin centers and parameter names. If None, uses LogBinningScheme defaults.

TYPE: Optional[BinningScheme] DEFAULT: None

offset

Log-space offset for each dataset. If None (default), automatically computed as auto_offset_scale * (log spacing between bins). Applied as shifted_k = bin_centers * 10**(offset) for proportional separation.

TYPE: Optional[float] DEFAULT: None

ax

Matplotlib axes object. If None, creates new figure via create_pk_canvas().

TYPE: Optional[Axes] DEFAULT: None

colors

List of colors for each dataset. If None, uses default palette or extracts from chain_entries.

TYPE: Optional[List[str]] DEFAULT: None

auto_offset_scale

Offset parameter for log-space scaling (default: 0.05). Produces ~10^(0.05) ≈ 1.122 or ~12.2% separation per dataset. Increase for larger spacing, decrease for tighter clustering.

TYPE: float DEFAULT: 0.05

fig_kw

Keyword arguments passed to plt.subplots() when creating new figure (default: None).

TYPE: Optional[Dict] DEFAULT: None

chain_entries

Optional ChainsCollection or list of ChainEntry objects for auto color extraction.

TYPE: Optional[List] DEFAULT: None

RETURNS DESCRIPTION
fig

Matplotlib figure object with dual k/ell axes and the plot

TYPE: Figure

Example

chains = {'Planck': chain1, 'ACT': chain2} fig = pf.plots.delta_CMB(chains)

Source code in src/primefeat/plots/gp.py
def delta_CMB(
    chains,
    binning: Optional[BinningScheme] = None,
    offset: Optional[float] = None,
    ax: Optional[plt.Axes] = None,
    colors: Optional[List[str]] = None,
    auto_offset_scale: float = 0.05,
    fig_kw: Optional[Dict] = None,
    chain_entries: Optional[List] = None,
) -> plt.Figure:
    """
    Plot binned primordial power spectrum deviations with measurements from multiple datasets.

    Uses symmetric log-space offsets to prevent overlap when displaying multiple measurements
    at the same bin center. Offsets are distributed symmetrically around the true bin center
    and applied proportionally (each measurement scaled by 10^offset in log-space), ensuring
    equal relative separation across the entire k range.

    Creates a dual-axis plot with $k$ and $\\ell$ scales using ``create_pk_canvas()``.

    Args:
        chains: Dictionary mapping dataset labels to MCMC chain objects
        binning: BinningScheme for bin centers and parameter names. If None, uses LogBinningScheme defaults.
        offset: Log-space offset for each dataset. If None (default), automatically
                computed as auto_offset_scale * (log spacing between bins). Applied as
                shifted_k = bin_centers * 10**(offset) for proportional separation.
        ax: Matplotlib axes object. If None, creates new figure via ``create_pk_canvas()``.
        colors: List of colors for each dataset. If None, uses default palette or extracts from chain_entries.
        auto_offset_scale: Offset parameter for log-space scaling (default: 0.05).
                          Produces ~10^(0.05) ≈ 1.122 or ~12.2% separation per dataset.
                          Increase for larger spacing, decrease for tighter clustering.
        fig_kw: Keyword arguments passed to plt.subplots() when creating new figure (default: None).
        chain_entries: Optional ChainsCollection or list of ChainEntry objects for auto color extraction.

    Returns:
        fig: Matplotlib figure object with dual k/ell axes and the plot

    Example:
        >>> chains = {'Planck': chain1, 'ACT': chain2}
        >>> fig = pf.plots.delta_CMB(chains)
    """
    # Extract colors from chain_entries if provided
    if chain_entries is not None:
        colors = [entry.color for entry in chain_entries]

    if binning is None:
        binning = LogBinningScheme()

    bin_centers = binning.bin_centers

    delta_mean = {
        lbl: np.array([chain[param].mean() for param in binning.bin_param_names])
        for lbl, chain in chains.items()
    }
    delta_std = {
        lbl: np.array([chain[param].std() for param in binning.bin_param_names])
        for lbl, chain in chains.items()
    }

    if ax is None:
        if fig_kw is None:
            fig_kw = {"figsize": (8, 5)}
        fig, ax, _ = create_pk_canvas(
            ax=None,
            fig_kw=fig_kw,
            primary_axis="k",
            ylabel=r"$\delta_i$",
            xlim=(binning.k_start, binning.k_end),
            xscale="log",
        )
    else:
        fig = ax.figure

    if colors is None:
        colors = ["steelblue", "orange", "green", "red", "purple"]

    n_datasets = len(chains)
    if offset is None:
        log_spacing = np.log10(bin_centers[1]) - np.log10(bin_centers[0])
        offset = auto_offset_scale * log_spacing

    offsets = offset * (np.arange(n_datasets) - (n_datasets - 1) / 2)

    for i, lbl in enumerate(chains.keys()):
        shifted_k = bin_centers * 10 ** (offsets[i])

        ax.errorbar(
            shifted_k,
            delta_mean[lbl],
            yerr=delta_std[lbl],
            fmt="o",
            color=colors[i],
            ecolor=colors[i],
            markersize=3,
            capsize=0,
            label=lbl,
            zorder=10,
        )

    ax.axhline(0, color="gray", linestyle="-", linewidth=0.8, alpha=0.5)
    ax.legend(
        ncols=len(chains) // 2,
        bbox_to_anchor=(0.5, 1.15),
        loc="upper center",
        frameon=False,
    )
    plt.tight_layout()
    return fig

landscape_LML(results, *, figsize=None, cmap='RdBu', levels=50, vmin=None, vmax=None, show_optimal=True, optimal_marker='*', optimal_markersize=10, optimal_color='black', show_contours=True, contour_levels=None, ax=None, colorbar_label='$\\log\\mathcal{B}$', fontsize=10, title=None, label_loc='upper_left')

Plot log(Bayes Factor) landscape from GPSignificanceResult objects.

Visualizes the \(\log(\mathrm{BF}) = \mathrm{LML} - \mathrm{LML}_{\mathrm{null}}\) landscape across \((\sigma, \ell)\) hyperparameter space, where positive values indicate evidence for correlated features.

For multiple results, creates a grid layout with a shared colorbar showing consistent scale across all panels.

PARAMETER DESCRIPTION
results

Single GPSignificanceResult or dict mapping labels to results. Must have lml_landscape populated (method='lml' or 'null+GP').

TYPE: Union[GPSignificanceResult, Dict[str, GPSignificanceResult]]

figsize

Figure size (width, height). Auto-computed if None.

TYPE: Optional[Tuple[float, float]] DEFAULT: None

cmap

Colormap name (default: 'RdBu').

TYPE: str DEFAULT: 'RdBu'

levels

Number of contour levels.

TYPE: int DEFAULT: 50

vmin

Minimum color scale limit. Computed from data if None.

TYPE: Optional[float] DEFAULT: None

vmax

Maximum color scale limit. Computed from data if None.

TYPE: Optional[float] DEFAULT: None

show_optimal

Mark optimal \((\sigma^*, \ell^*)\) point.

TYPE: bool DEFAULT: True

optimal_marker

Marker style for optimal point.

TYPE: str DEFAULT: '*'

optimal_markersize

Size of optimal point marker.

TYPE: float DEFAULT: 10

optimal_color

Color of optimal point marker.

TYPE: str DEFAULT: 'black'

show_contours

Draw confidence contours based on Wilks' theorem.

TYPE: bool DEFAULT: True

contour_levels

Confidence levels for contours (default: [0.68, 0.95]).

TYPE: Optional[List[float]] DEFAULT: None

ax

Existing axes to plot on. For multiple results, provide list.

TYPE: Optional[Union[Axes, List[Axes]]] DEFAULT: None

colorbar_label

Label for colorbar.

TYPE: str DEFAULT: '$\\log\\mathcal{B}$'

fontsize

Font size for labels (default: 12).

TYPE: int DEFAULT: 10

title

Optional title (single result) or ignored (multiple results).

TYPE: Optional[str] DEFAULT: None

label_loc

Label location inside axes (default: 'upper_left').

TYPE: str DEFAULT: 'upper_left'

RETURNS DESCRIPTION
Figure

Matplotlib Figure object.

Examples:

>>> # Single result
>>> result = pf.significance.gp_significance_test(chain, method='lml')
>>> fig = pf.plots.landscape_LML(result)
>>> # Multiple results with shared colorbar
>>> results = {'PR3': result_pr3, 'PR4': result_pr4, 'SPA': result_spa}
>>> fig = pf.plots.landscape_LML(results)
Source code in src/primefeat/plots/gp.py
def landscape_LML(
    results: Union["GPSignificanceResult", Dict[str, "GPSignificanceResult"]],
    *,
    figsize: Optional[Tuple[float, float]] = None,
    cmap: str = "RdBu",
    levels: int = 50,
    vmin: Optional[float] = None,
    vmax: Optional[float] = None,
    show_optimal: bool = True,
    optimal_marker: str = "*",
    optimal_markersize: float = 10,
    optimal_color: str = "black",
    show_contours: bool = True,
    contour_levels: Optional[List[float]] = None,
    ax: Optional[Union[plt.Axes, List[plt.Axes]]] = None,
    colorbar_label: str = r"$\log\mathcal{B}$",
    fontsize: int = 10,
    title: Optional[str] = None,
    label_loc: str = "upper_left",
) -> plt.Figure:
    """
    Plot log(Bayes Factor) landscape from GPSignificanceResult objects.

    Visualizes the $\\log(\\mathrm{BF}) = \\mathrm{LML} - \\mathrm{LML}_{\\mathrm{null}}$ landscape across
    $(\\sigma, \\ell)$ hyperparameter space, where positive values indicate evidence
    for correlated features.

    For multiple results, creates a grid layout with a shared colorbar showing
    consistent scale across all panels.

    Args:
        results: Single GPSignificanceResult or dict mapping labels to results.
            Must have `lml_landscape` populated (method='lml' or 'null+GP').
        figsize: Figure size (width, height). Auto-computed if None.
        cmap: Colormap name (default: 'RdBu').
        levels: Number of contour levels.
        vmin: Minimum color scale limit. Computed from data if None.
        vmax: Maximum color scale limit. Computed from data if None.
        show_optimal: Mark optimal $(\\sigma^*, \\ell^*)$ point.
        optimal_marker: Marker style for optimal point.
        optimal_markersize: Size of optimal point marker.
        optimal_color: Color of optimal point marker.
        show_contours: Draw confidence contours based on Wilks' theorem.
        contour_levels: Confidence levels for contours (default: [0.68, 0.95]).
        ax: Existing axes to plot on. For multiple results, provide list.
        colorbar_label: Label for colorbar.
        fontsize: Font size for labels (default: 12).
        title: Optional title (single result) or ignored (multiple results).
        label_loc: Label location inside axes (default: 'upper_left').

    Returns:
        Matplotlib Figure object.

    Examples:
        >>> # Single result
        >>> result = pf.significance.gp_significance_test(chain, method='lml')
        >>> fig = pf.plots.landscape_LML(result)

        >>> # Multiple results with shared colorbar
        >>> results = {'PR3': result_pr3, 'PR4': result_pr4, 'SPA': result_spa}
        >>> fig = pf.plots.landscape_LML(results)
    """
    from ..significance import GPSignificanceResult

    if isinstance(results, GPSignificanceResult):
        results_dict = {"": results}
    else:
        results_dict = results

    for name, result in results_dict.items():
        if result.lml_landscape is None:
            raise ValueError(
                f"Result '{name}' has no lml_landscape. "
                "Use method='lml' or method='null+GP' to compute it."
            )

    n_results = len(results_dict)

    if n_results == 1:
        ncols, nrows = 1, 1
    else:
        ncols = min(3, n_results)
        nrows = (n_results + ncols - 1) // ncols

    if figsize is None:
        figsize = (4 * ncols + 1, 3 * nrows)

    if ax is None:
        fig, axes_arr = plt.subplots(
            nrows,
            ncols,
            figsize=figsize,
            squeeze=False,
            sharex=(n_results > 1),
            sharey=(n_results > 1),
            constrained_layout=True,
        )
        if n_results > 1:
            fig.get_layout_engine().set(hspace=0, wspace=0)
        axes_list = axes_arr.flatten().tolist()
    else:
        axes_arr = None
        if isinstance(ax, plt.Axes):
            axes_list = [ax]
        else:
            axes_list = list(ax)
        fig = axes_list[0].figure

    _vmin_fixed = vmin is not None
    _vmax_fixed = vmax is not None

    all_log_bf = []
    for result in results_dict.values():
        landscape = result.lml_landscape
        log_bf_grid = landscape["lml_grid"] - landscape["null_lml"]
        all_log_bf.append(log_bf_grid)

    if vmin is None:
        vmin = min(np.nanmin(bf) for bf in all_log_bf)
    if vmax is None:
        vmax = max(np.nanmax(bf) for bf in all_log_bf)

    _extend_below = _vmin_fixed and any(np.nanmin(bf) < vmin for bf in all_log_bf)
    _extend_above = _vmax_fixed and any(np.nanmax(bf) > vmax for bf in all_log_bf)
    if _extend_below and _extend_above:
        _extend = "both"
    elif _extend_below:
        _extend = "min"
    elif _extend_above:
        _extend = "max"
    else:
        _extend = "neither"

    all_sigma_vals = np.concatenate(
        [r.lml_landscape["sigma_grid"] for r in results_dict.values()]
    )
    all_ell_vals = np.concatenate(
        [r.lml_landscape["length_scale_grid"] for r in results_dict.values()]
    )
    global_xlim = (
        float(
            max(
                r.lml_landscape["length_scale_grid"].min()
                for r in results_dict.values()
            )
        ),
        float(
            min(
                r.lml_landscape["length_scale_grid"].max()
                for r in results_dict.values()
            )
        ),
    )
    global_ylim = (
        float(max(r.lml_landscape["sigma_grid"].min() for r in results_dict.values())),
        float(min(r.lml_landscape["sigma_grid"].max() for r in results_dict.values())),
    )
    if global_xlim[0] >= global_xlim[1]:
        global_xlim = (float(all_ell_vals.min()), float(all_ell_vals.max()))
    if global_ylim[0] >= global_ylim[1]:
        global_ylim = (float(all_sigma_vals.min()), float(all_sigma_vals.max()))

    if contour_levels is None:
        contour_levels = [0.68, 0.95]

    images = []
    for idx, (name, result) in enumerate(results_dict.items()):
        if idx >= len(axes_list):
            break

        ax_curr = axes_list[idx]
        landscape = result.lml_landscape

        sigma_grid = landscape["sigma_grid"]
        length_scale_grid = landscape["length_scale_grid"]
        log_bf_grid = landscape["lml_grid"] - landscape["null_lml"]

        L, S = np.meshgrid(length_scale_grid, sigma_grid)

        im = ax_curr.contourf(
            L,
            S,
            log_bf_grid,
            levels=levels,
            cmap=cmap,
            vmin=vmin,
            vmax=vmax,
            extend=_extend,
        )
        images.append(im)

        if show_optimal:
            ax_curr.plot(
                landscape["optimal_length_scale"],
                landscape["optimal_sigma"],
                marker=optimal_marker,
                markersize=optimal_markersize,
                color=optimal_color,
                markeredgecolor="white",
                markeredgewidth=1.0,
                zorder=10,
            )

        if show_contours:
            from scipy.stats import chi2

            log_bf_opt = landscape["log_bayes_factor"]
            contour_colors = ["#000000", "#6F6E6C", "#D9D8CE"]
            contour_styles = ["-", "--", ":"]

            for i, conf_level in enumerate(contour_levels):
                chi2_threshold = chi2.ppf(conf_level, df=2)
                bf_threshold = log_bf_opt - 0.5 * chi2_threshold

                try:
                    contour = ax_curr.contour(
                        L,
                        S,
                        log_bf_grid,
                        levels=[bf_threshold],
                        colors=[contour_colors[i % len(contour_colors)]],
                        linewidths=1.5,
                        linestyles=[contour_styles[i % len(contour_styles)]],
                        alpha=0.8,
                    )

                    if i < 2:
                        ax_curr.clabel(
                            contour,
                            contour.levels,
                            inline=True,
                            fmt={contour.levels[0]: rf"${i + 1}\sigma$"},
                            fontsize=fontsize,
                        )
                except ValueError:
                    pass

        if n_results == 1:
            ax_curr.set_xlabel(r"Length scale $\ell_f$", fontsize=fontsize)
            ax_curr.set_ylabel(r"Signal amplitude $\sigma_f$", fontsize=fontsize)

        label_text = name if name else title
        if label_text:
            loc_map = {
                "upper_left": (0.05, 0.95, "left", "top"),
                "upper_right": (0.95, 0.95, "right", "top"),
                "lower_left": (0.05, 0.05, "left", "bottom"),
                "lower_right": (0.95, 0.05, "right", "bottom"),
            }
            if label_loc not in loc_map:
                valid = ", ".join(loc_map.keys())
                raise ValueError(f"label_loc must be one of {valid}, got '{label_loc}'")
            x, y, ha, va = loc_map[label_loc]
            ax_curr.text(
                x,
                y,
                label_text,
                transform=ax_curr.transAxes,
                ha=ha,
                va=va,
                fontsize=fontsize,
                bbox=dict(
                    boxstyle="round,pad=0.3",
                    facecolor="white",
                    edgecolor="black",
                    alpha=0.7,
                    linewidth=1.0,
                ),
            )

    for idx in range(len(results_dict), len(axes_list)):
        axes_list[idx].set_visible(False)

    if n_results > 1 and axes_arr is not None:
        axes_list[0].set_xlim(global_xlim)
        axes_list[0].set_ylim(global_ylim)
        for a in axes_arr.flatten():
            a.label_outer()
        fig.supxlabel(r"Length scale $\ell_f$", fontsize=fontsize)
        fig.supylabel(r"Signal amplitude $\sigma_f$", fontsize=fontsize)

    import matplotlib.colors as mcolors
    from matplotlib.cm import ScalarMappable
    from matplotlib.ticker import MaxNLocator

    norm = mcolors.Normalize(vmin=vmin, vmax=vmax)
    sm = ScalarMappable(cmap=cmap, norm=norm)
    sm.set_array([])

    if axes_arr is not None:
        cbar = fig.colorbar(sm, ax=axes_arr, pad=0.01, extend=_extend)
    else:
        cbar = fig.colorbar(sm, ax=axes_list[0], pad=0.01, extend=_extend)

    cbar.set_label("")
    cbar.ax.set_xlabel(colorbar_label, fontsize=fontsize, rotation=0, labelpad=6)
    cbar.ax.xaxis.set_label_position("bottom")
    cbar.ax.xaxis.tick_bottom()
    cbar.ax.yaxis.set_major_locator(MaxNLocator(integer=True))

    return fig