Skip to content

Principal Component Analysis module

pca

Dimensionality reduction and feature extraction for primordial features analysis.

This module provides tools to reduce the N-dimensional bin parameter space to a smaller number of interpretable components, revealing the dominant modes of variation in \(\delta(k)\).

Supports multiple backends:

  • NumPy/sklearn (default): CPU-based computation using sklearn's PCA
  • JAX: GPU-accelerated, JIT-compiled, autodiff-compatible

Use the backend parameter to select:

>>> results = perform_pca(chains, backend="numpy")  # Default
>>> results = perform_pca(chains, backend="jax")    # GPU-accelerated

PCAResults(n_components, explained_variance_ratio, cumulative_variance, components, transformed_data, pca_model, scaler, effective_dim, dataset_labels=None, backend='numpy', mean=None) dataclass

Results from Principal Component Analysis.

Container for all PCA outputs including components, scores, variance statistics, and metadata. Supports both NumPy and JAX backends.

ATTRIBUTE DESCRIPTION
n_components

Number of principal components computed.

TYPE: int

explained_variance_ratio

Fraction of variance explained by each PC, shape (n_components,).

TYPE: ndarray

cumulative_variance

Cumulative variance explained, shape (n_components,).

TYPE: ndarray

components

Principal component vectors (eigenvectors), shape (n_components, nbins).

TYPE: ndarray

transformed_data

Data projected to PC space (PC scores), shape (n_samples, n_components).

TYPE: ndarray

pca_model

sklearn PCA model for compatibility (None for JAX backend).

TYPE: Optional[PCA]

scaler

StandardScaler used to normalize data before PCA.

TYPE: StandardScaler

effective_dim

Number of PCs explaining 95% of variance.

TYPE: int

dataset_labels

Labels identifying which dataset each sample belongs to.

TYPE: Optional[List[str]]

backend

Backend used ("numpy" or "jax").

TYPE: str

mean

Data mean before centering, shape (nbins,). Used for backend-agnostic reconstruction.

TYPE: Optional[ndarray]

Examples:

>>> results = perform_pca(chains, nbins=20)
>>> print(f"Effective dimensionality: {results.effective_dim}")
>>> print(f"PC1 explains {results.explained_variance_ratio[0]:.1%}")
>>> print(f"Components shape: {results.components.shape}")

collect_delta_samples(chains_dict, nbins=20, param_pattern='delta_{i}')

Collect all \(\delta\) samples from all chains into a single array.

Aggregates samples from multiple MCMC chains into a single data matrix suitable for PCA or other dimensionality reduction methods.

PARAMETER DESCRIPTION
chains_dict

Dictionary mapping dataset labels to MCMC chains, or a single chain object (will be wrapped in dict).

TYPE: Dict

nbins

Number of bins (default: 20).

TYPE: int DEFAULT: 20

param_pattern

Parameter name pattern (default: "delta_{i}").

TYPE: str DEFAULT: 'delta_{i}'

RETURNS DESCRIPTION
Tuple[ndarray, List[str]]

Tuple of (X, labels): - X: Array of shape (n_total_samples, nbins) with all \(\delta\) values - labels: List of dataset labels for each sample

RAISES DESCRIPTION
KeyError

If parameter names are not found in chains.

Examples:

>>> # Collect from multiple chains
>>> X, labels = collect_delta_samples(chains, nbins=20)
>>> print(f"Total samples: {len(X)}")
>>> print(f"Unique datasets: {set(labels)}")
>>> # Single chain (auto-wrapped)
>>> X, labels = collect_delta_samples(single_chain, nbins=20)
Source code in src/primefeat/pca.py
def collect_delta_samples(
    chains_dict: Dict, nbins: int = 20, param_pattern: str = "delta_{i}"
) -> Tuple[np.ndarray, List[str]]:
    """
    Collect all $\\delta$ samples from all chains into a single array.

    Aggregates samples from multiple MCMC chains into a single data matrix
    suitable for PCA or other dimensionality reduction methods.

    Args:
        chains_dict: Dictionary mapping dataset labels to MCMC chains,
                     or a single chain object (will be wrapped in dict).
        nbins: Number of bins (default: 20).
        param_pattern: Parameter name pattern (default: "delta_{i}").

    Returns:
        Tuple of (X, labels):
            - X: Array of shape (n_total_samples, nbins) with all $\\delta$ values
            - labels: List of dataset labels for each sample

    Raises:
        KeyError: If parameter names are not found in chains.

    Examples:
        >>> # Collect from multiple chains
        >>> X, labels = collect_delta_samples(chains, nbins=20)
        >>> print(f"Total samples: {len(X)}")
        >>> print(f"Unique datasets: {set(labels)}")

        >>> # Single chain (auto-wrapped)
        >>> X, labels = collect_delta_samples(single_chain, nbins=20)
    """
    # Handle single chain object (wrap in dict)
    if not isinstance(chains_dict, dict):
        chains_dict = {"chain": chains_dict}

    all_deltas = []
    labels = []

    for label, chain in chains_dict.items():
        # Extract delta parameters for this chain
        delta_matrix = _extract_delta_from_chain(chain, nbins, param_pattern)

        all_deltas.append(delta_matrix)
        labels.extend([label] * len(delta_matrix))

    X = np.vstack(all_deltas)
    return X, labels

perform_pca(chains_dict, nbins=20, n_components=None, param_pattern='delta_{i}', mode='pooled', backend=None, verbose=True)

Perform Principal Component Analysis on \(\delta\) parameters.

Identifies the dominant modes of variation in the primordial power spectrum deviations across datasets. Supports pooled analysis (combining all chains) or individual analysis (separate PCA per chain).

PARAMETER DESCRIPTION
chains_dict

Dictionary mapping dataset labels to MCMC chains, or a single chain object (will be wrapped in dict).

TYPE: Dict

nbins

Number of bins (default: 20).

TYPE: int DEFAULT: 20

n_components

Number of components to compute (default: nbins).

TYPE: Optional[int] DEFAULT: None

param_pattern

Parameter name pattern (default: "delta_{i}").

TYPE: str DEFAULT: 'delta_{i}'

mode

Analysis mode (default: "pooled"): - "pooled": Perform PCA on pooled samples from all chains - "individual": Perform PCA separately on each chain

TYPE: Literal['pooled', 'individual'] DEFAULT: 'pooled'

backend

Computation backend (default: None for auto-detection): - "numpy": NumPy/sklearn backend (CPU) - "jax": JAX backend (GPU-accelerated, autodiff-compatible) - None: Auto-detect (prefer JAX if available)

TYPE: Optional[str] DEFAULT: None

verbose

Whether to print progress and results (default: True). Set to False for silent operation in scripts/pipelines.

TYPE: bool DEFAULT: True

RETURNS DESCRIPTION
Union[PCAResults, Dict[str, PCAResults]]

If mode="pooled": PCAResults object containing analysis results.

Union[PCAResults, Dict[str, PCAResults]]

If mode="individual": Dictionary mapping chain labels to PCAResults objects.

RAISES DESCRIPTION
ValueError

If mode is not "pooled" or "individual".

ImportError

If backend="jax" but JAX is not installed.

Examples:

>>> # Pooled mode (default)
>>> results = perform_pca(chains, nbins=20)
>>> print(f"Effective dimensionality: {results.effective_dim}")
>>> print(f"Top 5 PCs explain {results.cumulative_variance[4]:.1%}")
>>> # Individual mode - analyze each chain separately
>>> results_dict = perform_pca(chains, nbins=20, mode="individual")
>>> for label, result in results_dict.items():
...     print(f"{label}: {result.effective_dim} effective dims")
>>> # Silent operation (no console output)
>>> results = perform_pca(chains, verbose=False)
>>> # Use JAX backend for GPU acceleration
>>> results = perform_pca(chains, backend="jax")
Source code in src/primefeat/pca.py
def perform_pca(
    chains_dict: Dict,
    nbins: int = 20,
    n_components: Optional[int] = None,
    param_pattern: str = "delta_{i}",
    mode: Literal["pooled", "individual"] = "pooled",
    backend: Optional[str] = None,
    verbose: bool = True,
) -> Union[PCAResults, Dict[str, PCAResults]]:
    """
    Perform Principal Component Analysis on $\\delta$ parameters.

    Identifies the dominant modes of variation in the primordial power
    spectrum deviations across datasets. Supports pooled analysis (combining
    all chains) or individual analysis (separate PCA per chain).

    Args:
        chains_dict: Dictionary mapping dataset labels to MCMC chains,
                     or a single chain object (will be wrapped in dict).
        nbins: Number of bins (default: 20).
        n_components: Number of components to compute (default: nbins).
        param_pattern: Parameter name pattern (default: "delta_{i}").
        mode: Analysis mode (default: "pooled"):
              - "pooled": Perform PCA on pooled samples from all chains
              - "individual": Perform PCA separately on each chain
        backend: Computation backend (default: None for auto-detection):
              - "numpy": NumPy/sklearn backend (CPU)
              - "jax": JAX backend (GPU-accelerated, autodiff-compatible)
              - None: Auto-detect (prefer JAX if available)
        verbose: Whether to print progress and results (default: True).
                 Set to False for silent operation in scripts/pipelines.

    Returns:
        If mode="pooled":
            PCAResults object containing analysis results.
        If mode="individual":
            Dictionary mapping chain labels to PCAResults objects.

    Raises:
        ValueError: If mode is not "pooled" or "individual".
        ImportError: If backend="jax" but JAX is not installed.

    Examples:
        >>> # Pooled mode (default)
        >>> results = perform_pca(chains, nbins=20)
        >>> print(f"Effective dimensionality: {results.effective_dim}")
        >>> print(f"Top 5 PCs explain {results.cumulative_variance[4]:.1%}")

        >>> # Individual mode - analyze each chain separately
        >>> results_dict = perform_pca(chains, nbins=20, mode="individual")
        >>> for label, result in results_dict.items():
        ...     print(f"{label}: {result.effective_dim} effective dims")

        >>> # Silent operation (no console output)
        >>> results = perform_pca(chains, verbose=False)

        >>> # Use JAX backend for GPU acceleration
        >>> results = perform_pca(chains, backend="jax")
    """
    # Validate mode
    if mode not in ["pooled", "individual"]:
        raise ValueError(f"mode must be 'pooled' or 'individual', got '{mode}'")

    # Handle single chain object (wrap in dict)
    if not isinstance(chains_dict, dict):
        chains_dict = {"chain": chains_dict}

    # Set default number of components
    if n_components is None:
        n_components = nbins

    # Get backend once (eliminate duplication)
    pca_backend = _get_pca_backend(backend)
    backend_name = pca_backend.backend_name
    backend_suffix = f" [{backend_name}]" if backend_name != "numpy" else ""

    # MODE: POOLED (default, backward compatible)
    if mode == "pooled":
        # Collect all delta samples
        X, labels = collect_delta_samples(chains_dict, nbins, param_pattern)

        # Display input info
        if verbose:
            input_panel = Panel(
                f"[key]Samples:[/key] [value]{X.shape[0]:,}[/value] from "
                f"[value]{len(chains_dict)}[/value] dataset(s)\n"
                f"[key]Dimensionality:[/key] [value]{X.shape[1]}[/value] bins\n"
                f"[key]Backend:[/key] [value]{backend_name}[/value]",
                title=f"[header]PCA Input Data (Pooled){backend_suffix}[/header]",
                border_style="cyan",
            )
            console.print(input_panel)

        # Perform PCA on pooled data (input panel already shown above)
        return _perform_pca_on_data(
            X=X,
            labels=labels,
            n_components=n_components,
            nbins=nbins,
            dataset_name=None,
            show_input_panel=False,
            show_results=True,
            backend=backend,
            verbose=verbose,
        )

    # MODE: INDIVIDUAL
    else:  # mode == "individual"
        results_dict = {}

        if verbose:
            console.print(
                Panel(
                    f"[key]Analyzing:[/key] [value]{len(chains_dict)}[/value] "
                    f"dataset(s) individually\n"
                    f"[key]Backend:[/key] [value]{backend_name}[/value]",
                    title=f"[header]PCA Individual Mode{backend_suffix}[/header]",
                    border_style="magenta",
                )
            )
            console.print()

        for label, chain in chains_dict.items():
            # Extract delta parameters for this chain (using helper)
            delta_matrix = _extract_delta_from_chain(chain, nbins, param_pattern)

            # Create labels for this chain
            chain_labels = [label] * len(delta_matrix)

            # Perform PCA on this chain
            result = _perform_pca_on_data(
                X=delta_matrix,
                labels=chain_labels,
                n_components=n_components,
                nbins=nbins,
                dataset_name=label,
                show_input_panel=True,
                show_results=True,
                backend=backend,
                verbose=verbose,
            )

            results_dict[label] = result

        # Print summary across all chains
        if verbose:
            summary_table = Table(
                title="Summary Across Datasets",
                show_header=True,
            )
            summary_table.add_column("Dataset", style="cyan")
            summary_table.add_column("Samples", justify="right", style="yellow")
            summary_table.add_column("d_eff", justify="right", style="green")
            summary_table.add_column("Top PC Var", justify="right", style="blue")

            for label, result in results_dict.items():
                summary_table.add_row(
                    label,
                    f"{result.transformed_data.shape[0]:,}",
                    str(result.effective_dim),
                    f"{result.explained_variance_ratio[0]:.2%}",
                )

            console.print()
            console.print(summary_table)
            console.print()

        return results_dict

perform_ica(chains_dict, nbins=20, n_components=10, param_pattern='delta_{i}', random_state=42)

Perform Independent Component Analysis on \(\delta\) parameters.

ICA finds statistically independent patterns, which can be better than PCA for identifying localized features or non-Gaussian structures in the primordial power spectrum deviations.

PARAMETER DESCRIPTION
chains_dict

Dictionary mapping dataset labels to MCMC chains, or a single chain object (will be wrapped in dict).

TYPE: Dict

nbins

Number of bins (default: 20).

TYPE: int DEFAULT: 20

n_components

Number of independent components to extract (default: 10).

TYPE: int DEFAULT: 10

param_pattern

Parameter name pattern (default: "delta_{i}").

TYPE: str DEFAULT: 'delta_{i}'

random_state

Random seed for reproducibility (default: 42).

TYPE: int DEFAULT: 42

RETURNS DESCRIPTION
Tuple[FastICA, ndarray, ndarray]

Tuple of (ica_model, X_ica, components): - ica_model: Fitted sklearn FastICA object - X_ica: Transformed data in IC space, shape (n_samples, n_components) - components: Independent components (unmixing matrix), shape (n_components, nbins)

RAISES DESCRIPTION
KeyError

If parameter names are not found in chains.

Examples:

>>> ica, X_ica, components = perform_ica(chains, nbins=20, n_components=10)
>>> print(f"Converged in {ica.n_iter_} iterations")
>>> # Plot IC1
>>> plt.plot(components[0])
Source code in src/primefeat/pca.py
def perform_ica(
    chains_dict: Dict,
    nbins: int = 20,
    n_components: int = 10,
    param_pattern: str = "delta_{i}",
    random_state: int = 42,
) -> Tuple[FastICA, np.ndarray, np.ndarray]:
    """
    Perform Independent Component Analysis on $\\delta$ parameters.

    ICA finds statistically independent patterns, which can be better than PCA
    for identifying localized features or non-Gaussian structures in the
    primordial power spectrum deviations.

    Args:
        chains_dict: Dictionary mapping dataset labels to MCMC chains,
                     or a single chain object (will be wrapped in dict).
        nbins: Number of bins (default: 20).
        n_components: Number of independent components to extract (default: 10).
        param_pattern: Parameter name pattern (default: "delta_{i}").
        random_state: Random seed for reproducibility (default: 42).

    Returns:
        Tuple of (ica_model, X_ica, components):
            - ica_model: Fitted sklearn FastICA object
            - X_ica: Transformed data in IC space, shape (n_samples, n_components)
            - components: Independent components (unmixing matrix),
                         shape (n_components, nbins)

    Raises:
        KeyError: If parameter names are not found in chains.

    Examples:
        >>> ica, X_ica, components = perform_ica(chains, nbins=20, n_components=10)
        >>> print(f"Converged in {ica.n_iter_} iterations")
        >>> # Plot IC1
        >>> plt.plot(components[0])
    """
    # Handle single chain object (wrap in dict)
    if not isinstance(chains_dict, dict):
        chains_dict = {"chain": chains_dict}

    # Collect samples
    X, labels = collect_delta_samples(chains_dict, nbins, param_pattern)

    from rich.status import Status

    with Status("[cyan]Performing ICA decomposition...", console=console) as status:
        # Standardize first (recommended for ICA)
        scaler = StandardScaler()
        X_scaled = scaler.fit_transform(X)

        # Fit ICA
        ica = FastICA(
            n_components=n_components, random_state=random_state, max_iter=500, tol=1e-4
        )
        X_ica = ica.fit_transform(X_scaled)

        # Get components (sources)
        components = ica.components_

        status.update(f"[cyan]ICA converged after {ica.n_iter_} iterations")

    print_success(f"ICA converged after {ica.n_iter_} iterations")

    return ica, X_ica, components

compute_reconstruction_error(results, X_original, n_components)

Compute reconstruction error using only n_components PCs.

Quantifies how much information is lost by using fewer components. Useful for determining the optimal number of PCs to retain.

Note

Requires NumPy backend (uses pca_model.inverse_transform).

PARAMETER DESCRIPTION
results

PCAResults from perform_pca().

TYPE: PCAResults

X_original

Original (unstandardized) data, shape (n_samples, nbins).

TYPE: ndarray

n_components

Number of PCs to use for reconstruction.

TYPE: int

RETURNS DESCRIPTION
float

Root mean squared error (RMSE) of reconstruction in standardized space.

RAISES DESCRIPTION
ValueError

If n_components > results.n_components.

AttributeError

If results.pca_model is None (JAX backend).

Examples:

>>> results = perform_pca(chains, nbins=20, backend="numpy")
>>> X, _ = collect_delta_samples(chains, nbins=20)
>>> rmse_5pc = compute_reconstruction_error(results, X, n_components=5)
>>> rmse_10pc = compute_reconstruction_error(results, X, n_components=10)
>>> print(f"5 PCs: RMSE={rmse_5pc:.4f}, 10 PCs: RMSE={rmse_10pc:.4f}")
Source code in src/primefeat/pca.py
def compute_reconstruction_error(
    results: PCAResults, X_original: np.ndarray, n_components: int
) -> float:
    """
    Compute reconstruction error using only n_components PCs.

    Quantifies how much information is lost by using fewer components.
    Useful for determining the optimal number of PCs to retain.

    Note:
        Requires NumPy backend (uses pca_model.inverse_transform).

    Args:
        results: PCAResults from perform_pca().
        X_original: Original (unstandardized) data, shape (n_samples, nbins).
        n_components: Number of PCs to use for reconstruction.

    Returns:
        Root mean squared error (RMSE) of reconstruction in standardized space.

    Raises:
        ValueError: If n_components > results.n_components.
        AttributeError: If results.pca_model is None (JAX backend).

    Examples:
        >>> results = perform_pca(chains, nbins=20, backend="numpy")
        >>> X, _ = collect_delta_samples(chains, nbins=20)
        >>> rmse_5pc = compute_reconstruction_error(results, X, n_components=5)
        >>> rmse_10pc = compute_reconstruction_error(results, X, n_components=10)
        >>> print(f"5 PCs: RMSE={rmse_5pc:.4f}, 10 PCs: RMSE={rmse_10pc:.4f}")
    """
    # Standardize
    X_scaled = results.scaler.transform(X_original)

    # Project to PC space and back
    X_proj = results.pca_model.transform(X_scaled)[:, :n_components]
    X_reconstructed = results.pca_model.inverse_transform(
        np.column_stack(
            [X_proj, np.zeros((X_proj.shape[0], results.n_components - n_components))]
        )
    )

    # Compute RMSE
    rmse = np.sqrt(np.mean((X_scaled - X_reconstructed) ** 2))

    return rmse

RMSE_vs_n_components(chains_dict, nbins=20, max_components=None)

Compute reconstruction error (RMSE) as a function of number of PCs.

Evaluates how reconstruction quality improves as more principal components are included. Useful for determining the optimal number of PCs to retain.

PARAMETER DESCRIPTION
chains_dict

Dictionary mapping dataset labels to MCMC chains, or a single chain object (will be wrapped in dict).

nbins

Number of bins (default: 20).

DEFAULT: 20

max_components

Maximum number of components to test (default: nbins).

DEFAULT: None

RETURNS DESCRIPTION

Tuple of (error_values, pca_result):

  • error_values: List of dicts, one per component count, containing:

    • n_components: Number of PCs used
    • mean_rmse: Mean RMSE across samples
    • median_rmse: Median RMSE across samples
    • std_rmse: Standard deviation of RMSE
    • percentile_95: 95th percentile of RMSE
    • variance_explained: Fraction of variance explained
    • total_mse: Mean squared error across all samples
  • pca_result: PCAResults object from the analysis

Examples:

>>> errors, pca = RMSE_vs_n_components(chains, nbins=20)
>>> # Find elbow point
>>> for e in errors[:5]:
...     print(f"PCs={e['n_components']}: RMSE={e['mean_rmse']:.4f}")
>>> # Plot reconstruction error curve
>>> import matplotlib.pyplot as plt
>>> plt.plot([e['n_components'] for e in errors],
...          [e['mean_rmse'] for e in errors])
Source code in src/primefeat/pca.py
def RMSE_vs_n_components(chains_dict, nbins=20, max_components=None):
    """
    Compute reconstruction error (RMSE) as a function of number of PCs.

    Evaluates how reconstruction quality improves as more principal components
    are included. Useful for determining the optimal number of PCs to retain.

    Args:
        chains_dict: Dictionary mapping dataset labels to MCMC chains,
                    or a single chain object (will be wrapped in dict).
        nbins: Number of bins (default: 20).
        max_components: Maximum number of components to test (default: nbins).

    Returns:
        Tuple of (error_values, pca_result):

            - error_values: List of dicts, one per component count, containing:

                - n_components: Number of PCs used
                - mean_rmse: Mean RMSE across samples
                - median_rmse: Median RMSE across samples
                - std_rmse: Standard deviation of RMSE
                - percentile_95: 95th percentile of RMSE
                - variance_explained: Fraction of variance explained
                - total_mse: Mean squared error across all samples

            - pca_result: PCAResults object from the analysis

    Examples:
        >>> errors, pca = RMSE_vs_n_components(chains, nbins=20)
        >>> # Find elbow point
        >>> for e in errors[:5]:
        ...     print(f"PCs={e['n_components']}: RMSE={e['mean_rmse']:.4f}")
        >>> # Plot reconstruction error curve
        >>> import matplotlib.pyplot as plt
        >>> plt.plot([e['n_components'] for e in errors],
        ...          [e['mean_rmse'] for e in errors])
    """
    # Perform PCA to get the model
    pca_result = perform_pca(chains_dict, nbins=nbins)

    # Collect all delta samples
    X, labels = collect_delta_samples(chains_dict, nbins=nbins)

    max_components = max_components or nbins
    error_values = []

    # IMPORTANT: PCA uses StandardScaler, so we need to work in scaled space
    # then transform back to original space for interpretable errors

    for k in range(1, max_components + 1):
        # Reconstruct in scaled space using PCA model
        # PCA does: X_scaled = scaler.transform(X)
        #           X_pca = X_scaled @ components.T
        # Inverse: X_scaled_recon = X_pca[:, :k] @ components[:k, :]
        #          X_recon = scaler.inverse_transform(X_scaled_recon)

        X_scaled = pca_result.scaler.transform(X)
        X_pca_k = pca_result.transformed_data[:, :k]  # First k PC scores

        # Reconstruct in scaled space
        X_scaled_recon = X_pca_k @ pca_result.components[:k, :]

        # Transform back to original space
        X_recon = pca_result.scaler.inverse_transform(X_scaled_recon)

        # Compute reconstruction error (RMSE per sample)
        residuals = X - X_recon
        mse_per_sample = np.mean(residuals**2, axis=1)  # Mean over bins
        rmse_per_sample = np.sqrt(mse_per_sample)

        # Also compute in terms of variance explained
        total_variance = np.var(X)
        residual_variance = np.var(residuals)
        variance_explained = 1 - (residual_variance / total_variance)

        error_values.append(
            {
                "n_components": k,
                "mean_rmse": np.mean(rmse_per_sample),
                "median_rmse": np.median(rmse_per_sample),
                "std_rmse": np.std(rmse_per_sample),
                "percentile_95": np.percentile(rmse_per_sample, 95),
                "variance_explained": variance_explained,
                "total_mse": np.mean(mse_per_sample),
            }
        )

    return error_values, pca_result

reconstruct_delta_from_pcs(results, pc_indices, sample_indices=None, return_mean=False)

Reconstruct \(\delta(k)\) using only specified principal components.

Allows selective reconstruction to study impact of individual PCs or groups. For example, reconstruct using PCs 2-8 to see the effect of dropping PC1.

Backend-agnostic: works with both NumPy and JAX PCA results.

PARAMETER DESCRIPTION
results

PCAResults object from perform_pca()

TYPE: PCAResults

pc_indices

List of PC indices to use (1-indexed, e.g., [2, 3, 4, 5, 6, 7, 8]) Uses 1-indexing to match standard PC naming (PC1, PC2, ...)

TYPE: List[int]

sample_indices

Optional sample index/indices to reconstruct: - None: reconstruct all samples - int: single sample - List[int]: multiple specific samples

TYPE: Optional[Union[int, List[int]]] DEFAULT: None

return_mean

If True, average reconstruction across selected samples

TYPE: bool DEFAULT: False

RETURNS DESCRIPTION
Dict[str, Any]

Dictionary containing:

Dict[str, Any]
  • 'delta': Reconstructed \(\delta(k)\) values, shape depends on inputs
Dict[str, Any]
  • 'pc_indices': PC indices used (1-indexed)
Dict[str, Any]
  • 'n_pcs_used': Number of PCs used
Dict[str, Any]
  • 'variance_explained': Cumulative variance explained by selected PCs
Dict[str, Any]
  • 'n_samples': Number of samples reconstructed
Dict[str, Any]
  • 'shape': Shape of delta array

Examples:

>>> # Reconstruct using only PCs 2-8 (drop PC1)
>>> result = reconstruct_delta_from_pcs(pca_results, pc_indices=[2,3,4,5,6,7,8])
>>> delta_partial = result['delta']  # Shape: (n_samples, nbins)
>>> # See impact of PC1 alone
>>> result = reconstruct_delta_from_pcs(pca_results, pc_indices=[1])
>>> delta_pc1_only = result['delta']
>>> # Reconstruct single sample using PCs 1-5
>>> result = reconstruct_delta_from_pcs(pca_results, [1,2,3,4,5], sample_indices=42)
>>> delta_sample = result['delta']  # Shape: (nbins,)
>>> # Get mean reconstruction across all samples (for plotting)
>>> result = reconstruct_delta_from_pcs(pca_results, [2,3,4,5,6,7,8], return_mean=True)
>>> delta_mean = result['delta']  # Shape: (nbins,)
Source code in src/primefeat/pca.py
def reconstruct_delta_from_pcs(
    results: PCAResults,
    pc_indices: List[int],
    sample_indices: Optional[Union[int, List[int]]] = None,
    return_mean: bool = False,
) -> Dict[str, Any]:
    """
    Reconstruct $\\delta(k)$ using only specified principal components.

    Allows selective reconstruction to study impact of individual PCs or groups.
    For example, reconstruct using PCs 2-8 to see the effect of dropping PC1.

    Backend-agnostic: works with both NumPy and JAX PCA results.

    Args:
        results: PCAResults object from perform_pca()
        pc_indices: List of PC indices to use (1-indexed, e.g., [2, 3, 4, 5, 6, 7, 8])
                   Uses 1-indexing to match standard PC naming (PC1, PC2, ...)
        sample_indices: Optional sample index/indices to reconstruct:
                       - None: reconstruct all samples
                       - int: single sample
                       - List[int]: multiple specific samples
        return_mean: If True, average reconstruction across selected samples

    Returns:
        Dictionary containing:
        - 'delta': Reconstructed $\\delta(k)$ values, shape depends on inputs
        - 'pc_indices': PC indices used (1-indexed)
        - 'n_pcs_used': Number of PCs used
        - 'variance_explained': Cumulative variance explained by selected PCs
        - 'n_samples': Number of samples reconstructed
        - 'shape': Shape of delta array

    Examples:
        >>> # Reconstruct using only PCs 2-8 (drop PC1)
        >>> result = reconstruct_delta_from_pcs(pca_results, pc_indices=[2,3,4,5,6,7,8])
        >>> delta_partial = result['delta']  # Shape: (n_samples, nbins)

        >>> # See impact of PC1 alone
        >>> result = reconstruct_delta_from_pcs(pca_results, pc_indices=[1])
        >>> delta_pc1_only = result['delta']

        >>> # Reconstruct single sample using PCs 1-5
        >>> result = reconstruct_delta_from_pcs(pca_results, [1,2,3,4,5], sample_indices=42)
        >>> delta_sample = result['delta']  # Shape: (nbins,)

        >>> # Get mean reconstruction across all samples (for plotting)
        >>> result = reconstruct_delta_from_pcs(pca_results, [2,3,4,5,6,7,8], return_mean=True)
        >>> delta_mean = result['delta']  # Shape: (nbins,)
    """
    # Validate PC indices (1-indexed)
    pc_indices = np.asarray(pc_indices)
    if np.any(pc_indices < 1) or np.any(pc_indices > results.n_components):
        raise ValueError(
            f"PC indices must be between 1 and {results.n_components}, "
            f"got {pc_indices.tolist()}"
        )

    # Convert to 0-indexed for array operations
    pc_indices_0 = pc_indices - 1

    # Get PC scores for selected components
    pc_scores_selected = results.transformed_data[:, pc_indices_0]

    # Get corresponding components
    components_selected = results.components[pc_indices_0, :]

    # Reconstruct in scaled space
    # X_scaled = PC_scores @ Components
    X_scaled_recon = pc_scores_selected @ components_selected

    # Transform back to original space
    # Use scaler.inverse_transform for backend compatibility
    delta_reconstructed = results.scaler.inverse_transform(X_scaled_recon)

    # Handle sample selection
    if sample_indices is not None:
        if isinstance(sample_indices, int):
            delta_reconstructed = delta_reconstructed[sample_indices, :]
        else:
            delta_reconstructed = delta_reconstructed[sample_indices, :]

    # Compute mean if requested
    if return_mean:
        delta_reconstructed = np.mean(delta_reconstructed, axis=0)

    # Compute variance explained by selected PCs
    variance_explained = np.sum(results.explained_variance_ratio[pc_indices_0])

    # Build result dictionary
    result = {
        "delta": delta_reconstructed,
        "pc_indices": pc_indices.tolist(),
        "n_pcs_used": len(pc_indices),
        "variance_explained": float(variance_explained),
        "n_samples": delta_reconstructed.shape[0]
        if delta_reconstructed.ndim > 1
        else 1,
        "shape": delta_reconstructed.shape,
    }

    return result

compare_pc_reconstructions(results, pc_sets, k_values=None, sample_indices=None, figsize=(12, 6), title=None)

Compare \(\delta(k)\) reconstructions using different PC subsets.

Useful for visualizing the impact of specific PCs or understanding how reconstruction quality changes with different PC selections.

PARAMETER DESCRIPTION
results

PCAResults object from perform_pca()

TYPE: PCAResults

pc_sets

Dictionary mapping labels to PC index lists Example: { 'Full (1-10)': [1,2,3,4,5,6,7,8,9,10], 'Without PC1': [2,3,4,5,6,7,8,9,10], 'PC1 only': [1], 'PCs 2-5': [2,3,4,5] }

TYPE: Dict[str, List[int]]

k_values

Optional k-values for x-axis (Mpc\(^{-1}\)) If None, uses bin indices

TYPE: Optional[ndarray] DEFAULT: None

sample_indices

Which samples to plot (None = mean over all)

TYPE: Optional[Union[int, List[int]]] DEFAULT: None

figsize

Figure size (width, height)

TYPE: Tuple[float, float] DEFAULT: (12, 6)

title

Optional plot title

TYPE: Optional[str] DEFAULT: None

RETURNS DESCRIPTION
Figure

Figure object

Example:

>>> pc_sets = {
...     'Full (PCs 1-10)': list(range(1, 11)),
...     'Drop PC1 (PCs 2-10)': list(range(2, 11)),
...     'PC1 only': [1],
...     'PCs 2-8': [2,3,4,5,6,7,8]
... }
>>> fig = compare_pc_reconstructions(pca_results, pc_sets, k_values=k)
>>> plt.show()
Source code in src/primefeat/pca.py
def compare_pc_reconstructions(
    results: PCAResults,
    pc_sets: Dict[str, List[int]],
    k_values: Optional[np.ndarray] = None,
    sample_indices: Optional[Union[int, List[int]]] = None,
    figsize: Tuple[float, float] = (12, 6),
    title: Optional[str] = None,
) -> plt.Figure:
    """
    Compare $\\delta(k)$ reconstructions using different PC subsets.

    Useful for visualizing the impact of specific PCs or understanding
    how reconstruction quality changes with different PC selections.

    Args:
        results: PCAResults object from perform_pca()
        pc_sets: Dictionary mapping labels to PC index lists
                Example: {
                    'Full (1-10)': [1,2,3,4,5,6,7,8,9,10],
                    'Without PC1': [2,3,4,5,6,7,8,9,10],
                    'PC1 only': [1],
                    'PCs 2-5': [2,3,4,5]
                }
        k_values: Optional k-values for x-axis (Mpc$^{-1}$)
                 If None, uses bin indices
        sample_indices: Which samples to plot (None = mean over all)
        figsize: Figure size (width, height)
        title: Optional plot title

    Returns:
        Figure object

    Example:

        >>> pc_sets = {
        ...     'Full (PCs 1-10)': list(range(1, 11)),
        ...     'Drop PC1 (PCs 2-10)': list(range(2, 11)),
        ...     'PC1 only': [1],
        ...     'PCs 2-8': [2,3,4,5,6,7,8]
        ... }
        >>> fig = compare_pc_reconstructions(pca_results, pc_sets, k_values=k)
        >>> plt.show()
    """
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)

    # Use bin indices if k_values not provided
    nbins = results.components.shape[1]
    x_values = k_values if k_values is not None else np.arange(1, nbins + 1)
    x_label = r"$k$ [Mpc$^{-1}$]" if k_values is not None else "Bin Index"

    # Colors for different PC sets
    colors = plt.cm.tab10(np.linspace(0, 1, len(pc_sets)))

    # Left panel: Reconstructions
    for (label, pc_list), color in zip(pc_sets.items(), colors):
        recon = reconstruct_delta_from_pcs(
            results, pc_list, sample_indices=sample_indices, return_mean=True
        )
        delta = recon["delta"]
        var_exp = recon["variance_explained"]

        label_with_var = f"{label} ({var_exp:.1%} var)"
        ax1.plot(x_values, delta, label=label_with_var, color=color, linewidth=2)

    ax1.axhline(0, color="k", linestyle="--", alpha=0.3, linewidth=1)
    ax1.set_xlabel(x_label)
    ax1.set_ylabel(r"$\delta(k)$")
    ax1.legend(fontsize=9, framealpha=0.9)
    ax1.grid(True, alpha=0.3)
    if k_values is not None:
        ax1.set_xscale("log")

    # Right panel: Differences from first reconstruction (typically "full")
    reference_label = list(pc_sets.keys())[0]
    reference_recon = reconstruct_delta_from_pcs(
        results,
        pc_sets[reference_label],
        sample_indices=sample_indices,
        return_mean=True,
    )
    reference_delta = reference_recon["delta"]

    for (label, pc_list), color in zip(list(pc_sets.items())[1:], colors[1:]):
        recon = reconstruct_delta_from_pcs(
            results, pc_list, sample_indices=sample_indices, return_mean=True
        )
        delta = recon["delta"]
        difference = delta - reference_delta

        ax2.plot(
            x_values,
            difference,
            label=f"{label} - {reference_label}",
            color=color,
            linewidth=2,
        )

    ax2.axhline(0, color="k", linestyle="--", alpha=0.3, linewidth=1)
    ax2.set_xlabel(x_label)
    ax2.set_ylabel(r"$\Delta\delta(k)$ from reference")
    ax2.legend(fontsize=9, framealpha=0.9)
    ax2.grid(True, alpha=0.3)
    if k_values is not None:
        ax2.set_xscale("log")

    if title:
        fig.suptitle(title, fontsize=14, fontweight="bold")

    plt.tight_layout()
    return fig

analyze_bin_correlations(chains_dict, nbins=20)

Compute correlation matrix between \(\delta\) bins.

Shows which bins are correlated, typically due to smoothness constraints in the primordial power spectrum reconstruction or cosmic variance. High correlations between adjacent bins indicate the data prefers smooth \(\delta(k)\) variations.

PARAMETER DESCRIPTION
chains_dict

Dictionary mapping dataset labels to MCMC chains, or a single chain object (will be wrapped in dict).

TYPE: Dict

nbins

Number of bins (default: 20).

TYPE: int DEFAULT: 20

RETURNS DESCRIPTION
ndarray

Correlation matrix of shape (nbins, nbins). Element [i,j] is the

ndarray

Pearson correlation coefficient between bins i and j.

Examples:

>>> corr = analyze_bin_correlations(chains, nbins=20)
>>> # Visualize correlation structure
>>> plt.imshow(corr, cmap='RdBu_r', vmin=-1, vmax=1)
>>> plt.colorbar(label='Correlation')
>>> # Check adjacent bin correlations
>>> adj_corr = np.diag(corr, k=1).mean()
>>> print(f"Mean adjacent correlation: {adj_corr:.3f}")
Source code in src/primefeat/pca.py
def analyze_bin_correlations(chains_dict: Dict, nbins: int = 20) -> np.ndarray:
    """
    Compute correlation matrix between $\\delta$ bins.

    Shows which bins are correlated, typically due to smoothness constraints
    in the primordial power spectrum reconstruction or cosmic variance.
    High correlations between adjacent bins indicate the data prefers
    smooth $\\delta(k)$ variations.

    Args:
        chains_dict: Dictionary mapping dataset labels to MCMC chains,
                     or a single chain object (will be wrapped in dict).
        nbins: Number of bins (default: 20).

    Returns:
        Correlation matrix of shape (nbins, nbins). Element [i,j] is the
        Pearson correlation coefficient between bins i and j.

    Examples:
        >>> corr = analyze_bin_correlations(chains, nbins=20)
        >>> # Visualize correlation structure
        >>> plt.imshow(corr, cmap='RdBu_r', vmin=-1, vmax=1)
        >>> plt.colorbar(label='Correlation')
        >>> # Check adjacent bin correlations
        >>> adj_corr = np.diag(corr, k=1).mean()
        >>> print(f"Mean adjacent correlation: {adj_corr:.3f}")
    """
    # Handle single chain object (wrap in dict)
    if not isinstance(chains_dict, dict):
        chains_dict = {"chain": chains_dict}

    X, _ = collect_delta_samples(chains_dict, nbins)

    # Compute correlation matrix
    corr_matrix = np.corrcoef(X.T)

    return corr_matrix

variance_decomposition(pca_pooled, N_pcs=10, chains_dict=None)

Decompose variance into between-dataset and within-dataset components.

Analyzes how much of the variance in each principal component is due to differences between datasets versus variation within individual datasets. Uses ANOVA F-statistics to quantify dataset separation and computes signal-to-noise ratios based on effective sample sizes.

PARAMETER DESCRIPTION
pca_pooled

PCAResults object from perform_pca() with mode="pooled".

N_pcs

Number of principal components to analyze (default: 10).

DEFAULT: 10

chains_dict

Optional dict of chains to compute effective sample sizes. If None, uses total sample count for SNR calculation.

DEFAULT: None

RETURNS DESCRIPTION

Tuple of (variance_between, variance_within, variance_total, f_statistics, snr_values, p_values): - variance_between: List of between-dataset variance for each PC - variance_within: List of within-dataset variance for each PC - variance_total: List of total variance for each PC - f_statistics: List of F-statistics (between/within ratio) - snr_values: List of signal-to-noise ratios - p_values: List of ANOVA p-values testing mean differences

Examples:

>>> results = perform_pca(chains, mode="pooled")
>>> var_b, var_w, var_t, f_stats, snr, pvals = variance_decomposition(
...     results, N_pcs=10, chains_dict=chains
... )
>>> # High F-stat indicates PC separates datasets well
>>> print(f"PC1 F-statistic: {f_stats[0]:.3f}")
>>> # Low p-value indicates significant dataset separation
>>> print(f"PC1 p-value: {pvals[0]:.2e}")
Source code in src/primefeat/pca.py
def variance_decomposition(pca_pooled, N_pcs=10, chains_dict=None):
    """
    Decompose variance into between-dataset and within-dataset components.

    Analyzes how much of the variance in each principal component is due to
    differences between datasets versus variation within individual datasets.
    Uses ANOVA F-statistics to quantify dataset separation and computes
    signal-to-noise ratios based on effective sample sizes.

    Args:
        pca_pooled: PCAResults object from perform_pca() with mode="pooled".
        N_pcs: Number of principal components to analyze (default: 10).
        chains_dict: Optional dict of chains to compute effective sample sizes.
                    If None, uses total sample count for SNR calculation.

    Returns:
        Tuple of (variance_between, variance_within, variance_total,
                 f_statistics, snr_values, p_values):
            - variance_between: List of between-dataset variance for each PC
            - variance_within: List of within-dataset variance for each PC
            - variance_total: List of total variance for each PC
            - f_statistics: List of F-statistics (between/within ratio)
            - snr_values: List of signal-to-noise ratios
            - p_values: List of ANOVA p-values testing mean differences

    Examples:
        >>> results = perform_pca(chains, mode="pooled")
        >>> var_b, var_w, var_t, f_stats, snr, pvals = variance_decomposition(
        ...     results, N_pcs=10, chains_dict=chains
        ... )
        >>> # High F-stat indicates PC separates datasets well
        >>> print(f"PC1 F-statistic: {f_stats[0]:.3f}")
        >>> # Low p-value indicates significant dataset separation
        >>> print(f"PC1 p-value: {pvals[0]:.2e}")
    """
    from scipy.stats import f_oneway

    # Get transformed data and labels
    transformed = pca_pooled.transformed_data
    labels_array = np.array(pca_pooled.dataset_labels)

    # Compute average effective sample size across chains
    if chains_dict is not None:
        if not isinstance(chains_dict, dict):
            chains_dict = {"chain": chains_dict}
        n_eff_list = [chain.getEffectiveSamples() for chain in chains_dict.values()]
        avg_n_eff = np.mean(n_eff_list)
    else:
        # Fallback: use total number of samples
        avg_n_eff = len(transformed)

    # Get explained variance ratio from PCA model (fraction of total variance)
    explained_variance_ratio = pca_pooled.explained_variance_ratio

    variance_between = []
    variance_within = []
    variance_total = []
    f_statistics = []
    snr_values = []
    p_values = []

    n_pcs = min(N_pcs, pca_pooled.n_components)

    # Create variance decomposition table
    decomp_table = Table(title="Variance Decomposition", show_header=True)
    decomp_table.add_column("PC", justify="right", style="cyan", width=4)
    decomp_table.add_column("Var(Between)", justify="right", width=13)
    decomp_table.add_column("Var(Within)", justify="right", width=12)
    decomp_table.add_column("F-stat", justify="right", style="yellow", width=10)
    decomp_table.add_column("SNR", justify="right", style="green", width=10)
    decomp_table.add_column("Interpretation", style="dim")

    for i in range(n_pcs):
        pc_scores = transformed[:, i]

        # Between-dataset variance
        dataset_means = np.array(
            [
                pc_scores[labels_array == l].mean()
                for l in np.unique(pca_pooled.dataset_labels)
            ]
        )
        var_between = np.var(dataset_means, ddof=1) if len(dataset_means) > 1 else 0

        # Within-dataset variance
        var_within = np.mean(
            [
                pc_scores[labels_array == l].var(ddof=1)
                for l in np.unique(pca_pooled.dataset_labels)
            ]
        )

        # Total variance
        var_total = np.var(pc_scores, ddof=1)

        # F-statistic (ratio of between to within variance)
        f_stat = var_between / var_within if var_within > 0 else 0

        # Signal-to-noise ratio: fraction of variance explained, scaled by measurement quality
        # Signal: explained variance ratio (0-1, fraction of total variance)
        # Statistical quality factor: sqrt(N_eff / 100) normalized to ~1 for N_eff=100
        signal = explained_variance_ratio[i]
        quality_factor = np.sqrt(avg_n_eff / 100.0) if avg_n_eff > 0 else 1.0
        snr = signal * quality_factor

        # ANOVA p-value to test if means differ across datasets
        dataset_groups = [
            pc_scores[labels_array == l] for l in np.unique(pca_pooled.dataset_labels)
        ]
        _, p_value = f_oneway(*dataset_groups)

        variance_between.append(var_between)
        variance_within.append(var_within)
        variance_total.append(var_total)
        f_statistics.append(f_stat)
        snr_values.append(snr)
        p_values.append(p_value)

        # Interpretation
        if f_stat > 1.0:
            interpretation = "Dataset separation dominates"
            row_style = "green"
        elif f_stat > 0.1:
            interpretation = "Mixed (both sources contribute)"
            row_style = "yellow"
        else:
            interpretation = "Within-dataset variance dominates"
            row_style = "dim"

        decomp_table.add_row(
            f"[{row_style}]{i + 1}[/{row_style}]",
            f"[{row_style}]{var_between:.4f}[/{row_style}]",
            f"[{row_style}]{var_within:.4f}[/{row_style}]",
            f"[{row_style}]{f_stat:.3f}[/{row_style}]",
            f"[{row_style}]{snr:.2f}[/{row_style}]",
            f"[{row_style}]{interpretation}[/{row_style}]",
        )

    console.print(decomp_table)

    # Print interpretation guide
    console.print("\n[header]Interpretation Guide:[/header]")
    console.print(
        "  • [green]High F-stat (>1)[/green]: PC captures differences between datasets"
    )
    console.print("  • [yellow]Mid F-stat (0.1-1)[/yellow]: PC captures both sources")
    console.print(
        "  • [dim]Low F-stat (<0.1)[/dim]: PC captures shared variation within datasets"
    )
    console.print(
        "  • [green]High SNR[/green]: PC is well-measured (signal >> measurement noise)"
    )
    console.print("  • [dim]Low SNR[/dim]: PC is dominated by MCMC sampling noise\n")
    return (
        variance_between,
        variance_within,
        variance_total,
        f_statistics,
        snr_values,
        p_values,
    )