API Reference

class nbi.engine.NBI(flow=None, featurizer=None, state_dict=None, simulator=None, priors=None, device='cpu', path='test', n_jobs=1, labels=None, tqdm_notebook=False, network_reinit=False, scale_reinit=True)[source]

Bases: object

Neural Bayesian Inference Engine.

NBI is an open-source software introduced to support both amortized and sequential Neural Posterior Estimation (NPE) methods, particularly tailored for astronomical inference problems, such as those involving light curves and spectra.

The design of NBI addresses critical issues in the adaptation of NPE methods in astronomy. It provides built-in “featurizer” networks with demonstrated efficacy on sequential data, removing the need for custom featurizer networks by users. It also employs a modified algorithm, SNPE-IS, which enables asymptotically exact inference by using the surrogate posterior under NPE as a proposal distribution for importance sampling.

Parameters

flowdict or nn.Module

Dictionary containing hyperparameters for the Masked Autoregressive Flow. If dictionary, the keys include:

  • ‘n_dims’, dimension of model parameter space.

  • ‘num_blocks’, number of Masked Autoregressive Flow (MAF) blocks.

  • ‘flow_hidden’, hidden dimension for each MAF block

  • ‘perm_seed’, random seed for dimension permutation of each MAF block

  • ‘n_mog’, number of mixture of Gaussians as the base density. Recommended > 1 for multi-modal and/or non-Gaussian posterior distributions. Rule of thumb: twice the maximum number of posterior modes.

Only ‘n_dims’ required. See class attribute default_flow_config for default values.

If nn.Module, it’s a custom normalizing flow.

Not required for inference only mode when state_dict is supplied

featurizerdict or nn.Module

Dictionary of hyperparameters for the pre-built neural network for dimensionality reduction or (currently not supported uet) a custom PyTorch network module.

If dictionary, the keys include:
  • ‘type’: Name of pre-built network architecture. Default: ‘resnet-gru’.

  • ‘norm’: Either ‘weight_norm’ or ‘batch_norm’.

  • ‘dim_in’: Number of input data channels.

  • ‘dim_out’: Dimension of output feature vector.

  • ‘dim_conv_max’: Maximum hidden dimensions of ResNet layers.

  • ‘depth’: ResNet network depth.

  • ‘n_rnn’: Number of GRU layers.

If nn.Module, it’s a custom featurizer network that maps input sequence of shape [Batch, Channel, Length] to output feature vector of shape [Batch, Dimension].

Not required for inference only mode when state_dict is supplied

state_dictstr or dictionary, optional

State dict or path to saved state dict containing three keys: network_state_dict, x_scale, y_scale. Loads pre-trained model.

simulatorfunction, optional

Simulator function to generate data. Requires input of model parameters and returns simulated data.

priorslist of scipy.stats objects, optional

Prior distribution for Bayesian inference. If not provided, uniform prior is assumed.

device{‘cpu’, ‘cuda’, ‘mps’} str, optional

Device for neural network, default is ‘cpu’.

pathstr, optional

Path to save training set and model checkpoints.

n_jobsint, optional

Number of parallel jobs for computation.

labelslist of str, optional

Names of parameters for inference. Must be in the same order as priors.

tqdm_notebookbool, optional

If True, uses notebook version of tqdm for progress bars.

network_reinitbool, optional

If True, re-initializes the network weights every round. Default is False, which often yield better results than True.

scale_reinitbool, optional

If True, re-initializes data pre-processing scales every round. Default is True.

corner(x, y=None, weights=None, color='k', y_true=None, plot_datapoints=True, plot_density=False, range_=None, truth_color='r', n=5000)[source]

Wrapper function to make corner plot.

Parameters

xndarray

Input data for inference.

yndarray, optional

Parameters to plot. If None, parameters are drawn from the surrogate posterior.

weightsndarray, optional

Importance weights for reweighting.

colorstr, optional

Color of the corner plot

y_truendarray, optional

True parameters for crosshairs

plot_datapointsbool, optional

If True, plot data points as scatter.

plot_densitybool, optional

If True, plot 2D densities.

range_list, optional

Percentile of data to plot in corner plot.

truth_colorstr, optional

Color of crosshairs.

nint, optional

Number of samples to draw from the surrogate posterior, if y is not specified.

Returns

fit(x=None, y=None, noise=None, log_like=None, n_epochs=10, n_rounds=1, n_sims=-1, x_obs=None, y_true=None, n_reuse=0, batch_size=64, project='test', use_wandb=False, neff_stop=-1, early_stop_train=False, early_stop_patience=-1, f_val=0.1, lr=0.001, min_lr=None, decay_type='SGDR', plot=True, f_accept_min=-1, workers=8)[source]

Fit the Neural Bayesian Inference Engine.

Trains the network based on provided data and parameters.

Parameters

xndarray of paths to individual simulations, optional

First round training simulations. Only required when simulation and prior not specified during engine initialization.

yndarray of shape (N, D) where D is parameter space dimension, optional

First round training parameters. Only required when simulation and prior not specified during engine initialization.

noisendarray or function, optional
Measurement error and/or data augmentation during training. Array of Gaussian errorbars for fixed

iid Gaussian noise. For ANPE, provide a function that takes in noiseless data and parameters (x, y) and outputs the noisified data and parameters (x’, y’), which is the last pre-processing step before feeding into the neural network.

log_likefunction, optional

Log-likelihood function that takes in (x, x_path, y) and returns the log likelihood. Required for importance sampling (SNPE) but not required when noise is iid Gaussian and specified as an errorbar array.

n_epochsint, optional

Number of training epochs.

n_roundsint, optional

Number of training rounds.

n_simsint, optional

Number of simulations.

x_obsndarray, optional

Observed data.

y_truendarray, optional

True target values.

n_reuseint, optional

Number of previous round training data to be reused for the current round.

batch_sizeint, optional

Batch size for training and validation.

projectstr, optional

Name of the project for logging.

use_wandbbool, optional

If True, enables wandb logging.

neff_stopint, optional

Early stopping criteria based on Effective Sample Size (ESS). Terminate inference when ESS exceeds this value.

early_stop_trainbool, optional

If True, terminates inference when the surrogate posterior (as measured by the ESS) does not improve for the current round.

early_stop_patienceint, optional

Number of epochs without improvement to trigger early stopping.

f_valfloat, optional

Fraction of data to use for validation. Default: 0.1

lrfloat, optional

Learning rate. Default: 0.001

min_lrfloat, optional

Minimum learning rate for learning rate decay. Automatically calculated when not specified.

decay_type{‘SGDR’} str, optional

Type of learning rate decay. Default is Cosine annealing decay (“SGDR”).

plotbool, optional

If True, plots results after training.

f_accept_minfloat, optional

Minimum round sampling efficiency (defined as the ratio from the effective sample size to the total sample size) to terminate inference early.

workersint, optional

Number of workers for data loading.

Returns

get_network()[source]

Returns the network module without DataParallel wrapper, if any.

Returns

nn.Module

Network module without the DataParallel wrapper.

get_params()[source]

Saves the network weights and pre-processing scales to disk

Returns

get_round_data(n_reuse)[source]

Returns training data for the current round.

Parameters

n_reuseint

Number of previous round training data to be reused for the current round.

Returns

x_roundndarray

Training data for the current round.

y_roundndarray

Training parameters for the current round.

importance_reweight(x_obs, x, y)[source]

SNPE: Calculate importance reweights for the current round.

Parameters

x_obsndarray

Observed data.

xndarray

Simulated data.

yndarray

Simulated parameters.

Returns

weightsndarray

Importance weights.

importance_reweight_like_only(x_obs, x, y)[source]

SNPE: Calculate importance reweights for the current round, using only the likelihood.

Parameters

x_obsndarray

Observed data.

xndarray

Simulated data.

yndarray

Simulated parameters.

Returns

weights : ndarray

init_env()[source]

Initialize environment for training.

Returns

log_like(x_obs, x, y)[source]

Calculate log likelihood.

Parameters

x_obsndarray

Observed data.

xndarray

Simulated data.

yndarray

Simulated parameters.

Returns

log_probndarray

Log likelihood.

log_prior(y)[source]

Calculate log prior probability.

Parameters

yndarray

Parameters to calculate prior for.

Returns

log_probndarray

Log prior probability.

log_prob(x, y)[source]

Calculate log probability under surrogate posterior.

Parameters

xndarray

Observations.

yndarray

Pparameters.

Returns

log_probndarray

Log probability under surrogate posterior

predict(x, x_err=None, y_true=None, log_like=None, n_samples=1000, neff_min=0, n_max=-1, corner=False, corner_reweight=False, seed=None)[source]

Generates the posterior distribution of parameters given input data.

Parameters

xndarray

Input data for inference

x_errndarray, optional

Measurement error for input data. Required for importance sampling. If not specified, use log_like instead.

y_truendarray, optional

True parameters, if known.

log_likefunction, optional

Log-likelihood function that takes in (x, x_path, y) and returns the log likelihood. Required for importance sampling when x_err not specified.

n_samplesint, optional

Number of posterior samples to generate.

neff_minint, optional

Minimum effective sample size required. If neff_min > n_samples, additional simulations will be generated until an ESS of neff_min is reached. Default: 0

n_maxint, optional

Maximum number of simulations to generate to achieve neff_min.

cornerbool, optional

If True, generates a corner plot of the posterior before reweighting.

corner_reweightbool, optional

If True, generates a corner plot of the posterior after reweighting.

seedint, optional

Random seed for generating parameters

Returns

ysndarray

Posterior samples.

weightsndarray

Importance weights.

prepare_data(x_obs, n_sims)[source]

Generate training data for the current round.

Parameters

x_obsndarray

Observed data for producing simulations.

n_simsint

Number of simulations.

Returns

result()[source]

SNPE: Returns the reweighted posterior from all rounds.

Returns

all_thetasndarray

Parameter values from all rounds.

all_weights: ndarray

Importance weights from all rounds.

sample(x, y=None, n=5000, corner=False)[source]

Generates samples from the surrogate posterior.

Parameters

xndarray

Input data for inference

yndarray, optional

True parameters (for corner plot), if known.

nint, optional

Number of samples to generate.

cornerbool, optional

If True, generates a corner plot of the surrogate posterior samples.

Returns

samplesndarray

Samples from the surrogate posterior.

save_params(path)[source]

Saves the network weights and pre-processing scales to disk

Returns

scale_x(x, back=False)[source]

Scale data to zero mean and unit variance, and vice versa.

Parameters

xndarray

Data to be scaled.

backbool, optional

If True, scales data back to original values.

Returns

scale_y(y, back=False)[source]

Scale parameters to zero mean and unit variance, and vice versa

Parameters

yndarray

Parameters to be scaled.

backbool, optional

If True, scales parameters back to original values.

Returns

set_params(state_dict)[source]

Load engine parameters from disk, including network weights and data pre-processing scales.

Parameters

state_dictstr or state dict

State dict or path to saved state dict containing three keys: network_state_dict, x_scale, y_scale

Returns

simulate(thetas)[source]
Generates simulations for provided parameters, which are saved to disk. An array containing paths to the

simulations is returned.

Parameters

thetasndarray

Parameters to generate simulations for.

Returns

x_pathndarray

Paths to generated simulations.

weighted_corner(x_obs, y_true)[source]

SNPE: Reweighted corner plot for the current round.

Parameters

x_obsndarray

Observed data.

y_truendarray

True parameters, if known.

Returns