Source code for TransportMaps.CLI.cli_tmap_sampling

#!/usr/bin/env python

#
# This file is part of TransportMaps.
#
# TransportMaps is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# TransportMaps is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with TransportMaps.  If not, see <http://www.gnu.org/licenses/>.
#
# Transport Maps Library
# Copyright (C) 2015-2018 Massachusetts Institute of Technology
# Uncertainty Quantification group
# Department of Aeronautics and Astronautics
#
# Author: Transport Map Team
# Website: transportmaps.mit.edu
# Support: transportmaps.mit.edu/qa/
#

import logging
import sys
from pathlib import Path
from typing import List
import numpy as np
import click

from TransportMaps import MPIPoolContext, Samplers, Distributions, Maps
from TransportMaps.Distributions import Inference

from . import AvailableOptions as AO
from ._utils import _general_options, _lambda_str_to_list_argument, _load_input, _select_dist, _ask_overwrite, H5, \
    logged

__all__ = [
    'tmap_sampling'
]


@click.group(
    name='tmap-sampling',
    help="""
    Given a file (--input) storing the transport map pushing forward a base distribution
    to a target distribution, provides a number of sampling routines.
    All the generated outputs are stored in a hdf5 file.
    """
)
[docs]def tmap_sampling(): pass
@tmap_sampling.command( name='quadrature', help='compute quadrature points using the sampling distribution --dist' ) @_general_options @click.option( '--dist', type=click.Choice(AO.AVAIL_DISTRIBUTIONS), required=True, help='distribution for which are plotted/computed aligned slices' ) @click.option( '--output-h5', 'path_output_h5', required=True, type=Path, help='path to the hdf5 file storing big size postprocess data.' ) @click.option( '--qtype', required=True, type=click.IntRange(0, 3), help=f'quadrature type for the discretization of the KL-divergence {AO.AVAIL_QTYPE}' ) @click.option( '--qnum', required=True, type=_lambda_str_to_list_argument(int), help='quadrature level (must be a comma separated list if qtype requires it)' ) @logged def quadrature( path_input: Path, path_output_h5: Path, overwrite: bool, nprocs: int, dist: str, qtype: int, qnum: List[int], log: int ): stg = _load_input(path_input) d = _select_dist(stg, dist) if not overwrite and path_output_h5.is_file() and not _ask_overwrite(): logging.info('Terminating') sys.exit(0) with MPIPoolContext(nprocs) as mpi_pool: (x, w) = d.quadrature(qtype, qnum, mpi_pool=mpi_pool) with H5(path_output_h5, 'w') as h5_root: h5_root.create_dataset('x', data=x, chunks=True) h5_root.create_dataset('w', data=w, chunks=True) logging.info('Quadrature generated and stored') @tmap_sampling.command( name='importance-sampling', help='compute quadrature points of the target distribution ' 'using importance sampling from approximate base distribution ' 'using the corresponding exact base distribution as bias.' ) @_general_options @click.option( '--output-h5', 'path_output_h5', required=True, type=Path, help='path to the hdf5 file storing big size postprocess data.' ) @click.option( '--n-samples', type=int, required=True, help='Number of samples to generate' ) @logged def importance_sampling( path_input: Path, n_samples: int, path_output_h5: Path, overwrite: bool, nprocs: int, log: int ): stg = _load_input(path_input) if not overwrite and path_output_h5.is_file() and not _ask_overwrite(): logging.info('Terminating') sys.exit(0) with MPIPoolContext(nprocs) as mpi_pool: sampler = Samplers.ImportanceSampler( stg.approx_base_distribution, stg.base_distribution ) (x, w) = sampler.rvs(n_samples, mpi_pool_tuple=(mpi_pool, None)) x = stg.approx_target_distribution.map_samples_base_to_target( x, mpi_pool=mpi_pool ) with H5(path_output_h5, 'w') as h5_root: h5_root.create_dataset('x', data=x, chunks=True) h5_root.create_dataset('w', data=w, chunks=True) logging.info('Importance sampling quadrature generated and stored') @tmap_sampling.command( name='mcmc', help='compute quadrature points of the target distribution ' 'using Markov Chain Monte Carlo from approximate base distribution ' 'using the corresponding exact base distribution as bias.' ) @_general_options @click.option( '--output-h5', 'path_output_h5', required=True, type=Path, help='path to the hdf5 file storing big size postprocess data.' ) @click.option( '--method', type=click.Choice(list(AO.AVAIL_MCMC_ALGORITHMS.keys())), default='mh', help=f'MCMC method to be used for sampling. Options are {AO.AVAIL_MCMC_ALGORITHMS}' ) @click.option( '--n-samples', type=int, required=True, help='Number of samples to generate' ) @click.option( '--burnin', type=int, default=0, help='Number of samples to be used as burn-in' ) @click.option( '--skip', type=int, default=0, help='number of sample to be skipped (>=0) in storage (a NSAMP*SKIP chain is subsampled)' ) @click.option( '--mh-eps', type=float, default=0.1, help='variance of the Standard Normal proposal in Metropolis-Hasting' ) @click.option( '--mh-pcn', type=bool, is_flag=True, help='Use the preconditioned Crank-Nicolson proposal ' 'N(sqrt(1-eps**2)u,eps**2 * Sigma) ' 'where Sigma is the covariance of the prior (has to be Gaussian).' ) @click.option( '--hmc-eps', type=float, default=0.2, help='epsilon value in Hamiltonian Monte Carlo' ) @click.option( '--hmc-nsteps', type=int, default=1, help='number of steps per sample in Hamiltonian Monte Carlo' ) @logged def mcmc( path_input: Path, path_output_h5: Path, method: str, n_samples: int, burnin: int, skip: int, mh_eps: float, mh_pcn: bool, hmc_eps: float, hmc_nsteps: int, overwrite: bool, nprocs: int, log: int ): stg = _load_input(path_input) if not overwrite and path_output_h5.is_file() and not _ask_overwrite(): logging.info('Terminating') sys.exit(0) if method == 'mh': if mh_pcn: if isinstance(stg.target_distribution, Inference.BayesPosteriorDistribution) and \ ( isinstance(stg.target_distribution.prior, Distributions.NormalDistribution) or \ isinstance(stg.target_distribution.prior, Distributions.StandardNormalDistribution) ): prop_distribution = Distributions.MeanConditionallyGaussianDistribution( Maps.PreconditionedCrankNicolsonMap( stg.base_distribution.dim, mh_eps), mh_eps ** 2 * stg.target_distribution.prior.covariance ) else: raise ValueError( "In order to use preconditioned Crank-Nicolson " + \ "the target distribution must be a Bayesian posterior with " + \ "normal prior" ) else: prop_distribution = Distributions.MeanConditionallyNormalDistribution( Maps.IdentityTransportMap(stg.base_distribution.dim), mh_eps * np.eye(stg.base_distribution.dim) ) sampler = Samplers.MetropolisHastingsSampler( stg.approx_base_distribution, prop_distribution) elif method == 'mhind': sampler = Samplers.MetropolisHastingsIndependentProposalsSampler( stg.approx_base_distribution, stg.base_distribution) elif method == 'hmc': if not isinstance(stg.base_distribution, Distributions.StandardNormalDistribution): logging.warning( "The HMC algorithm uses a Standard Normal distribution " "as default proposal" ) sampler = Samplers.HamiltonianMonteCarloSampler(stg.approx_base_distribution) else: raise NotImplementedError(f'Method {method} not implemented') if method in ['mh', 'mhind']: with MPIPoolContext(nprocs) as mpi_pool: (s, _) = sampler.rvs(n_samples * (skip + 1), x0=None, mpi_pool_tuple=(mpi_pool, None)) elif method == 'hmc': (s, _) = sampler.rvs( n_samples * (skip + 1), x0=None, epsilon=hmc_eps, n_steps=hmc_nsteps) s = s[burnin::(skip + 1), :] # Skip burnin and subsampling with MPIPoolContext(nprocs) as mpi_pool: x = stg.approx_target_distribution.map_samples_base_to_target( s, mpi_pool=mpi_pool) with H5(path_output_h5, 'w') as h5_root: h5_root.create_dataset('s', data=s, chunks=True) h5_root.create_dataset('x', data=x, chunks=True) logging.info('Markov chain generated and stored') if __name__ == '__main__': tmap_sampling()