Skip to content

Optimize module

optimize

Gradient-based Gaussian Process hyperparameter optimization.

Implements optimization and sampling of GP hyperparameters using JAX, optax, and blackjax backends.

optimize(log_k, delta, noise_cov, initial_config, *, max_steps=500, learning_rate=0.05, convergence_tol=1e-05, patience=30, history_interval=10, verbose=False)

Gradient-based GP hyperparameter optimisation via JAX autodiff.

Dispatches to JAXGPBackend.optimize_kernel_hyperparameters() when the JAX backend and optax are available. Returns None gracefully when JAX or optax is not installed, allowing callers to fall back to the grid search result.

All free hyperparameters are optimised jointly in log-space using Adam with a cosine-decay learning rate. Gradients are computed through tinygp.GaussianProcess.log_probability().

Parameters

log_k : array-like, shape (n,) Natural-log wavenumber values \(\log(k)\) (training inputs). delta : array-like, shape (n,) Posterior-mean power spectrum deviations \(\delta(k)\) (targets). noise_cov : array-like, shape (n, n) Full posterior covariance matrix used as the GP noise term. If you only have a diagonal noise level \(\sigma_n\), pass sigma_n**2 * np.eye(n). initial_config : KernelConfig Starting hyperparameter values. Best practice is to warm-start from the grid-search maximum. max_steps : int, optional Maximum number of Adam gradient steps (default 500). learning_rate : float, optional Peak learning rate for Adam with cosine decay (default 0.05). convergence_tol : float, optional Absolute LML change threshold for early stopping (default 1e-5). patience : int, optional Number of consecutive steps below convergence_tol that trigger early stopping (default 30). history_interval : int, optional Record LML and gradient-norm every this many steps (default 10). verbose : bool, optional Print per-step progress (default False).

Returns

OptimizationResult or None OptimizationResult with optimized_config, final_lml, converged, n_steps, lml_history, and grad_norm_history. Returns None if JAX or optax is not available.

Examples

from primefeat.gp import compute_lml_landscape, optimize from primefeat.backends.base import KernelConfig, KernelType

Step 1: coarse grid to get an initial point

landscape = compute_lml_landscape( ... delta_mean, log_k, ... kernel_type=KernelType.LOCALLY_PERIODIC, ... kernel_params={'period': 0.8, 'length_scale_rbf': 2.0}, ... nbins=20, k_start=0.001, k_end=0.23, ... ) init_config = KernelConfig( ... KernelType.LOCALLY_PERIODIC, ... sigma=landscape['optimal_sigma'], ... length_scale=landscape['optimal_length_scale'], ... params={'period': 0.8, 'length_scale_rbf': 2.0}, ... )

Step 2: gradient refinement of all 4 hyperparameters

result = optimize( ... log_k.ravel(), delta_mean, posterior_cov, init_config, ... max_steps=400, verbose=True, ... )

if result is not None: ... console.print(result.summary()) ... console.print(f"Optimal config: {result.optimized_config.describe()}")

Source code in src/primefeat/gp/optimize.py
def optimize(
    log_k: np.ndarray,
    delta: np.ndarray,
    noise_cov: np.ndarray,
    initial_config: KernelConfig,
    *,
    max_steps: int = 500,
    learning_rate: float = 0.05,
    convergence_tol: float = 1e-5,
    patience: int = 30,
    history_interval: int = 10,
    verbose: bool = False,
) -> Optional[Any]:
    """
    Gradient-based GP hyperparameter optimisation via JAX autodiff.

    Dispatches to ``JAXGPBackend.optimize_kernel_hyperparameters()`` when the
    JAX backend and ``optax`` are available. Returns ``None`` gracefully when
    JAX or optax is not installed, allowing callers to fall back to the grid
    search result.

    All free hyperparameters are optimised jointly in log-space using Adam
    with a cosine-decay learning rate. Gradients are computed through
    ``tinygp.GaussianProcess.log_probability()``.

    Parameters
    ----------
    log_k : array-like, shape (n,)
        Natural-log wavenumber values $\\log(k)$ (training inputs).
    delta : array-like, shape (n,)
        Posterior-mean power spectrum deviations $\\delta(k)$ (targets).
    noise_cov : array-like, shape (n, n)
        Full posterior covariance matrix used as the GP noise term.
        If you only have a diagonal noise level $\\sigma_n$, pass
        ``sigma_n**2 * np.eye(n)``.
    initial_config : KernelConfig
        Starting hyperparameter values. Best practice is to warm-start from
        the grid-search maximum.
    max_steps : int, optional
        Maximum number of Adam gradient steps (default 500).
    learning_rate : float, optional
        Peak learning rate for Adam with cosine decay (default 0.05).
    convergence_tol : float, optional
        Absolute LML change threshold for early stopping (default 1e-5).
    patience : int, optional
        Number of consecutive steps below ``convergence_tol`` that trigger
        early stopping (default 30).
    history_interval : int, optional
        Record LML and gradient-norm every this many steps (default 10).
    verbose : bool, optional
        Print per-step progress (default False).

    Returns
    -------
    OptimizationResult or None
        ``OptimizationResult`` with ``optimized_config``, ``final_lml``,
        ``converged``, ``n_steps``, ``lml_history``, and
        ``grad_norm_history``. Returns ``None`` if JAX or optax is not
        available.

    Examples
    --------

    >>> from primefeat.gp import compute_lml_landscape, optimize
    >>> from primefeat.backends.base import KernelConfig, KernelType
    >>>
    >>> # Step 1: coarse grid to get an initial point
    >>> landscape = compute_lml_landscape(
    ...     delta_mean, log_k,
    ...     kernel_type=KernelType.LOCALLY_PERIODIC,
    ...     kernel_params={'period': 0.8, 'length_scale_rbf': 2.0},
    ...     nbins=20, k_start=0.001, k_end=0.23,
    ... )
    >>> init_config = KernelConfig(
    ...     KernelType.LOCALLY_PERIODIC,
    ...     sigma=landscape['optimal_sigma'],
    ...     length_scale=landscape['optimal_length_scale'],
    ...     params={'period': 0.8, 'length_scale_rbf': 2.0},
    ... )
    >>>
    >>> # Step 2: gradient refinement of all 4 hyperparameters
    >>> result = optimize(
    ...     log_k.ravel(), delta_mean, posterior_cov, init_config,
    ...     max_steps=400, verbose=True,
    ... )
    >>>
    >>> if result is not None:
    ...     console.print(result.summary())
    ...     console.print(f"Optimal config: {result.optimized_config.describe()}")
    """

    log_k = np.asarray(log_k).ravel()
    delta = np.asarray(delta).ravel()
    noise_cov = np.asarray(noise_cov)

    try:
        from ..backends.jax.gp_jax import get_jax_backend, _OPTAX_AVAILABLE
    except ImportError:
        warnings.warn(
            "JAX backend not available; skipping gradient-based optimization. "
            "Install with: pip install jax jaxlib tinygp"
        )
        return None

    if not _OPTAX_AVAILABLE:
        warnings.warn(
            "optax not installed; skipping gradient-based optimization. "
            "Install with: pip install optax"
        )
        return None

    backend = get_jax_backend()
    return backend.optimize_kernel_hyperparameters(
        log_k,
        delta,
        noise_cov,
        initial_config,
        max_steps=max_steps,
        learning_rate=learning_rate,
        convergence_tol=convergence_tol,
        patience=patience,
        history_interval=history_interval,
        verbose=verbose,
    )

sample_hyperparameters(log_k, delta, noise_cov, initial_config, *, num_warmup=500, num_samples=1000, target_accept_rate=0.8, prior_scale=1.5, initial_step_size=0.1, seed=0, progress_bar=False)

Sample GP kernel hyperparameters via NUTS (HMC) using blackjax.

Unlike :func:optimize, which returns a single MAP point, this function returns the full posterior distribution over the hyperparameters, enabling uncertainty quantification and credible intervals. This is particularly valuable for kernels with 3–4 free parameters (Rational Quadratic, Locally Periodic) where the posterior can be multimodal or banana-shaped.

The log-posterior sampled is::

log p(θ | data) = LML(θ) + Σ_i Normal(log θ_i ; 0, prior_scale)

All hyperparameters are sampled in log-space (guaranteeing positivity). Warmup uses blackjax.window_adaptation for automatic step-size and diagonal mass-matrix tuning (Stan-style dual averaging + Welford online covariance).

Parameters

log_k : ndarray, shape (N,) Log-wavenumber bin centres. delta : ndarray, shape (N,) Posterior mean of the δ parameters from the MCMC chain. noise_cov : ndarray, shape (N, N) Full posterior covariance matrix Σ_post. initial_config : KernelConfig Kernel configuration used to initialise the chain position. For best results, warm-start from the output of :func:optimize. num_warmup : int Number of NUTS adaptation steps. Default 500. num_samples : int Number of posterior samples. Default 1000. target_accept_rate : float Target NUTS acceptance probability. Default 0.80. prior_scale : float Std of the Normal prior on each log-parameter. Default 1.5 (≈ 2 orders of magnitude around the initial value). initial_step_size : float Initial NUTS step size before adaptation. Default 0.1. seed : int JAX random seed. Default 0. progress_bar : bool Show blackjax warmup progress bar. Default False.

Returns

HMCSamplingResult or None Posterior samples and diagnostics, or None (with a warning) if JAX or blackjax is not installed.

Examples

from primefeat.gp import optimize, sample_hyperparameters from primefeat.backends.base import KernelConfig, KernelType

config = KernelConfig(KernelType.LOCALLY_PERIODIC, ... sigma=0.1, length_scale=0.3, ... params={"period": 0.8, "length_scale_rbf": 2.0})

Optional: warm-start from gradient optimisation

opt = optimize(log_k, delta, noise_cov, config) init = opt.optimized_config if opt else config

result = sample_hyperparameters( ... log_k, delta, noise_cov, init, ... num_warmup=500, num_samples=1000, ... ) if result is not None: ... console.print(result.summary()) ... lo, hi = result.credible_interval("log_sigma") ... console.print(f"σ 95% CI: [{lo:.4f}, {hi:.4f}]")

Source code in src/primefeat/gp/optimize.py
def sample_hyperparameters(
    log_k: np.ndarray,
    delta: np.ndarray,
    noise_cov: np.ndarray,
    initial_config: "KernelConfig",
    *,
    num_warmup: int = 500,
    num_samples: int = 1000,
    target_accept_rate: float = 0.80,
    prior_scale: float = 1.5,
    initial_step_size: float = 0.1,
    seed: int = 0,
    progress_bar: bool = False,
) -> Optional[Any]:
    """
    Sample GP kernel hyperparameters via NUTS (HMC) using blackjax.

    Unlike :func:`optimize`, which returns a single MAP point, this function
    returns the full posterior distribution over the hyperparameters, enabling
    uncertainty quantification and credible intervals. This is particularly
    valuable for kernels with 3–4 free parameters (Rational Quadratic, Locally
    Periodic) where the posterior can be multimodal or banana-shaped.

    The log-posterior sampled is::

        log p(θ | data) = LML(θ) + Σ_i Normal(log θ_i ; 0, prior_scale)

    All hyperparameters are sampled in log-space (guaranteeing positivity).
    Warmup uses ``blackjax.window_adaptation`` for automatic step-size and
    diagonal mass-matrix tuning (Stan-style dual averaging + Welford online
    covariance).

    Parameters
    ----------
    log_k : ndarray, shape (N,)
        Log-wavenumber bin centres.
    delta : ndarray, shape (N,)
        Posterior mean of the δ parameters from the MCMC chain.
    noise_cov : ndarray, shape (N, N)
        Full posterior covariance matrix Σ_post.
    initial_config : KernelConfig
        Kernel configuration used to initialise the chain position. For best
        results, warm-start from the output of :func:`optimize`.
    num_warmup : int
        Number of NUTS adaptation steps. Default 500.
    num_samples : int
        Number of posterior samples. Default 1000.
    target_accept_rate : float
        Target NUTS acceptance probability. Default 0.80.
    prior_scale : float
        Std of the Normal prior on each log-parameter. Default 1.5
        (≈ 2 orders of magnitude around the initial value).
    initial_step_size : float
        Initial NUTS step size before adaptation. Default 0.1.
    seed : int
        JAX random seed. Default 0.
    progress_bar : bool
        Show blackjax warmup progress bar. Default False.

    Returns
    -------
    HMCSamplingResult or None
        Posterior samples and diagnostics, or ``None`` (with a warning) if
        JAX or blackjax is not installed.

    Examples
    --------
    >>> from primefeat.gp import optimize, sample_hyperparameters
    >>> from primefeat.backends.base import KernelConfig, KernelType
    >>>
    >>> config = KernelConfig(KernelType.LOCALLY_PERIODIC,
    ...                       sigma=0.1, length_scale=0.3,
    ...                       params={"period": 0.8, "length_scale_rbf": 2.0})
    >>>
    >>> # Optional: warm-start from gradient optimisation
    >>> opt = optimize(log_k, delta, noise_cov, config)
    >>> init = opt.optimized_config if opt else config
    >>>
    >>> result = sample_hyperparameters(
    ...     log_k, delta, noise_cov, init,
    ...     num_warmup=500, num_samples=1000,
    ... )
    >>> if result is not None:
    ...     console.print(result.summary())
    ...     lo, hi = result.credible_interval("log_sigma")
    ...     console.print(f"σ 95% CI: [{lo:.4f}, {hi:.4f}]")
    """
    log_k = np.asarray(log_k).ravel()
    delta = np.asarray(delta).ravel()
    noise_cov = np.asarray(noise_cov)

    try:
        from ..backends.jax.gp_jax import JAXGPBackend, _BLACKJAX_AVAILABLE
    except ImportError:
        warnings.warn(
            "JAX backend not available; skipping HMC sampling. "
            "Install with: pip install jax jaxlib tinygp"
        )
        return None

    if not _BLACKJAX_AVAILABLE:
        warnings.warn(
            "blackjax not installed; skipping HMC sampling. "
            "Install with: pip install blackjax"
        )
        return None

    return JAXGPBackend.sample_hyperparameters_hmc(
        log_k,
        delta,
        noise_cov,
        initial_config,
        num_warmup=num_warmup,
        num_samples=num_samples,
        target_accept_rate=target_accept_rate,
        prior_scale=prior_scale,
        initial_step_size=initial_step_size,
        seed=seed,
        progress_bar=progress_bar,
    )