Source code for nbi.engine

# rull: noqa: E402 F401
import copy
import os

# this needs to go in before importing torch
os.environ["KMP_DUPLICATE_LIB_OK"] = "True"

import corner
import matplotlib.pyplot as plt
import multiprocess as mp
import numpy as np
import torch
import wandb
from multiprocess import Pool
from torch import optim
from torch.optim.lr_scheduler import (
    CosineAnnealingWarmRestarts,
    MultiStepLR,
    ReduceLROnPlateau,
)

# this seems to be required for some environments
from torch.utils.data import DataLoader, dataloader
from tqdm import tqdm
from tqdm.notebook import tqdm as tqdmn

dataloader.multiprocessing = mp

from .data import BaseContainer
from .model import DataParallelFlow, get_featurizer, get_flow
from .utils import iid_gaussian, log_like_iidg, parallel_simulate


[docs] class NBI: """Neural Bayesian Inference Engine. NBI is an open-source software introduced to support both amortized and sequential Neural Posterior Estimation (NPE) methods, particularly tailored for astronomical inference problems, such as those involving light curves and spectra. The design of NBI addresses critical issues in the adaptation of NPE methods in astronomy. It provides built-in "featurizer" networks with demonstrated efficacy on sequential data, removing the need for custom featurizer networks by users. It also employs a modified algorithm, SNPE-IS, which enables asymptotically exact inference by using the surrogate posterior under NPE as a proposal distribution for importance sampling. Parameters ---------- flow : dict or nn.Module Dictionary containing hyperparameters for the Masked Autoregressive Flow. If dictionary, the keys include: - 'n_dims', dimension of model parameter space. - 'num_blocks', number of Masked Autoregressive Flow (MAF) blocks. - 'flow_hidden', hidden dimension for each MAF block - 'perm_seed', random seed for dimension permutation of each MAF block - 'n_mog', number of mixture of Gaussians as the base density. Recommended > 1 for multi-modal and/or non-Gaussian posterior distributions. Rule of thumb: twice the maximum number of posterior modes. Only 'n_dims' required. See class attribute default_flow_config for default values. If nn.Module, it's a custom normalizing flow. Not required for inference only mode when state_dict is supplied featurizer : dict or nn.Module Dictionary of hyperparameters for the pre-built neural network for dimensionality reduction or (currently not supported uet) a custom PyTorch network module. If dictionary, the keys include: - 'type': Name of pre-built network architecture. Default: 'resnet-gru'. - 'norm': Either 'weight_norm' or 'batch_norm'. - 'dim_in': Number of input data channels. - 'dim_out': Dimension of output feature vector. - 'dim_conv_max': Maximum hidden dimensions of ResNet layers. - 'depth': ResNet network depth. - 'n_rnn': Number of GRU layers. If nn.Module, it's a custom featurizer network that maps input sequence of shape [Batch, Channel, Length] to output feature vector of shape [Batch, Dimension]. Not required for inference only mode when state_dict is supplied state_dict : str or dictionary, optional State dict or path to saved state dict containing three keys: network_state_dict, x_scale, y_scale. Loads pre-trained model. simulator : function, optional Simulator function to generate data. Requires input of model parameters and returns simulated data. priors : list of scipy.stats objects, optional Prior distribution for Bayesian inference. If not provided, uniform prior is assumed. device : {'cpu', 'cuda', 'mps'} str, optional Device for neural network, default is 'cpu'. path : str, optional Path to save training set and model checkpoints. n_jobs : int, optional Number of parallel jobs for computation. labels : list of str, optional Names of parameters for inference. Must be in the same order as priors. tqdm_notebook : bool, optional If True, uses notebook version of tqdm for progress bars. network_reinit : bool, optional If True, re-initializes the network weights every round. Default is False, which often yield better results than True. scale_reinit : bool, optional If True, re-initializes data pre-processing scales every round. Default is True. """ corner_kwargs = { "quantiles": [0.16, 0.5, 0.84], "show_titles": True, "title_kwargs": {"fontsize": 16}, "fill_contours": True, "levels": 1.0 - np.exp(-0.5 * np.arange(0.5, 2.6, 0.5) ** 2), } default_flow_config = { "flow_hidden": 64, "num_blocks": 5, "perm_seed": 3, "n_mog": 1, } def __init__( self, flow=None, featurizer=None, state_dict=None, simulator=None, priors=None, device="cpu", path="test", n_jobs=1, labels=None, tqdm_notebook=False, network_reinit=False, scale_reinit=True, ): self.device = device self.init_env() if state_dict is not None: _state_dict = torch.load( state_dict, map_location=self.device, weights_only=False ) flow_config_all = _state_dict["flow_config"] featurizer = ( featurizer if featurizer is not None else _state_dict["featurizer_config"] ) else: flow_config_all = copy.copy(self.default_flow_config) flow_config_all.update(flow) if type(featurizer) == dict: featurizer = get_featurizer(featurizer["type"], featurizer) flow_config_all["num_cond_inputs"] = featurizer.num_outputs self.featurizer_config = copy.copy(featurizer) self.flow_config = flow_config_all self.network = get_flow(featurizer, **flow_config_all) self.network = DataParallelFlow(self.network).to( self.device, dtype=torch.float32 ) self.corner_kwargs.update({"labels": labels}) self.network_reinit = network_reinit self.scale_reinit = scale_reinit self.epoch = 0 self.prev_clip = 1e8 self.train_losses = [] self.val_losses = [] self.x_mean = None self.x_std = None self.y_mean = None self.y_std = None self.norm = [] self.prior = priors self.param_names = labels self.simulator = simulator self.directory = path self.n_jobs = n_jobs self.x_obs = None self.y_true = None self.x = None self.y = None self.process = None self.like = None self.round = 0 self.early_stop_count = 0 self.x_all = [] self.y_all = [] self.weights = [] self.neff = [] self.state_dict_0 = copy.deepcopy(self.get_network().state_dict()) self.best_params = None self.prev_state = [] self.prev_x_mean = [] self.prev_x_std = [] self.prev_y_mean = [] self.prev_y_std = [] try: os.mkdir(self.directory) except: pass if tqdm_notebook: self.tqdm = tqdmn else: self.tqdm = tqdm if state_dict is not None: self.set_params(state_dict)
[docs] def fit( self, x=None, y=None, noise=None, log_like=None, n_epochs=10, n_rounds=1, n_sims=-1, x_obs=None, y_true=None, n_reuse=0, batch_size=64, project="test", use_wandb=False, neff_stop=-1, early_stop_train=False, early_stop_patience=-1, f_val=0.1, lr=0.001, min_lr=None, decay_type="SGDR", plot=True, f_accept_min=-1, workers=8, ): """ Fit the Neural Bayesian Inference Engine. Trains the network based on provided data and parameters. Parameters ---------- x : ndarray of paths to individual simulations, optional First round training simulations. Only required when simulation and prior not specified during engine initialization. y : ndarray of shape (N, D) where D is parameter space dimension, optional First round training parameters. Only required when simulation and prior not specified during engine initialization. noise : ndarray or function, optional Measurement error and/or data augmentation during training. Array of Gaussian errorbars for fixed iid Gaussian noise. For ANPE, provide a function that takes in noiseless data and parameters (x, y) and outputs the noisified data and parameters (x', y'), which is the last pre-processing step before feeding into the neural network. log_like : function, optional Log-likelihood function that takes in (x, x_path, y) and returns the log likelihood. Required for importance sampling (SNPE) but not required when noise is iid Gaussian and specified as an errorbar array. n_epochs : int, optional Number of training epochs. n_rounds : int, optional Number of training rounds. n_sims : int, optional Number of simulations. x_obs : ndarray, optional Observed data. y_true : ndarray, optional True target values. n_reuse : int, optional Number of previous round training data to be reused for the current round. batch_size : int, optional Batch size for training and validation. project : str, optional Name of the project for logging. use_wandb : bool, optional If True, enables wandb logging. neff_stop : int, optional Early stopping criteria based on Effective Sample Size (ESS). Terminate inference when ESS exceeds this value. early_stop_train : bool, optional If True, terminates inference when the surrogate posterior (as measured by the ESS) does not improve for the current round. early_stop_patience : int, optional Number of epochs without improvement to trigger early stopping. f_val : float, optional Fraction of data to use for validation. Default: 0.1 lr : float, optional Learning rate. Default: 0.001 min_lr : float, optional Minimum learning rate for learning rate decay. Automatically calculated when not specified. decay_type : {'SGDR'} str, optional Type of learning rate decay. Default is Cosine annealing decay ("SGDR"). plot : bool, optional If True, plots results after training. f_accept_min : float, optional Minimum round sampling efficiency (defined as the ratio from the effective sample size to the total sample size) to terminate inference early. workers : int, optional Number of workers for data loading. Returns ------- """ # either simulate n_sims or provide pre computed samples assert n_sims > 0 or y is not None if type(noise) == np.ndarray: # for i.i.d. gaussian noise self.process = iid_gaussian(noise) self.like = log_like_iidg(noise) else: # for custom noise self.process = noise self.like = log_like self.n_epochs = n_epochs self.x_obs = x_obs self.y_true = y_true self.x = x self.y = y # this needs revision in another version self.wandb = use_wandb if self.wandb: self._init_wandb(project) if min_lr is None: min_lr = min(lr, lr / (n_sims / batch_size * n_epochs) * 10) print("Auto learning rate to min_lr =", min_lr) # for restarting training if len(self.x_all) == self.round: # this is not a restart because # data for this round has not been generated self.prepare_data(x_obs, n_sims) for i in range(n_rounds): print( f"\n---------------------- Round: {self.round} ----------------------" ) self._init_train(lr) self._init_scheduler(min_lr, decay_type=decay_type) x_round, y_round = self.get_round_data(n_reuse) data_container = BaseContainer( x_round, y_round, f_test=0, f_val=f_val, process=self.process ) self._init_loader(data_container, batch_size, workers=workers) for epoch in range(n_epochs): self.epoch = epoch self._train_step() self._step_scheduler() self._validate_step() if self.wandb: wandb.log( { "Train Loss": self.train_losses[-1][-1], "Val Loss": self.val_losses[-1][-1], } ) path = os.path.join( self.directory, str(self.round), str(epoch) + ".pth" ) self.save_params(path) # if the validation loss has not improved by early_stop_patience epochs # load the epoch with lowest validation loss, i.e., epoch_best epoch_best = np.argmin(self.val_losses[-1]) self.best_params = os.path.join( self.directory, str(self.round), str(epoch_best) + ".pth" ) if early_stop_patience > 0 and epoch_best < ( epoch - early_stop_patience ): print( "early stopping, loading state dict from epoch", epoch_best, ) # load from the epoch with lowest validation loss self.set_params(self.best_params) break self.round += 1 # If we're doing Amortized Neural Posterior Estimation (n_rounds=1) # then the rest of the code is irrelavant if n_rounds == 1: return # Produce the training set for the next rounds # Even if we're at the last round, still need to produce these simulations for # importance sample self.prepare_data(x_obs, n_sims) if plot: self.weighted_corner(x_obs, y_true) # stop further rounds of SNPE? # option 1: enough effective posterior samples if np.sum(self.neff) > neff_stop > 0: print("Success: Exceed specified stopping sample size!") if plot: self._corner_all() return # option 2: surrogate posterior is good enough (in terms of efficiency) f_accept_round = self.neff[-1] / n_sims if self.neff[-1] / n_sims > f_accept_min > 0: print(f"Success: Sampling efficiency is {f_accept_round:.1f}!") if plot: self._corner_all() return # option 3: surrogate posterior from last round was better # do: load state from previous round if early_stop_train and self.round > 1: if 1 < self.neff[-1] < self.neff[-2]: print( "Early stop: Surrogate posterior did not improve for this round" ) epoch_best = np.argmin(self.val_losses[-2]) # path_round = os.path.join() path = os.path.join( self.directory, str(self.round - 2), str(epoch_best) + ".pth" ) self.best_params = path self.set_params(path) return if plot: self._corner_all()
def _corner_all(self): """ SNPE: Corner plot for the reweighted posterior from all rounds. Returns ------- """ print("reweighted posterior from all rounds") all_thetas, all_weights = self.result() self.corner(self.x_obs, all_thetas, y_true=self.y_true, weights=all_weights)
[docs] def prepare_data(self, x_obs, n_sims): """ Generate training data for the current round. Parameters ---------- x_obs : ndarray Observed data for producing simulations. n_sims : int Number of simulations. Returns ------- """ ys = self._draw_params(x_obs, n_sims) np.save(os.path.join(self.directory, str(self.round)) + "_y_all.npy", ys) x_path, good = self.simulate(ys) np.save(os.path.join(self.directory, str(self.round)) + "_x.npy", x_path[good]) np.save(os.path.join(self.directory, str(self.round)) + "_y.npy", ys[good]) self.x_all.append(np.array(x_path)[good]) self.y_all.append(np.array(ys)[good]) weights = self.importance_reweight(x_obs, self.x_all[-1], self.y_all[-1]) self.weights.append(weights) np.save(os.path.join(self.directory, str(self.round)) + "_w.npy", weights) if self.like is not None and x_obs is not None: neff = 1 / (weights**2).sum() - 1 self.neff.append(neff) print( "Effective sample size for current/all rounds", f"{neff:.1f}/{np.sum(self.neff):.1f}", )
[docs] def weighted_corner(self, x_obs, y_true): """ SNPE: Reweighted corner plot for the current round. Parameters ---------- x_obs : ndarray Observed data. y_true : ndarray True parameters, if known. Returns ------- """ try: self.corner(x_obs, self.y_all[-1], y_true=y_true, weights=self.weights[-1]) except: print("corner plot failed")
[docs] def result(self): """ SNPE: Returns the reweighted posterior from all rounds. Returns ------- all_thetas : ndarray Parameter values from all rounds. all_weights: ndarray Importance weights from all rounds. """ all_weights = np.concatenate( [self.weights[i] * self.neff[i] for i in range(self.round + 1)] ) all_weights /= all_weights.sum() all_thetas = np.concatenate(self.y_all) return all_thetas, all_weights
[docs] def get_round_data(self, n_reuse): """ Returns training data for the current round. Parameters ---------- n_reuse : int Number of previous round training data to be reused for the current round. Returns ------- x_round : ndarray Training data for the current round. y_round : ndarray Training parameters for the current round. """ if n_reuse == -1: return np.concatenate(self.x_all), np.concatenate(self.y_all) else: x_round = self.x_all[max(0, self.round - n_reuse) : self.round + 1] x_round = np.concatenate(x_round) y_round = self.y_all[max(0, self.round - n_reuse) : self.round + 1] y_round = np.concatenate(y_round) return x_round, y_round
[docs] def importance_reweight(self, x_obs, x, y): """ SNPE: Calculate importance reweights for the current round. Parameters ---------- x_obs : ndarray Observed data. x : ndarray Simulated data. y : ndarray Simulated parameters. Returns ------- weights : ndarray Importance weights. """ if self.like is None or x_obs is None: return None loglike = self.log_like(x_obs, x, y) logprior = self.log_prior(y) logproposal = self.log_prob(x_obs, y) log_weights = loglike + logprior - logproposal bad = np.isnan(log_weights) + np.isinf(log_weights) valid_weights = log_weights[~bad] if len(valid_weights) == 0: print("All log weights are NaN or Inf — skipping this round!") return np.zeros_like(log_weights) log_weights -= log_weights[~bad].max() weights = np.exp(log_weights) weights[bad] = 0 weights /= weights.sum() return weights
[docs] def importance_reweight_like_only(self, x_obs, x, y): """ SNPE: Calculate importance reweights for the current round, using only the likelihood. Parameters ---------- x_obs : ndarray Observed data. x : ndarray Simulated data. y : ndarray Simulated parameters. Returns ------- weights : ndarray """ if self.like is None or x_obs is None: return None log_weights = self.log_like(x_obs, x, y) bad = np.isnan(log_weights) + np.isinf(log_weights) log_weights -= log_weights[~bad].max() weights = np.exp(log_weights) weights[bad] = 0 weights /= weights.sum() return weights
[docs] def init_env(self): """ Initialize environment for training. Returns ------- """ torch.manual_seed(0) np.random.seed(0) if self.device == "mps": try: torch.mps.manual_seed(0) except: print( "MPS not supported by current PyTorch installation. Reverting to CPU" ) self.device = "cpu" elif "cuda" in self.device: if not torch.cuda.is_available(): print( "CUDA not supported by current PyTorch installation. Reverting to CPU" ) else: torch.cuda.manual_seed(0)
[docs] def get_network(self): """ Returns the network module without DataParallel wrapper, if any. Returns ------- nn.Module Network module without the DataParallel wrapper. """ if type(self.network) == DataParallelFlow: return self.network.module else: return self.network
[docs] def set_params(self, state_dict): """ Load engine parameters from disk, including network weights and data pre-processing scales. Parameters ---------- state_dict : str or state dict State dict or path to saved state dict containing three keys: network_state_dict, x_scale, y_scale Returns ------- """ if type(state_dict) == str: state_dict = torch.load( state_dict, map_location=self.device, weights_only=False ) model_state_dict = state_dict["model_state_dict"] # Move x_scale and y_scale to CPU before converting to numpy arrays x_scale = state_dict["x_scale"].cpu().numpy() y_scale = state_dict["y_scale"].cpu().numpy() self.x_mean = x_scale[0] self.x_std = x_scale[1] self.y_mean = y_scale[0] self.y_std = y_scale[1] self.get_network().load_state_dict(model_state_dict)
[docs] def get_params(self): """ Saves the network weights and pre-processing scales to disk Returns ------- """ x_scale = np.array([self.x_mean, self.x_std], dtype=np.float32) y_scale = np.array([self.y_mean, self.y_std], dtype=np.float32) # Convert numpy arrays to PyTorch tensors x_scale_tensor = torch.from_numpy(x_scale) y_scale_tensor = torch.from_numpy(y_scale) # Assuming 'network' is your model model_state_dict = copy.deepcopy(self.get_network().state_dict()) # Create a new dictionary to store model state and additional tensors state_dict = { "model_state_dict": model_state_dict, "x_scale": x_scale_tensor, "y_scale": y_scale_tensor, "flow_config": self.flow_config, "featurizer_config": self.featurizer_config, } return state_dict
[docs] def save_params(self, path): """ Saves the network weights and pre-processing scales to disk Returns ------- """ state_dict = self.get_params() torch.save(state_dict, path)
[docs] def scale_y(self, y, back=False): """ Scale parameters to zero mean and unit variance, and vice versa Parameters ---------- y : ndarray Parameters to be scaled. back : bool, optional If True, scales parameters back to original values. Returns ------- """ if back: return y * self.y_std + self.y_mean else: if len(y.shape) != 2: y = np.expand_dims(y, axis=list(range(2 - len(y.shape)))) return (y - self.y_mean) / self.y_std
[docs] def scale_x(self, x, back=False): """ Scale data to zero mean and unit variance, and vice versa. Parameters ---------- x : ndarray Data to be scaled. back : bool, optional If True, scales data back to original values. Returns ------- """ if back: return x * self.x_std + self.x_mean else: return (x - self.x_mean) / self.x_std
[docs] def predict( self, x, x_err=None, y_true=None, log_like=None, n_samples=1000, neff_min=0, f_accept_min=0.001, corner=False, corner_reweight=False, seed=None, ): """ Generates the posterior distribution of parameters given input data. Parameters ---------- x : ndarray Input data for inference x_err : ndarray, optional Measurement error for input data. Required for importance sampling. If not specified, use log_like instead. y_true : ndarray, optional True parameters, if known. log_like : function, optional Log-likelihood function that takes in (x, x_path, y) and returns the log likelihood. Required for importance sampling when x_err not specified. n_samples : int, optional Number of posterior samples to generate. f_accept_min : float, optional Minimum acceptance rate for importance sampling required. If the acceptance rate is less than f_accept_min, no additional simulations will be generated due to low efficiency. Required only when neff_min > 0. Default: 0.001 n_max : int, optional Maximum number of simulations to generate to achieve neff_min. corner : bool, optional If True, generates a corner plot of the posterior before reweighting. corner_reweight : bool, optional If True, generates a corner plot of the posterior after reweighting. seed : int, optional Random seed for generating parameters Returns ------- ys : ndarray Posterior samples. weights : ndarray Importance weights. """ if seed is not None: torch.manual_seed(seed) np.random.seed(seed) self.like = log_like_iidg(x_err) if type(x_err) == np.ndarray else log_like if self.round == 0: self.round = 1 ys = self._draw_params(x, n_samples) if corner: print("surrogate posterior") self.corner(x, ys, y_true=y_true) if x_err is None and log_like is None: return ys x_path, good = self.simulate(ys) x_path = x_path[good] ys = ys[good] weights = self.importance_reweight(x, x_path, ys) neff = 1 / (weights**2).sum() - 1 f_accept = neff / n_samples print(f"Effective Sample Size = {neff:.1f}") print(f"Sampling efficiency = {f_accept * 100:.1f}%") if neff < neff_min: if f_accept > f_accept_min: n_required = int(n_samples * (1 / f_accept - 1)) print("Requires N =", n_required, "more simulations to reach n_samples") n_required = min(n_required, n_max - n_samples) ys_extra = self._draw_params(x, n_required) x_path, good = self.simulate(ys_extra) x_path = x_path[good] ys_extra = ys_extra[good] weights_extra = self.importance_reweight(x, x_path, ys_extra) neff_extra = 1 / (weights_extra**2).sum() - 1 print("Total effective sample size N =", "%.1f" % (neff + neff_extra)) ys = np.concatenate([ys, ys_extra]) weights = np.concatenate([weights, weights_extra]) else: print( "Sampling efficiency below f_accept_min. No additional simulations will be generated." ) if corner_reweight: self.corner(x, ys, y_true=y_true, weights=weights) return ys, weights
[docs] def sample(self, x, y=None, n=5000, corner=False): """ Generates samples from the surrogate posterior. Parameters ---------- x : ndarray Input data for inference y : ndarray, optional True parameters (for corner plot), if known. n : int, optional Number of samples to generate. corner : bool, optional If True, generates a corner plot of the surrogate posterior samples. Returns ------- samples : ndarray Samples from the surrogate posterior. """ self.network.eval() x = self.scale_x(x) x = torch.from_numpy(x).to(self.device, dtype=torch.float32) with torch.no_grad(): # GPU memory control (make larger?) if n > 20000: s = [] for i in range(n // 100000 + 1): s.append(self.get_network()(x, n=n, sample=True).cpu().numpy()) s = np.concatenate(s)[:n] else: s = self.get_network()(x, n=n, sample=True).cpu().numpy() samples = self.scale_y(s, back=True)[0] if corner: self.corner(x, samples, y_true=y) return samples
[docs] def simulate(self, thetas): """ Generates simulations for provided parameters, which are saved to disk. An array containing paths to the simulations is returned. Parameters ---------- thetas : ndarray Parameters to generate simulations for. Returns ------- x_path : ndarray Paths to generated simulations. """ path_round = os.path.join(self.directory, str(self.round)) try: os.mkdir(path_round) except: pass if self.x is not None and self.round == 0: print("Use precomputed simulations for round ", self.round) masks = np.array([True] * len(self.x)) return self.x, masks else: n = len(thetas) paths = np.array( [os.path.join(path_round, str(i) + ".npy") for i in range(n)] ) per_job = n // self.n_jobs njobs = np.zeros(self.n_jobs) + per_job njobs[np.arange(n % self.n_jobs)] += 1 njobs = np.array( [njobs[:i].sum() for i in range(self.n_jobs + 1)], dtype=int ) jobs = [ [ thetas[njobs[i] : njobs[i + 1]], paths[njobs[i] : njobs[i + 1]], self.simulator, ] for i in range(self.n_jobs) ] with Pool(self.n_jobs) as p: masks = p.map(parallel_simulate, jobs) masks = np.concatenate(masks) return paths, masks
def _train_step(self): """ Single training step. Returns ------- """ np.random.seed(self.epoch) self.network.train() train_loss = [] with self.tqdm(total=len(self.train_loader.dataset)) as pbar: for batch_idx, data in enumerate(self.train_loader): x, y = data x = self.scale_x(x).to(self.device, dtype=torch.float32) y = self.scale_y(y).to(self.device, dtype=torch.float32) self.optimizer.zero_grad() loss = self.network(x, y) loss = loss.mean() train_loss.append(loss.item()) loss.backward() if self.clip > 0: self.norm.append( torch.nn.utils.clip_grad_norm_( self.network.parameters(), self.prev_clip ).cpu() ) self.optimizer.step() pbar.update(x.shape[0]) pbar.set_description( "Epoch {:d}: Train, Loglike in nats: {:.6f}".format( self.epoch, -np.mean(train_loss) ) ) if self.clip > 0: self.prev_clip = np.percentile(np.array(self.norm), self.clip) train_loss = np.array(train_loss).mean() self.train_losses[-1].append(train_loss) def _validate_step(self): """ Single validation step. Returns ------- """ np.random.seed(0) self.network.eval() val_loss = [] with self.tqdm(total=len(self.valid_loader.dataset)) as pbar: objs = 0 for batch_idx, data in enumerate(self.valid_loader): x, y = data x = self.scale_x(x).to(self.device, dtype=torch.float32) y = self.scale_y(y).to(self.device, dtype=torch.float32) objs += x.shape[0] self.optimizer.zero_grad() with torch.no_grad(): loss = self.network(x, y).mean() val_loss.append(loss.detach().cpu().numpy()) pbar.update(x.shape[0]) pbar.set_description( f"- Val, Loglike in nats: {-np.sum(val_loss) / (batch_idx + 1):.6f}" ) val_loss = np.median(val_loss) pbar.set_description(f"- Val, Loglike in nats: {-val_loss:.6f}") self.val_losses[-1].append(val_loss) def _init_wandb(self, project): """ Initialize weights & biases logging. Parameters ---------- project : str Project name. Returns ------- """ wandb.init(project=project, config=self.args, name=self.name) wandb.watch(self.network) def _init_train(self, lr, clip=85): """ Initialize training environment. Parameters ---------- lr : float Learning rate. clip : [0, 100] float, optional Gradient clipping percentile for each epoch. If 0, no clipping is performed. Returns ------- """ self.clip = clip if self.network_reinit: self.get_network().load_state_dict(self.state_dict_0) self.optimizer = optim.Adam(self.network.parameters(), lr=lr) torch.manual_seed(0) self.train_losses.append([]) self.val_losses.append([]) def _init_scheduler( self, min_lr, decay_type="SGDR", patience=5, decay_threshold=0.01 ): """ Initialize learning rate scheduler. Parameters ---------- min_lr : float Minimum learning rate. decay_type : str, optional Learning rate decay type. Options: "SGDR", "plateau", or a comma-separated list/string of epochs to decay at. patience : int, optional plateau: Number of epochs without improvement to trigger learning rate decay. decay_threshold : float, optional plateau: Threshold for measuring the new optimum, to only focus on significant changes. Returns ------- """ self.decay_type = decay_type if decay_type == "plateau": self.scheduler = ReduceLROnPlateau( self.optimizer, factor=0.1, patience=patience, threshold_mode="abs", cooldown=0, verbose=True, threshold=decay_threshold, min_lr=1e-6, ) elif "SGDR" in decay_type: self.scheduler = CosineAnnealingWarmRestarts( self.optimizer, T_0=self.n_epochs, T_mult=1, eta_min=min_lr ) else: self.scheduler = MultiStepLR( self.optimizer, np.array(decay_type.split(","), dtype=int), gamma=0.1 ) def _step_scheduler(self): """ Step learning rate scheduler. Returns ------- """ if self.decay_type == "plateau": self.scheduler.step(self.train_losses[-1][-1]) else: self.scheduler.step() def _init_loader(self, data_container, batch_size, workers=4): """ Initialize data loader. Parameters ---------- data_container : DataContainer Data container object. batch_size : int Batch size. workers : int, optional Number of workers for data loader. Returns ------- """ train_container, val_container, test_container = data_container.get_splits() kwargs = { "num_workers": workers, "pin_memory": False, "drop_last": True, "persistent_workers": True, } self.train_loader = DataLoader( train_container, batch_size=batch_size, shuffle=True, **kwargs ) self.valid_loader = DataLoader(val_container, batch_size=batch_size, **kwargs) # if self.network_reinit or self.round == 0: if self.scale_reinit or self.round == 0: self._init_scales() def _draw_params(self, x, n): """ Draw parameters from prior (ANPE or SNPE first round) or surrogate posterior. Parameters ---------- x : ndarray Input data for inference. n : int Number of parameters to draw. Returns ------- """ # first round: precomputed data or draw from prior if self.round == 0: if self.y is not None: return self.y else: params = [] for prior in self.prior: params.append(prior.rvs(n)) params = np.array(params).T return params # 2+ round: sample from surrogate posterior else: params = self.sample(x, n=n) logprior = self.log_prior(params) if np.isinf(logprior).any(): print("Samples outside prior N =", np.isinf(logprior).sum()) params = params[~np.isinf(logprior)] while len(params) < n: n_needed = n - len(params) new_params = self.sample(x, n=n) new_logprior = self.log_prior(new_params) new_params = new_params[~np.isinf(new_logprior)] params = np.concatenate([params, new_params]) params = params[:n] return params def _init_scales(self): """ Calculate data pre-processing scales from the current round training data. Returns ------- """ x_list = [] y_list = [] n = 0 for batch_idx, data in enumerate(self.train_loader): x, y = data x_list.append(x.cpu().numpy()) y_list.append(y.cpu().numpy()) n += x_list[-1].shape[0] if n > 5000: break x_list = np.concatenate(x_list, axis=0) y_list = np.concatenate(y_list, axis=0) self.x_mean = x_list.mean(-1, keepdims=True).mean(0, keepdims=True) self.x_std = x_list.std(-1, keepdims=True).mean(0, keepdims=True) self.y_mean = y_list.mean(0, keepdims=True) self.y_std = y_list.std(0, keepdims=True)
[docs] def log_prior(self, y): """ Calculate log prior probability. Parameters ---------- y : ndarray Parameters to calculate prior for. Returns ------- log_prob : ndarray Log prior probability. """ if self.prior is None: return np.zeros(len(y)) else: log_prob = np.zeros(len(y)) for i, prior in enumerate(self.prior): log_prob += prior.logpdf(y[:, i]) return log_prob
[docs] def log_like(self, x_obs, x, y): """ Calculate log likelihood. Parameters ---------- x_obs : ndarray Observed data. x : ndarray Simulated data. y : ndarray Simulated parameters. Returns ------- log_prob : ndarray Log likelihood. """ values = [] for i in range(len(x)): values.append(self.like(x_obs, x[i], y[i])) return np.array(values)
[docs] def log_prob(self, x, y): """ Calculate log probability under surrogate posterior. Parameters ---------- x : ndarray Observations. y : ndarray Pparameters. Returns ------- log_prob : ndarray Log probability under surrogate posterior """ if self.round == 0: return self.log_prior(y) x = self.scale_x(x) y = self.scale_y(y) x = torch.from_numpy(x).to(self.device, dtype=torch.float32) y = torch.from_numpy(y).to(self.device, dtype=torch.float32) with torch.no_grad(): # it appears that DataParallal doesn't work properly here log_prob = self.network.module(x, y).cpu().numpy()[:, 0] * -1 return log_prob
[docs] def corner( self, x, y=None, weights=None, color="k", y_true=None, plot_datapoints=True, plot_density=False, range_=None, truth_color="r", n=5000, ): """ Wrapper function to make corner plot. Parameters ---------- x : ndarray Input data for inference. y : ndarray, optional Parameters to plot. If None, parameters are drawn from the surrogate posterior. weights : ndarray, optional Importance weights for reweighting. color : str, optional Color of the corner plot y_true : ndarray, optional True parameters for crosshairs plot_datapoints : bool, optional If True, plot data points as scatter. plot_density : bool, optional If True, plot 2D densities. `range_` : list, optional Percentile of data to plot in corner plot. truth_color : str, optional Color of crosshairs. n : int, optional Number of samples to draw from the surrogate posterior, if y is not specified. Returns ------- """ if y is None: y = self.sample(x, n=n) corner.corner( y, truths=y_true, color=color, plot_datapoints=plot_datapoints, range=range_, plot_density=plot_density, truth_color=truth_color, weights=weights, **self.corner_kwargs, ) plt.show()