API Reference
- class nbi.engine.NBI(flow=None, featurizer=None, state_dict=None, simulator=None, priors=None, device='cpu', path='test', n_jobs=1, labels=None, tqdm_notebook=False, network_reinit=False, scale_reinit=True)[source]
Bases:
object
Neural Bayesian Inference Engine.
NBI is an open-source software introduced to support both amortized and sequential Neural Posterior Estimation (NPE) methods, particularly tailored for astronomical inference problems, such as those involving light curves and spectra.
The design of NBI addresses critical issues in the adaptation of NPE methods in astronomy. It provides built-in “featurizer” networks with demonstrated efficacy on sequential data, removing the need for custom featurizer networks by users. It also employs a modified algorithm, SNPE-IS, which enables asymptotically exact inference by using the surrogate posterior under NPE as a proposal distribution for importance sampling.
Parameters
- flowdict or nn.Module
Dictionary containing hyperparameters for the Masked Autoregressive Flow. If dictionary, the keys include:
‘n_dims’, dimension of model parameter space.
‘num_blocks’, number of Masked Autoregressive Flow (MAF) blocks.
‘flow_hidden’, hidden dimension for each MAF block
‘perm_seed’, random seed for dimension permutation of each MAF block
‘n_mog’, number of mixture of Gaussians as the base density. Recommended > 1 for multi-modal and/or non-Gaussian posterior distributions. Rule of thumb: twice the maximum number of posterior modes.
Only ‘n_dims’ required. See class attribute default_flow_config for default values.
If nn.Module, it’s a custom normalizing flow.
Not required for inference only mode when state_dict is supplied
- featurizerdict or nn.Module
Dictionary of hyperparameters for the pre-built neural network for dimensionality reduction or (currently not supported uet) a custom PyTorch network module.
- If dictionary, the keys include:
‘type’: Name of pre-built network architecture. Default: ‘resnet-gru’.
‘norm’: Either ‘weight_norm’ or ‘batch_norm’.
‘dim_in’: Number of input data channels.
‘dim_out’: Dimension of output feature vector.
‘dim_conv_max’: Maximum hidden dimensions of ResNet layers.
‘depth’: ResNet network depth.
‘n_rnn’: Number of GRU layers.
If nn.Module, it’s a custom featurizer network that maps input sequence of shape [Batch, Channel, Length] to output feature vector of shape [Batch, Dimension].
Not required for inference only mode when state_dict is supplied
- state_dictstr or dictionary, optional
State dict or path to saved state dict containing three keys: network_state_dict, x_scale, y_scale. Loads pre-trained model.
- simulatorfunction, optional
Simulator function to generate data. Requires input of model parameters and returns simulated data.
- priorslist of scipy.stats objects, optional
Prior distribution for Bayesian inference. If not provided, uniform prior is assumed.
- device{‘cpu’, ‘cuda’, ‘mps’} str, optional
Device for neural network, default is ‘cpu’.
- pathstr, optional
Path to save training set and model checkpoints.
- n_jobsint, optional
Number of parallel jobs for computation.
- labelslist of str, optional
Names of parameters for inference. Must be in the same order as priors.
- tqdm_notebookbool, optional
If True, uses notebook version of tqdm for progress bars.
- network_reinitbool, optional
If True, re-initializes the network weights every round. Default is False, which often yield better results than True.
- scale_reinitbool, optional
If True, re-initializes data pre-processing scales every round. Default is True.
- corner(x, y=None, weights=None, color='k', y_true=None, plot_datapoints=True, plot_density=False, range_=None, truth_color='r', n=5000)[source]
Wrapper function to make corner plot.
Parameters
- xndarray
Input data for inference.
- yndarray, optional
Parameters to plot. If None, parameters are drawn from the surrogate posterior.
- weightsndarray, optional
Importance weights for reweighting.
- colorstr, optional
Color of the corner plot
- y_truendarray, optional
True parameters for crosshairs
- plot_datapointsbool, optional
If True, plot data points as scatter.
- plot_densitybool, optional
If True, plot 2D densities.
- range_list, optional
Percentile of data to plot in corner plot.
- truth_colorstr, optional
Color of crosshairs.
- nint, optional
Number of samples to draw from the surrogate posterior, if y is not specified.
Returns
- fit(x=None, y=None, noise=None, log_like=None, n_epochs=10, n_rounds=1, n_sims=-1, x_obs=None, y_true=None, n_reuse=0, batch_size=64, project='test', use_wandb=False, neff_stop=-1, early_stop_train=False, early_stop_patience=-1, f_val=0.1, lr=0.001, min_lr=None, decay_type='SGDR', plot=True, f_accept_min=-1, workers=8)[source]
Fit the Neural Bayesian Inference Engine.
Trains the network based on provided data and parameters.
Parameters
- xndarray of paths to individual simulations, optional
First round training simulations. Only required when simulation and prior not specified during engine initialization.
- yndarray of shape (N, D) where D is parameter space dimension, optional
First round training parameters. Only required when simulation and prior not specified during engine initialization.
- noisendarray or function, optional
- Measurement error and/or data augmentation during training. Array of Gaussian errorbars for fixed
iid Gaussian noise. For ANPE, provide a function that takes in noiseless data and parameters (x, y) and outputs the noisified data and parameters (x’, y’), which is the last pre-processing step before feeding into the neural network.
- log_likefunction, optional
Log-likelihood function that takes in (x, x_path, y) and returns the log likelihood. Required for importance sampling (SNPE) but not required when noise is iid Gaussian and specified as an errorbar array.
- n_epochsint, optional
Number of training epochs.
- n_roundsint, optional
Number of training rounds.
- n_simsint, optional
Number of simulations.
- x_obsndarray, optional
Observed data.
- y_truendarray, optional
True target values.
- n_reuseint, optional
Number of previous round training data to be reused for the current round.
- batch_sizeint, optional
Batch size for training and validation.
- projectstr, optional
Name of the project for logging.
- use_wandbbool, optional
If True, enables wandb logging.
- neff_stopint, optional
Early stopping criteria based on Effective Sample Size (ESS). Terminate inference when ESS exceeds this value.
- early_stop_trainbool, optional
If True, terminates inference when the surrogate posterior (as measured by the ESS) does not improve for the current round.
- early_stop_patienceint, optional
Number of epochs without improvement to trigger early stopping.
- f_valfloat, optional
Fraction of data to use for validation. Default: 0.1
- lrfloat, optional
Learning rate. Default: 0.001
- min_lrfloat, optional
Minimum learning rate for learning rate decay. Automatically calculated when not specified.
- decay_type{‘SGDR’} str, optional
Type of learning rate decay. Default is Cosine annealing decay (“SGDR”).
- plotbool, optional
If True, plots results after training.
- f_accept_minfloat, optional
Minimum round sampling efficiency (defined as the ratio from the effective sample size to the total sample size) to terminate inference early.
- workersint, optional
Number of workers for data loading.
Returns
- get_network()[source]
Returns the network module without DataParallel wrapper, if any.
Returns
- nn.Module
Network module without the DataParallel wrapper.
- get_round_data(n_reuse)[source]
Returns training data for the current round.
Parameters
- n_reuseint
Number of previous round training data to be reused for the current round.
Returns
- x_roundndarray
Training data for the current round.
- y_roundndarray
Training parameters for the current round.
- importance_reweight(x_obs, x, y)[source]
SNPE: Calculate importance reweights for the current round.
Parameters
- x_obsndarray
Observed data.
- xndarray
Simulated data.
- yndarray
Simulated parameters.
Returns
- weightsndarray
Importance weights.
- importance_reweight_like_only(x_obs, x, y)[source]
SNPE: Calculate importance reweights for the current round, using only the likelihood.
Parameters
- x_obsndarray
Observed data.
- xndarray
Simulated data.
- yndarray
Simulated parameters.
Returns
weights : ndarray
- log_like(x_obs, x, y)[source]
Calculate log likelihood.
Parameters
- x_obsndarray
Observed data.
- xndarray
Simulated data.
- yndarray
Simulated parameters.
Returns
- log_probndarray
Log likelihood.
- log_prior(y)[source]
Calculate log prior probability.
Parameters
- yndarray
Parameters to calculate prior for.
Returns
- log_probndarray
Log prior probability.
- log_prob(x, y)[source]
Calculate log probability under surrogate posterior.
Parameters
- xndarray
Observations.
- yndarray
Pparameters.
Returns
- log_probndarray
Log probability under surrogate posterior
- predict(x, x_err=None, y_true=None, log_like=None, n_samples=1000, neff_min=0, n_max=-1, corner=False, corner_reweight=False, seed=None)[source]
Generates the posterior distribution of parameters given input data.
Parameters
- xndarray
Input data for inference
- x_errndarray, optional
Measurement error for input data. Required for importance sampling. If not specified, use log_like instead.
- y_truendarray, optional
True parameters, if known.
- log_likefunction, optional
Log-likelihood function that takes in (x, x_path, y) and returns the log likelihood. Required for importance sampling when x_err not specified.
- n_samplesint, optional
Number of posterior samples to generate.
- neff_minint, optional
Minimum effective sample size required. If neff_min > n_samples, additional simulations will be generated until an ESS of neff_min is reached. Default: 0
- n_maxint, optional
Maximum number of simulations to generate to achieve neff_min.
- cornerbool, optional
If True, generates a corner plot of the posterior before reweighting.
- corner_reweightbool, optional
If True, generates a corner plot of the posterior after reweighting.
- seedint, optional
Random seed for generating parameters
Returns
- ysndarray
Posterior samples.
- weightsndarray
Importance weights.
- prepare_data(x_obs, n_sims)[source]
Generate training data for the current round.
Parameters
- x_obsndarray
Observed data for producing simulations.
- n_simsint
Number of simulations.
Returns
- result()[source]
SNPE: Returns the reweighted posterior from all rounds.
Returns
- all_thetasndarray
Parameter values from all rounds.
- all_weights: ndarray
Importance weights from all rounds.
- sample(x, y=None, n=5000, corner=False)[source]
Generates samples from the surrogate posterior.
Parameters
- xndarray
Input data for inference
- yndarray, optional
True parameters (for corner plot), if known.
- nint, optional
Number of samples to generate.
- cornerbool, optional
If True, generates a corner plot of the surrogate posterior samples.
Returns
- samplesndarray
Samples from the surrogate posterior.
- scale_x(x, back=False)[source]
Scale data to zero mean and unit variance, and vice versa.
Parameters
- xndarray
Data to be scaled.
- backbool, optional
If True, scales data back to original values.
Returns
- scale_y(y, back=False)[source]
Scale parameters to zero mean and unit variance, and vice versa
Parameters
- yndarray
Parameters to be scaled.
- backbool, optional
If True, scales parameters back to original values.
Returns
- set_params(state_dict)[source]
Load engine parameters from disk, including network weights and data pre-processing scales.
Parameters
- state_dictstr or state dict
State dict or path to saved state dict containing three keys: network_state_dict, x_scale, y_scale
Returns