Skip to content

Kernels module

kernels

Kernel construction and validation for Gaussian Processes.

This module handles kernel configuration, building, and validation. Provides a unified interface for working with multiple kernel types across numpy and JAX backends.

build_kernel(config)

Build sklearn kernel from configuration.

This is the single source of truth for kernel construction, ensuring consistency across the codebase.

PARAMETER DESCRIPTION
config

Kernel configuration specifying kernel type, \(\sigma\), \(\ell\), and parameters.

TYPE: KernelConfig

RETURNS DESCRIPTION
Kernel

sklearn Kernel object (without noise component).

RAISES DESCRIPTION
ValueError

If kernel type is unknown or params are invalid.

Examples:

>>> config = KernelConfig(KernelType.RBF, sigma=0.1, length_scale=0.5)
>>> kernel = build_kernel(config)
>>> K = kernel(log_k)  # Evaluate covariance matrix
Source code in src/primefeat/gp/kernels.py
def build_kernel(config: KernelConfig) -> Kernel:
    """
    Build sklearn kernel from configuration.

    This is the single source of truth for kernel construction,
    ensuring consistency across the codebase.

    Args:
        config: Kernel configuration specifying kernel type, $\\sigma$, $\\ell$, and parameters.

    Returns:
        sklearn Kernel object (without noise component).

    Raises:
        ValueError: If kernel type is unknown or params are invalid.

    Examples:
        >>> config = KernelConfig(KernelType.RBF, sigma=0.1, length_scale=0.5)
        >>> kernel = build_kernel(config)
        >>> K = kernel(log_k)  # Evaluate covariance matrix
    """
    if config.kernel_type == KernelType.RBF:
        signal_kernel = RBF(
            length_scale=config.length_scale, length_scale_bounds="fixed"
        )

    elif config.kernel_type == KernelType.MATERN:
        nu = config.params.get("nu", 1.5)
        signal_kernel = Matern(
            length_scale=config.length_scale,
            nu=nu,
            length_scale_bounds="fixed",
        )

    elif config.kernel_type == KernelType.RATIONAL_QUADRATIC:
        alpha = config.params["alpha"]
        signal_kernel = RationalQuadratic(
            length_scale=config.length_scale,
            alpha=alpha,
            length_scale_bounds="fixed",
            alpha_bounds="fixed",
        )

    elif config.kernel_type == KernelType.PERIODIC:
        period = config.params["period"]
        signal_kernel = ExpSineSquared(
            length_scale=config.length_scale,
            periodicity=period,
            length_scale_bounds="fixed",
            periodicity_bounds="fixed",
        )

    elif config.kernel_type == KernelType.LOCALLY_PERIODIC:
        period = config.params["period"]
        length_scale_rbf = config.params["length_scale_rbf"]

        rbf_kernel = RBF(length_scale=length_scale_rbf, length_scale_bounds="fixed")

        periodic_kernel = ExpSineSquared(
            length_scale=config.length_scale,
            periodicity=period,
            length_scale_bounds="fixed",
            periodicity_bounds="fixed",
        )

        signal_kernel = rbf_kernel * periodic_kernel

    else:
        raise ValueError(f"Unknown kernel type: {config.kernel_type}")

    full_kernel = (
        ConstantKernel(config.sigma**2, constant_value_bounds="fixed") * signal_kernel
    )

    return full_kernel

build_noise_covariance(n, noise_level=None, noise_cov=None)

Build noise covariance matrix.

Supports two modes:

  1. Diagonal noise (simple): \(\sigma_n^2 I\)
  2. Full posterior covariance (recommended for MCMC): \(\Sigma_{\mathrm{post}}\)
PARAMETER DESCRIPTION
n

Number of data points.

TYPE: int

noise_level

Diagonal noise standard deviation \(\sigma_n\) (used if noise_cov=None).

TYPE: Optional[float] DEFAULT: None

noise_cov

Full N×N posterior covariance matrix \(\Sigma_{\mathrm{post}}\).

TYPE: Optional[ndarray] DEFAULT: None

RETURNS DESCRIPTION
ndarray

Noise covariance matrix of shape (n, n).

RAISES DESCRIPTION
ValueError

If neither noise_level nor noise_cov provided, or shape mismatch.

Source code in src/primefeat/gp/kernels.py
def build_noise_covariance(
    n: int,
    noise_level: Optional[float] = None,
    noise_cov: Optional[np.ndarray] = None,
) -> np.ndarray:
    """
    Build noise covariance matrix.

    Supports two modes:

    1. Diagonal noise (simple): $\\sigma_n^2 I$
    2. Full posterior covariance (recommended for MCMC): $\\Sigma_{\\mathrm{post}}$

    Args:
        n: Number of data points.
        noise_level: Diagonal noise standard deviation $\\sigma_n$ (used if noise_cov=None).
        noise_cov: Full N×N posterior covariance matrix $\\Sigma_{\\mathrm{post}}$.

    Returns:
        Noise covariance matrix of shape (n, n).

    Raises:
        ValueError: If neither noise_level nor noise_cov provided, or shape mismatch.
    """
    if noise_cov is not None:
        if noise_cov.shape != (n, n):
            raise ValueError(
                f"noise_cov must be shape ({n}, {n}), got {noise_cov.shape}"
            )
        return noise_cov
    elif noise_level is not None:
        return noise_level**2 * np.eye(n)
    else:
        raise ValueError("Either noise_level or noise_cov must be provided")

compute_kernel_matrix(log_k, kernel_config)

Compute kernel covariance matrix (signal only, no noise).

Useful for visualizing kernel structure and comparing kernels.

PARAMETER DESCRIPTION
log_k

\(\log(k)\) values, shape (n,) or (n, 1).

TYPE: ndarray

kernel_config

Kernel configuration specifying type and hyperparameters.

TYPE: KernelConfig

RETURNS DESCRIPTION
ndarray

Signal covariance matrix \(K_{\mathrm{signal}}\) of shape (n, n).

Source code in src/primefeat/gp/kernels.py
def compute_kernel_matrix(
    log_k: np.ndarray,
    kernel_config: KernelConfig,
) -> np.ndarray:
    """
    Compute kernel covariance matrix (signal only, no noise).

    Useful for visualizing kernel structure and comparing kernels.

    Args:
        log_k: $\\log(k)$ values, shape (n,) or (n, 1).
        kernel_config: Kernel configuration specifying type and hyperparameters.

    Returns:
        Signal covariance matrix $K_{\\mathrm{signal}}$ of shape (n, n).
    """
    log_k = np.asarray(log_k).reshape(-1, 1)
    kernel = build_kernel(kernel_config)
    return kernel(log_k)

compare_kernels(log_k, configs)

Compute kernel matrices for multiple configurations.

Useful for comparing how different kernels represent correlation structure.

PARAMETER DESCRIPTION
log_k

\(\log(k)\) values, shape (n,) or (n, 1).

TYPE: ndarray

configs

Dictionary mapping names to KernelConfig objects.

TYPE: Dict[str, KernelConfig]

RETURNS DESCRIPTION
Dict[str, ndarray]

Dictionary mapping names to kernel matrices.

Examples:

>>> log_k = np.linspace(-7, -1.5, 20)
>>> configs = {
...     'RBF': KernelConfig(KernelType.RBF, 0.1, 0.5),
...     'RQ_low_alpha': KernelConfig(KernelType.RATIONAL_QUADRATIC, 0.1, 0.5, {'alpha': 0.5}),
...     'RQ_high_alpha': KernelConfig(KernelType.RATIONAL_QUADRATIC, 0.1, 0.5, {'alpha': 10.0}),
... }
>>> matrices = compare_kernels(log_k, configs)
Source code in src/primefeat/gp/kernels.py
def compare_kernels(
    log_k: np.ndarray,
    configs: Dict[str, KernelConfig],
) -> Dict[str, np.ndarray]:
    """
    Compute kernel matrices for multiple configurations.

    Useful for comparing how different kernels represent correlation structure.

    Args:
        log_k: $\\log(k)$ values, shape (n,) or (n, 1).
        configs: Dictionary mapping names to KernelConfig objects.

    Returns:
        Dictionary mapping names to kernel matrices.

    Examples:
        >>> log_k = np.linspace(-7, -1.5, 20)
        >>> configs = {
        ...     'RBF': KernelConfig(KernelType.RBF, 0.1, 0.5),
        ...     'RQ_low_alpha': KernelConfig(KernelType.RATIONAL_QUADRATIC, 0.1, 0.5, {'alpha': 0.5}),
        ...     'RQ_high_alpha': KernelConfig(KernelType.RATIONAL_QUADRATIC, 0.1, 0.5, {'alpha': 10.0}),
        ... }
        >>> matrices = compare_kernels(log_k, configs)
    """
    log_k = np.asarray(log_k).reshape(-1, 1)
    return {name: build_kernel(config)(log_k) for name, config in configs.items()}

compute_bin_resolution(nbins=None, k_start=None, k_end=None, binning=None)

Compute resolution limits imposed by finite binning.

With finite bins over a finite \(k\)-range, we cannot resolve arbitrarily small correlation lengths. This function computes the minimum resolvable length scale and warns if requested parameters are below this limit.

PARAMETER DESCRIPTION
nbins

Number of bins (default: 20).

TYPE: Optional[int] DEFAULT: None

k_start

Start of \(k\)-range in \(\mathrm{Mpc}^{-1}\) (default: 1e-3).

TYPE: Optional[float] DEFAULT: None

k_end

End of \(k\)-range in \(\mathrm{Mpc}^{-1}\) (default: 0.34).

TYPE: Optional[float] DEFAULT: None

binning

optional BinningScheme object (supersedes nbins, k_start, k_end if provided)

TYPE: Optional[BinningScheme] DEFAULT: None

RETURNS DESCRIPTION
Dict[str, Any]

Dictionary with: - delta_log_k: Bin spacing in \(\log(k)\) space - log_k_range: Total range in \(\log(k)\) space - min_resolvable_length: Minimum length scale we can constrain (~2 bins) - max_sensible_length: Maximum useful length scale (~half range)

Examples:

>>> res = compute_bin_resolution(20, 0.001, 0.23)
>>> print(f"Min length scale: {res['min_resolvable_length']:.3f}")
>>> print(f"Max length scale: {res['max_sensible_length']:.3f}")
Source code in src/primefeat/gp/kernels.py
def compute_bin_resolution(
    nbins: Optional[int] = None,
    k_start: Optional[float] = None,
    k_end: Optional[float] = None,
    binning: Optional[BinningScheme] = None,
) -> Dict[str, Any]:
    """
    Compute resolution limits imposed by finite binning.

    With finite bins over a finite $k$-range, we cannot resolve arbitrarily
    small correlation lengths. This function computes the minimum resolvable
    length scale and warns if requested parameters are below this limit.

    Args:
        nbins: Number of bins (default: 20).
        k_start: Start of $k$-range in $\\mathrm{Mpc}^{-1}$ (default: 1e-3).
        k_end: End of $k$-range in $\\mathrm{Mpc}^{-1}$ (default: 0.34).
        binning: optional BinningScheme object (supersedes nbins, k_start, k_end if provided)

    Returns:
        Dictionary with:
            - delta_log_k: Bin spacing in $\\log(k)$ space
            - log_k_range: Total range in $\\log(k)$ space
            - min_resolvable_length: Minimum length scale we can constrain (~2 bins)
            - max_sensible_length: Maximum useful length scale (~half range)

    Examples:
        >>> res = compute_bin_resolution(20, 0.001, 0.23)
        >>> print(f"Min length scale: {res['min_resolvable_length']:.3f}")
        >>> print(f"Max length scale: {res['max_sensible_length']:.3f}")
    """
    b = _resolve_binning(binning, k_start or 1e-3, k_end or 0.34, nbins or 20)
    log_k_range = np.log(b.k_end) - np.log(b.k_start)
    delta_log_k = log_k_range / b.nbins

    min_resolvable_length = 2 * delta_log_k
    max_sensible_length = log_k_range / 2

    return {
        "delta_log_k": delta_log_k,
        "log_k_range": log_k_range,
        "min_resolvable_length": min_resolvable_length,
        "max_sensible_length": max_sensible_length,
        "recommended_length_range": (min_resolvable_length, max_sensible_length),
    }

validate_hyperparameters(sigma, length_scale, nbins=None, k_start=None, k_end=None, warn=True, binning=None)

Validate GP hyperparameters against bin resolution limits.

PARAMETER DESCRIPTION
sigma

Signal standard deviation \(\sigma\).

TYPE: float

length_scale

RBF kernel length scale \(\ell\) in \(\log(k)\) space.

TYPE: float

nbins

Number of bins (default: 20).

TYPE: Optional[int] DEFAULT: None

k_start

Start of \(k\)-range in \(\mathrm{Mpc}^{-1}\) (default: 1e-3).

TYPE: Optional[float] DEFAULT: None

k_end

End of \(k\)-range in \(\mathrm{Mpc}^{-1}\) (default: 0.34).

TYPE: Optional[float] DEFAULT: None

warn

Whether to print warnings.

TYPE: bool DEFAULT: True

binning

optional BinningScheme object (supersedes nbins, k_start, k_end if provided)

TYPE: Optional[BinningScheme] DEFAULT: None

RETURNS DESCRIPTION
bool

True if parameters are within reasonable bounds, False otherwise.

Source code in src/primefeat/gp/kernels.py
def validate_hyperparameters(
    sigma: float,
    length_scale: float,
    nbins: Optional[int] = None,
    k_start: Optional[float] = None,
    k_end: Optional[float] = None,
    warn: bool = True,
    binning: Optional[BinningScheme] = None,
) -> bool:
    """
    Validate GP hyperparameters against bin resolution limits.

    Args:
        sigma: Signal standard deviation $\\sigma$.
        length_scale: RBF kernel length scale $\\ell$ in $\\log(k)$ space.
        nbins: Number of bins (default: 20).
        k_start: Start of $k$-range in $\\mathrm{Mpc}^{-1}$ (default: 1e-3).
        k_end: End of $k$-range in $\\mathrm{Mpc}^{-1}$ (default: 0.34).
        warn: Whether to print warnings.
        binning: optional BinningScheme object (supersedes nbins, k_start, k_end if provided)

    Returns:
        True if parameters are within reasonable bounds, False otherwise.
    """
    res = compute_bin_resolution(nbins, k_start, k_end, binning)

    is_valid = True

    if length_scale < res["min_resolvable_length"]:
        if warn:
            warnings.warn(
                f"Length scale $\\ell$={length_scale:.3f} is below minimum resolvable "
                f"scale {res['min_resolvable_length']:.3f}. "
                f"With {nbins} bins, features narrower than ~2 bins cannot be distinguished "
                f"from noise. Consider using $\\ell$ >= {res['min_resolvable_length']:.3f}."
            )
        is_valid = False

    if length_scale > res["max_sensible_length"]:
        if warn:
            warnings.warn(
                f"Length scale $\\ell$={length_scale:.3f} is larger than half the $k$-range "
                f"({res['max_sensible_length']:.3f}). Such broad features are poorly "
                f"constrained by the data. Consider using $\\ell$ <= {res['max_sensible_length']:.3f}."
            )
        is_valid = False

    if sigma < 0:
        raise ValueError(f"Signal variance $\\sigma$ must be non-negative, got {sigma}")

    if sigma > 1.0:
        if warn:
            warnings.warn(
                f"Signal amplitude $\\sigma$={sigma:.3f} is very large (>1.0). "
                f"This implies order-unity deviations from the power-law, "
                f"which may not be physically motivated."
            )

    return is_valid