#
# This file is part of TransportMaps.
#
# TransportMaps is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# TransportMaps is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with TransportMaps. If not, see <http://www.gnu.org/licenses/>.
#
# Transport Maps Library
# Copyright (C) 2015-2018 Massachusetts Institute of Technology
# Uncertainty Quantification group
# Department of Aeronautics and Astronautics
#
# Author: Transport Map Team
# Website: transportmaps.mit.edu
# Support: transportmaps.mit.edu/qa/
#
import sys
import os
import os.path
import fasteners
import pickle
import numpy as np
from TransportMaps import Maps
from . import AvailableOptions as AO
from .ScriptBase import Script
from TransportMaps.External import H5PY_SUPPORT, PLOT_SUPPORT
import TransportMaps.Diagnostics as DIAG
import TransportMaps.Samplers as SAMP
import TransportMaps.Distributions as DIST
import TransportMaps.Distributions.Inference as DISTINF
if H5PY_SUPPORT:
import h5py
if PLOT_SUPPORT:
import matplotlib.pyplot as plt
__all__ = ['Postprocess']
[docs]class Postprocess(Script):
[docs] cmd_usage_str = "Usage: tmap-postprocess "
[docs] opts_usage_str = """[-h -v -I]
--input=INPUT --output=OUTPUT
[--store-fig-dir=STR --store-fig-fmats=LIST
--extra-tit=STR --no-plotting
--aligned-conditionals=STR
--alc-n-points-x-ax=INT
--alc-n-tri-plots=INT
--alc-anchor=LIST
--alc-range=LIST
--random-conditionals=STR
--rndc-n-points-x-ax=INT
--rndc-n-plots-x-ax=INT
--rndc-anchor=LIST
--rndc-range=LIST
--var-diag=STR
--var-diag-qtype=INT
--var-diag-qnum=INT
--aligned-marginals=STR
--alm-n-points=INT
--alm-n-tri-plots=INT
--alm-scatter
--quadrature=STR
--quadrature-qtype=INT
--quadrature-qnum=INT
--quadrature-bsize=INT
--importance-samples=INT
--mcmc=STR
--mcmc-samples=INT
--mcmc-burnin=INT
--mcmc-skip=INT
--mcmc-ess
--mcmc-ess-skip=INT
--mcmc-ess-q=FLOAT
--mcmc-ess-corr-plot
--mcmc-ess-corr-plot-lag=LAG
--mcmc-ess-hist-plot
--mcmc-mh-eps=FLOAT
--mcmc-mh-pcn
--mcmc-hmc-eps=FLOAT
--mcmc-hmc-nsteps=INT
--chunk-size=INT --log=LOG --nprocs=INT]
"""
[docs] usage_str = cmd_usage_str + opts_usage_str
[docs] docs_distributions_str = \
AO.print_avail_options(AO.AVAIL_DISTRIBUTIONS,
' ', False)
[docs] docs_mcmc_str = \
' --mcmc=ALG algorithm to be used to generate Markov Chain.\n' + \
AO.print_avail_options(AO.AVAIL_MCMC_ALGORITHMS,
' ')
[docs] docs_log_str = \
' --log=LOG log level (default=30). Uses package logging.\n' + \
AO.print_avail_options(AO.AVAIL_LOGGING,' ')
[docs] docs_descr_str = """DESCRIPTION
Given a file (--input) storing the transport map pushing forward a base distribution
to a target distribution, provides a number of diagnositic routines.
All files involved are stored and loaded using the python package pickle and
an extra file OUTPUT.hdf5 is created to store big datasets in the hdf5 format.
In the following default values are shown in brackets."""
[docs] docs_options_str = """
OPTIONS - input/output:
--input=INPUT path to the file containing the target distribution,
the base distribution and the transport map pushing forward
the base to the target.
--output=OUTPUT path to the file storing all postprocess data.
The additional file OUTPUT.hdf5 will be used to store
the more memory consuming data.
--store-fig-dir=DIR path to the directory where to store the figures.
--store-fig-fmats=FMATS figure formats - see matplotlib for supported formats (svg)
--extra-tit=TITLE additional title for the figures' file names.
--no-plotting do not plot figures, but only store their data.
(requires --output or --store-fig-dir)
OPTIONS - Diagnostics:
--aligned-conditionals=DIST plot aligned slices of the selected DIST:
""" + docs_distributions_str + """ Optional arguments:
--alc-n-points-x-ax=N number of discretization points per axis (40)
--alc-n-tri-plots=N number of subplots (0)
--alc-anchor=LIST list of floats "f1,f2,f3..." for the anchor point (0)
--alc-range=LIST list of two floats "f1,f2" for the range (-5,5)
--random-conditionals=DIST plot randomly chosen slices of the selected DIST:
""" + docs_distributions_str + """ Optional arguments:
--rndc-n-points-x-ax=N number of discretization points per axis (40)
--rndc-n-plots-x-ax=N number of subplots (0)
--rndc-anchor=LIST list of floats "f1,f2,f3..." for the anchor point (0)
--rndc-range=LIST list of two floats "f1,f2" for the range (-5,5)
--var-diag=DIST compute variance diagostic using the sampling DIST:
""" + docs_distributions_str + """ Optional arguments:
--var-diag-qtype=QTYPE quadrature type to be used (0)
--var-diag-qnum=QNUM level of the quadrature (1000)
OPTIONS - Sampling:
--aligned-marginals=DIST plot aligned marginals of the selected DIST:
""" + docs_distributions_str + """ Optional arguments:
--alm-n-points=N number of samples to be used for the kernel density estimation
--alm-n-tri-plots=N number of subplots (0)
--alm-scatter produce scatter plots instead of contours
--quadrature=DIST generate quadrature for the selected DIST:
""" + docs_distributions_str + """ Optional arguments:
--quadrature-qtype=QTYPE generate quadrature of type QTYPE (0)
--quadrature-qnum=QNUM level of the quadrature (int or list)
--quadrature-bsize=SIZE quadratures are created in batches (to meet memory limitations)
--importance-samples=NSAMP number of importance samples and weights for the approximation
of estimators with respect to the target distribution
""" + docs_mcmc_str + \
""" --mcmc-samples=NSAMP length of the chain with invariant distribution the
target distribution using Metropolis-Hastings with independent
proposals
--mcmc-burnin=BURNIN number of samples to be considered as burn-in
--mcmc-skip=SKIP number of sample to be skipped (>=0) in storage (a NSAMP*SKIP chain is subsampled)
--mcmc-ess=TYPE whether to compute the ess. Options are:
acor: autocorrelation function and variance bars
uw: Ulli Wolff effective sample size
--mcmc-ess-skip=SKIP number of samples to be skipped in the effective sample size estimation
--mcmc-ess-q=QUANTILE quantile used for the estimation of the sample size (0.99).
This is estimated over the worst decaying autocorrelation rate.
--mcmc-ess-corr-plot whether to plot the auto correlations
--mcmc-ess-corr-plot-lag=LAG maximum lag to be plotted (100)
--mcmc-ess-hist-plot whether to plot a summary of the sample size by dimension
--mcmc-mh-eps=FLOAT variance of the Standard Normal proposal in Metropolis-Hasting
--mcmc-mh-pcn Use the preconditioned Crank-Nicolson proposal
N(sqrt(1-\eps**2)u,\eps**2 * \Sigma)
where \Sigma is the covariance of the prior (has to be Gaussian).
--mcmc-hmc-eps=FLOAT epsilon value in Hamiltonian Monte Carlo
--mcmc-hmc-nsteps=INT number of steps per sample
OPTIONS - Computation:
--chunk-size=SIZE chunk size to be used in the storage of data
""" + docs_log_str + """ --nprocs=NPROCS number of processors to be used (default=1)
OPTIONS - other:
-v verbose output (not affecting --log)
-I enter interactive mode after finishing
-h print this help
"""
[docs] docs_str = docs_descr_str + docs_options_str
[docs] def usage(self):
print(Postprocess.usage_str)
[docs] def description(self):
print(Postprocess.docs_str)
[docs] def store_figure(self, fig, fname):
for fmat in self.STORE_FIG_FMATS:
fig.savefig(fname+'.'+fmat, format=fmat, bbox_inches='tight');
[docs] def store_postproc_data(self, fname):
self.safe_store(self.postproc_data, fname)
@property
[docs] def long_options(self):
return super(Postprocess, self).long_options + \
[
# I/O
"store-fig-dir=", "store-fig-fmats=",
"extra-tit=", "no-plotting",
# Aligned conditionals
"aligned-conditionals=",
"alc-n-points-x-ax=", "alc-n-tri-plots=", "alc-anchor=", "alc-range=",
# Random conditionals
"random-conditionals=",
"rndc-n-points-x-ax=", "rndc-anchor=", "rndc-range=", "rndc-n-plots-x-ax=",
# Aligned marginals
"aligned-marginals=",
"alm-n-points=", "alm-n-tri-plots=", "alm-scatter",
# Variance diagnostic
"var-diag=", "var-diag-qtype=", "var-diag-qnum=",
# Quadrature
"quadrature=",
"quadrature-qtype=", "quadrature-qnum=",
"quadrature-bsize=",
# Importance sampling
"importance-samples=",
# Markov Chain Monte Carlo
"mcmc=",
"mcmc-samples=", "mcmc-burnin=", "mcmc-skip=",
"mcmc-mh-eps=", "mcmc-mh-pcn",
"mcmc-hmc-eps=", "mcmc-hmc-nsteps=",
"mcmc-ess=", "mcmc-ess-skip=",
"mcmc-ess-q=", "mcmc-ess-xcorr",
"mcmc-ess-corr-plot", "mcmc-ess-corr-plot-lag=",
"mcmc-ess-hist-plot",
# Computation
"chunk-size="
]
[docs] def _load_opts(self, opts):
super(Postprocess, self)._load_opts(opts)
for opt, arg in opts:
# I/O
if opt == "--store-fig-dir":
self.STORE_FIG_DIR = arg
elif opt == "--store-fig-fmats":
self.STORE_FIG_FMATS = arg.split(',')
elif opt == "--extra-tit":
self.EXTRA_TIT = "-" + arg
elif opt == "--no-plotting":
self.PLOTTING = False
# Aligned conditionals
elif opt == "--aligned-conditionals":
self.ALIGNED_CONDITIONALS.append(arg)
self.ALC_N_POINTS_X_AX.append( self.DFT_N_POINTS_X_AX )
self.ALC_N_TRI_PLOTS.append( self.DFT_N_TRI_PLOTS )
self.ALC_ANCHOR.append( self.DFT_ANCHOR )
self.ALC_RANGE.append( self.DFT_RANGE )
# Options
elif opt == "--alc-n-points-x-ax":
self.ALC_N_POINTS_X_AX[len(self.ALIGNED_CONDITIONALS)-1] = int(arg)
elif opt == "--alc-n-tri-plots":
self.ALC_N_TRI_PLOTS[len(self.ALIGNED_CONDITIONALS)-1] = list(range(int(arg)))
elif opt == "--alc-anchor":
self.ALC_ANCHOR[len(self.ALIGNED_CONDITIONALS)-1] = [float(s) for s in arg.split(',')]
elif opt == "--alc-range":
self.ALC_RANGE[len(self.ALIGNED_CONDITIONALS)-1] = [float(s) for s in arg.split(',')]
# Random conditionals
elif opt == "--random-conditionals":
self.RANDOM_CONDITIONALS.append(arg)
self.RNDC_N_POINTS_X_AX.append( self.DFT_N_POINTS_X_AX )
self.RNDC_N_PLOTS_X_AX.append( self.DFT_N_PLOTS_X_AX )
self.RNDC_ANCHOR.append( self.DFT_ANCHOR )
self.RNDC_RANGE.append( self.DFT_RANGE )
# Options
elif opt == "--rndc-n-points-x-ax":
self.RNDC_N_POINTS_X_AX[len(self.RANDOM_CONDITIONALS)-1] = int(arg)
elif opt == "--rndc-n-plots-x-ax":
self.RNDC_N_PLOTS_X_AX[len(self.RANDOM_CONDITIONALS)-1] = int(arg)
elif opt == "--rndc-anchor":
self.RNDC_ANCHOR[len(self.RANDOM_CONDITIONALS)-1] = [float(s) for s in arg.split(',')]
elif opt == "--rndc-range":
self.RNDC_RANGE[len(self.RANDOM_CONDITIONALS)-1] = [float(s) for s in arg.split(',')]
# Aligned marginals
elif opt == "--aligned-marginals":
self.ALIGNED_MARGINALS.append(arg)
self.ALM_N_POINTS.append(self.DFT_N_POINTS)
self.ALM_N_TRI_PLOTS.append(self.DFT_N_TRI_PLOTS)
# Options
elif opt == "--alm-scatter":
self.ALM_SCATTER = True
elif opt == "--alm-n-points":
self.ALM_N_POINTS[len(self.ALIGNED_MARGINALS)-1] = int(arg)
elif opt == "--alm-n-tri-plots":
self.ALM_N_TRI_PLOTS[len(self.ALIGNED_MARGINALS)-1] = list(range(int(arg)))
# Variance diagnostic
elif opt == "--var-diag":
self.VAR_DIAG.append(arg)
self.VD_QTYPE.append(self.DFT_VD_QTYPE)
self.VD_QNUM.append(self.DFT_VD_QNUM)
elif opt == "--var-diag-qtype":
self.VD_QTYPE[len(self.VAR_DIAG)-1] = int(arg)
elif opt == "--var-diag-qnum":
self.VD_QNUM[len(self.VAR_DIAG)-1] = [int(q) for q in arg.split(',')]
# Quadrature
elif opt == "--quadrature":
self.QUADRATURE.append( arg )
self.QUAD_QTYPE.append( None )
self.QUAD_QNUM.append( None )
self.QUAD_BSIZE.append( float("inf") )
elif opt == "--quadrature-qtype":
self.QUAD_QTYPE[len(self.QUADRATURE)-1] = int(arg)
elif opt == "--quadrature-qnum":
self.QUAD_QNUM[len(self.QUADRATURE)-1] = [int(q) for q in arg.split(',')]
elif opt == "--quadrature-bsize":
self.QUAD_BSIZE[len(self.QUADRATURE)-1] = int(arg)
# Importance sampling
elif opt == "--importance-samples":
self.IMP_SAMPLES = int(arg)
# Metropolis Hastings
elif opt == "--mcmc":
self.MCMC_ALG = arg
elif opt == "--mcmc-samples":
self.MCMC_SAMPLES = int(arg)
elif opt == "--mcmc-burnin":
self.MCMC_BURNIN = int(arg)
elif opt == "--mcmc-skip":
self.MCMC_SKIP = max(int(arg), 0)
elif opt == "--mcmc-mh-eps":
self.MCMC_MH_EPS = float(arg)
elif opt == '--mcmc-mh-pcn':
self.MCMC_MH_PCN = True
elif opt == "--mcmc-hmc-eps":
self.MCMC_HMC_EPS = float(arg)
elif opt == "--mcmc-hmc-nsteps":
self.MCMC_HMC_NSTEPS = int(arg)
elif opt == "--mcmc-ess":
self.MCMC_ESS = arg
elif opt == "--mcmc-ess-skip":
self.MCMC_ESS_SKIP = int(arg)
if self.MCMC_ESS_SKIP < 1:
raise ValueError("SKIP must be > 0 in --mcmc-ess-skip=SKIP")
elif opt == "--mcmc-ess-q":
self.MCMC_ESS_Q = float(arg)
elif opt == "--mcmc-ess-corr-plot":
self.MCMC_ESS_CORR_PLOT = True
elif opt == "--mcmc-ess-corr-plot-lag":
self.MCMC_ESS_CORR_PLOT_LAG = int(arg)
elif opt == "--mcmc-ess-hist-plot":
self.MCMC_ESS_HIST_PLOT = True
# Computation
elif opt == "--chunk-size":
self.CHUNK_SIZE = int(arg)
[docs] def _init_self_variables(self):
super(Postprocess, self)._init_self_variables()
# I/O
self.STORE_FIG_DIR = None
self.STORE_FIG_FMATS = ['svg','pdf']
self.EXTRA_TIT = ''
self.PLOTTING = True
# Aligned conditionals
self.ALIGNED_CONDITIONALS = []
self.ALC_N_POINTS_X_AX = []
self.ALC_N_TRI_PLOTS = []
self.ALC_ANCHOR = []
self.ALC_RANGE = []
# Random conditionals
self.RANDOM_CONDITIONALS = []
self.RNDC_N_POINTS_X_AX = []
self.RNDC_N_PLOTS_X_AX = []
self.RNDC_ANCHOR = []
self.RNDC_RANGE = []
# Aligned marginals
self.ALIGNED_MARGINALS = []
self.ALM_SCATTER = False
self.ALM_N_POINTS = []
self.ALM_N_TRI_PLOTS = []
# Default plotting options
self.DFT_N_POINTS = 1000
self.DFT_N_POINTS_X_AX = 40
self.DFT_N_TRI_PLOTS = 0
self.DFT_ANCHOR = None
self.DFT_RANGE = [-5.,5.]
self.DFT_N_PLOTS_X_AX = 6
# Variance diagnostic
self.VAR_DIAG = []
self.VD_QTYPE = []
self.VD_QNUM = []
# Defaults for variance diagnostic
self.DFT_VD_QTYPE = 0
self.DFT_VD_QNUM = 1000
# Samples
self.QUADRATURE = []
self.QUAD_QTYPE = []
self.QUAD_QNUM = []
self.QUAD_BSIZE = []
# Importance samples
self.IMP_SAMPLES = None
# MCMC samples
self.MCMC_ALG = None
self.MCMC_SAMPLES = None
self.MCMC_BURNIN = 0
self.MCMC_SKIP = 0
self.MCMC_ESS = None
self.MCMC_ESS_SKIP = 1
self.MCMC_ESS_Q = 0.99
self.MCMC_ESS_CORR_PLOT = False
self.MCMC_ESS_CORR_PLOT_LAG = 100
self.MCMC_ESS_HIST_PLOT = False
self.MCMC_MH_EPS = 0.1
self.MCMC_MH_PCN = False
self.MCMC_HMC_EPS = 0.2
self.MCMC_HMC_NSTEPS = 1
# hdf5 options
self.CHUNK_SIZE = 10000
[docs] def _check_required_args(self):
super(Postprocess, self)._check_required_args()
if not self.PLOTTING and self.STORE_FIG_DIR is None and self.OUTPUT is None:
self.usage()
self.tstamp_print(
"ERROR: Neither --output nor --store-fig-dir were " + \
"specified, while --no-plotting is active. " + \
"This would result on no data shown or stored.")
sys.exit(3)
[docs] def load(self):
self.h5_file = None
# Load data
with open(self.INPUT, 'rb') as in_stream:
self.stg = pickle.load(in_stream)
# Restore data
self.base_distribution = self.stg.base_distribution
self.target_distribution = self.stg.target_distribution
self.tmap = self.stg.tmap
self.approx_base_distribution = self.stg.approx_base_distribution
self.approx_target_distribution = self.stg.approx_target_distribution
self.dim = self.base_distribution.dim
# Load output (pickle file) if any
if not os.path.exists(self.OUTPUT):
self.postproc_data = {}
with open(self.OUTPUT, 'wb') as out_stream:
pickle.dump(self.postproc_data, out_stream)
with open(self.OUTPUT, 'rb') as in_stream:
self.postproc_data = pickle.load(in_stream)
self.postproc_root = self.postproc_data
# Load output (hdf5 file) if any
self.h5_lock = fasteners.InterProcessLock(self.OUTPUT + '.hdf5.lock')
if not self.h5_lock.acquire(blocking=False):
self.tstamp_print(
"ERROR: the hdf5 file is locked. " +
"Lock: " + self.h5_lock
)
sys.exit(4)
self.h5_file = h5py.File(self.OUTPUT + '.hdf5', 'a')
self.h5_root = self.h5_file
[docs] def close(self):
self.h5_file.close()
self.h5_lock.release()
[docs] def aligned_conditionals(self, mpi_pool):
for aligned, n_tri_plots, n_points_x_ax, anchor, rng in \
zip(self.ALIGNED_CONDITIONALS, self.ALC_N_TRI_PLOTS,
self.ALC_N_POINTS_X_AX, self.ALC_ANCHOR, self.ALC_RANGE):
self.filter_tstamp_print("[Start] Aligned conditionals " + aligned)
if aligned == 'exact-target':
d = self.target_distribution
elif aligned == 'approx-target':
d = self.approx_target_distribution
elif aligned == 'exact-base':
d = self.base_distribution
elif aligned == 'approx-base':
d = self.approx_base_distribution
else:
self.full_usage()
self.tstamp_print("ERROR: DIST %s not recognized." % aligned)
sys.exit(3)
DATA_FIELD = 'aligned-' + aligned
data = self.postproc_root.get(DATA_FIELD, None)
if data is None:
data = DIAG.computeAlignedConditionals(
d, dimensions_vec=n_tri_plots,
numPointsXax = n_points_x_ax,
range_vec=rng,
mpi_pool=mpi_pool)
self.postproc_root[DATA_FIELD] = data
if self.OUTPUT is not None:
self.store_postproc_data(self.OUTPUT)
if self.PLOTTING:
fig = DIAG.plotAlignedConditionals(
data=data, show_flag=(self.STORE_FIG_DIR is None));
if self.STORE_FIG_DIR is not None:
self.store_figure(fig, self.STORE_FIG_DIR+'/'+ \
'aligned-conditionals-'+ aligned +\
self.EXTRA_TIT)
self.filter_tstamp_print("[Stop] Aligned conditionals " + aligned)
[docs] def random_conditionals(self, mpi_pool):
for random, n_plots_x_ax, n_points_x_ax, anchor, rng in \
zip(self.RANDOM_CONDITIONALS, self.RNDC_N_PLOTS_X_AX,
self.RNDC_N_POINTS_X_AX,
self.RNDC_ANCHOR, self.RNDC_RANGE):
self.filter_tstamp_print("[Start] Random conditionals " + random)
if random == 'exact-target':
d = self.target_distribution
elif random == 'approx-target':
d = self.approx_target_distribution
elif random == 'exact-base':
d = self.base_distribution
elif random == 'approx-base':
d = self.approx_base_distribution
else:
self.full_usage()
self.tstamp_print("ERROR: DIST %s not recognized." % random)
sys.exit(3)
DATA_FIELD = 'random-' + random
data = self.postproc_root.get(DATA_FIELD, None)
if data is None:
data = DIAG.computeRandomConditionals(
d, num_conditionalsXax=n_plots_x_ax,
numPointsXax=n_points_x_ax,
pointEval=anchor, range_vec=rng,
mpi_pool=mpi_pool)
self.postproc_root[DATA_FIELD] = data
if self.OUTPUT is not None:
self.store_postproc_data(self.OUTPUT)
if self.PLOTTING:
fig = DIAG.plotRandomConditionals(
data=data, show_flag=(self.STORE_FIG_DIR is None))
if self.STORE_FIG_DIR is not None:
self.store_figure(fig, self.STORE_FIG_DIR+'/'+\
'random-conditionals-' + random + \
self.EXTRA_TIT)
self.filter_tstamp_print("[Stop] Random conditionals " + random)
[docs] def variance_diagnostic(self, mpi_pool):
for dstr, qtype, qnum in zip(self.VAR_DIAG, self.VD_QTYPE, self.VD_QNUM):
self.filter_tstamp_print("[Start] Variance diagnostic " + dstr)
if dstr == 'exact-target':
d1 = self.target_distribution
d2 = self.approx_target_distribution
elif dstr == 'approx-target':
d1 = self.approx_target_distribution
d2 = self.target_distribution
elif dstr == 'exact-base':
d1 = self.base_distribution
d2 = self.approx_base_distribution
elif dstr == 'approx-base':
d1 = self.approx_base_distribution
d2 = self.base_distribution
else:
self.full_usage()
self.tstamp_print("ERROR: DIST %s not recognized." % dstr)
sys.exit(3)
# Load values if any
GRP_NAME = "vals_var_diag/" + dstr
if GRP_NAME not in self.h5_root:
self.h5_root.create_group(GRP_NAME)
grp = self.h5_root[GRP_NAME]
QTYPE_NAME = str(qtype)
if QTYPE_NAME not in grp:
grp.create_group(QTYPE_NAME)
qtype_grp = grp[QTYPE_NAME]
V1_NAME = 'vals_d1'
V2_NAME = 'vals_d2'
if qtype == 0: # Monte-Carlo
if V1_NAME not in qtype_grp:
qtype_grp.create_dataset(
V1_NAME, (0,), maxshape=(None,), dtype='d', chunks=(self.CHUNK_SIZE,))
if V2_NAME not in qtype_grp:
qtype_grp.create_dataset(
V2_NAME, (0,), maxshape=(None,), dtype='d', chunks=(self.CHUNK_SIZE,))
loaded_vals_d1 = qtype_grp[V1_NAME]
loaded_vals_d2 = qtype_grp[V2_NAME]
if len(loaded_vals_d1) > 0 and qnum[0] <= len(loaded_vals_d1):
# Subselect already available data
vals_d1 = np.array( loaded_vals_d1[:qnum[0]] )
vals_d2 = np.array( loaded_vals_d2[:qnum[0]] )
else:
old_len = len(loaded_vals_d1)
# Sample new points and evaluate
n_new_samps = qnum[0] - len(loaded_vals_d1)
x = d1.rvs(n_new_samps)
new_vals_d1, new_vals_d2 = DIAG.compute_vals_variance_approx_kl(
d1, d2, x=x, mpi_pool_tuple=(None, mpi_pool))
loaded_vals_d1.resize(qnum[0], axis=0)
loaded_vals_d2.resize(qnum[0], axis=0)
loaded_vals_d1[old_len:] = new_vals_d1
loaded_vals_d2[old_len:] = new_vals_d2
vals_d1 = np.array( loaded_vals_d1 )
vals_d2 = np.array( loaded_vals_d2 )
w = np.ones(qnum[0])/float(qnum[0])
elif qtype == 3: # Gauss quadrature
QNUM_NAME = str(qnum)
W_NAME = 'w'
if QNUM_NAME not in qtype_grp:
qtype_grp.create_group(QNUM_NAME)
qnum_grp = qtype_grp[QNUM_NAME]
(x, w) = d1.quadrature(qtype, qnum, mpi_pool=mpi_pool)
vals_d1, vals_d2 = DIAG.compute_vals_variance_approx_kl(
d1, d2, x=x, mpi_pool_tuple=(None, mpi_pool))
csize = min(self.CHUNK_SIZE, w.shape[0])
qnum_grp.create_dataset(V1_NAME, data=vals_d1, chunks=(csize,))
qnum_grp.create_dataset(V2_NAME, data=vals_d2, chunks=(csize,))
qnum_grp.create_dataset(W_NAME, data=w, chunks=(csize,))
else:
qnum_grp = qtype_grp[QNUM_NAME]
vals_d1 = np.array( qnum_grp[V1_NAME] )
vals_d2 = np.array( qnum_grp[V2_NAME] )
w = np.array( qnum_grp[W_NAME] )
var_diag_tm = DIAG.variance_approx_kl(d1, d2,
vals_d1=vals_d1, vals_d2=vals_d2, w=w)
self.filter_tstamp_print("[Stop] Variance diagnostic %s: %e" % (dstr, var_diag_tm))
[docs] def aligned_marginals(self, mpi_pool):
for dstr, n_points, n_tri_plots in zip(
self.ALIGNED_MARGINALS, self.ALM_N_POINTS, self.ALM_N_TRI_PLOTS):
self.filter_tstamp_print("[Start] Aligned marginals %s " % dstr + \
"- Sample generation")
if dstr == 'exact-target':
d = self.target_distribution
elif dstr == 'approx-target':
d = self.approx_target_distribution
elif dstr == 'exact-base':
d = self.base_distribution
elif dstr == 'approx-base':
d = self.approx_base_distribution
else:
self.full_usage()
self.tstamp_print("ERROR: DIST %s not recognized." % dstr)
sys.exit(3)
# Load values if any
Q_GRP_NAME = "quadrature"
if Q_GRP_NAME not in self.h5_root:
self.h5_root.create_group(Q_GRP_NAME)
qgrp = self.h5_root[Q_GRP_NAME]
D_GRP_NAME = dstr
if D_GRP_NAME not in qgrp:
qgrp.create_group(D_GRP_NAME)
dgrp = qgrp[D_GRP_NAME]
DSET_NAME = '0'
if DSET_NAME not in dgrp:
dgrp.create_dataset(
DSET_NAME, (0,self.dim), maxshape=(None,self.dim), dtype='d',
chunks=(self.CHUNK_SIZE,1))
loaded_samp = dgrp[DSET_NAME]
if n_points > loaded_samp.shape[0]:
nold = loaded_samp.shape[0]
new_nsamp = n_points - nold
x = d.rvs(new_nsamp, mpi_pool=mpi_pool)
loaded_samp.resize(n_points, axis=0)
loaded_samp[nold:,:] = x
self.filter_tstamp_print(" Aligned marginals %s " % dstr + \
"- Plotting")
if self.PLOTTING:
fig = DIAG.plotAlignedMarginals(
loaded_samp[:n_points,:], dimensions_vec=n_tri_plots,
scatter=self.ALM_SCATTER,
show_flag=(self.STORE_FIG_DIR is None))
if self.STORE_FIG_DIR is not None:
self.store_figure(fig, self.STORE_FIG_DIR+'/'+\
'aligned-marginals-'+ dstr +\
self.EXTRA_TIT)
self.filter_tstamp_print("[Stop] Aligned marginals %s" % dstr)
[docs] def quadratures(self, mpi_pool):
for dstr, qtype, qnum, bsize \
in zip(self.QUADRATURE, self.QUAD_QTYPE,
self.QUAD_QNUM, self.QUAD_BSIZE):
self.filter_tstamp_print("[Start] Quadrature " + str(qtype))
if dstr == 'exact-target':
d = self.target_distribution
elif dstr == 'approx-target':
d = self.approx_target_distribution
elif dstr == 'exact-base':
d = self.base_distribution
elif dstr == 'approx-base':
d = self.approx_base_distribution
else:
self.full_usage()
self.tstamp_print("ERROR: DIST %s not recognized." % dstr)
sys.exit(3)
# Load values if any
GRP_NAME = "quadrature"
if GRP_NAME not in self.h5_root:
self.h5_root.create_group(GRP_NAME)
qgrp = self.h5_root[GRP_NAME]
D_GRP_NAME = dstr
if D_GRP_NAME not in qgrp:
qgrp.create_group(D_GRP_NAME)
dgrp = qgrp[D_GRP_NAME]
if qtype == 0: # Monte-Carlo
DSET_NAME = str(qtype)
if DSET_NAME not in dgrp:
dgrp.create_dataset(
DSET_NAME, (0,self.dim), maxshape=(None,self.dim), dtype='d',
chunks=(self.CHUNK_SIZE,1))
loaded_samp = dgrp[DSET_NAME]
while qnum[0] > loaded_samp.shape[0]:
nold = loaded_samp.shape[0]
new_nsamp = min(qnum[0], nold+bsize) - nold
self.filter_tstamp_print(
" Generating batch %d-%d" % (nold, nold+new_nsamp))
x = d.rvs(new_nsamp, mpi_pool=mpi_pool)
loaded_samp.resize(nold+new_nsamp, axis=0)
loaded_samp[nold:,:] = x
elif qtype == 3: # Gauss quadrature
QTYPE_NAME = str(qtype)
X_NAME = 'x'
W_NAME = 'w'
if QTYPE_NAME not in dgrp:
dgrp.create_group(QTYPE_NAME)
qtp_grp = dgrp[QTYPE_NAME]
QNUM_NAME = str(qnum)
if QNUM_NAME not in qtp_grp:
qtp_grp.create_group(QNUM_NAME)
qngrp = qtp_grp[QNUM_NAME]
(x, w) = d.quadrature(qtype, qnum, mpi_pool=mpi_pool)
qngrp.create_dataset(X_NAME, data=x, chunks=(self.CHUNK_SIZE,1))
qngrp.create_dataset(W_NAME, data=w, chunks=(self.CHUNK_SIZE,))
self.filter_tstamp_print("[Stop] Quadrature")
[docs] def importance_sampling(self, mpi_pool):
if self.IMP_SAMPLES is not None:
self.filter_tstamp_print("[Start] Importance sampling")
# Load values if any
GRP_NAME = "importance-samples"
if GRP_NAME not in self.h5_root:
self.h5_root.create_group(GRP_NAME)
is_grp = self.h5_root[GRP_NAME]
X_NAME = 'x'
W_NAME = 'w'
if X_NAME not in is_grp:
is_grp.create_dataset(
X_NAME, (0,self.dim), maxshape=(None,self.dim),
dtype='d', chunks=(self.CHUNK_SIZE,1))
is_grp.create_dataset(
W_NAME, (0,), maxshape=(None,),
dtype='d', chunks=(self.CHUNK_SIZE,))
loaded_x = is_grp[X_NAME]
loaded_w = is_grp[W_NAME]
if self.IMP_SAMPLES > loaded_x.shape[0]:
nold = loaded_x.shape[0]
new_nsamp = self.IMP_SAMPLES - nold
sampler = SAMP.ImportanceSampler( self.approx_base_distribution, self.base_distribution )
(x, w) = sampler.rvs(new_nsamp, mpi_pool_tuple=(mpi_pool, None))
x = self.approx_target_distribution.map_samples_base_to_target(
x, mpi_pool=mpi_pool)
loaded_x.resize(self.IMP_SAMPLES, axis=0)
loaded_x[nold:,:] = x
loaded_w.resize(self.IMP_SAMPLES, axis=0)
loaded_w[nold:] = w
loaded_w /= np.sum(loaded_w)
self.filter_tstamp_print("[Stop] Importance sampling")
[docs] def mcmc(self, mpi_pool):
if self.MCMC_ALG is not None:
self.filter_tstamp_print("[Start] Markov Chain Monte Carlo")
if self.MCMC_ALG == 'mhind':
self.filter_tstamp_print(" Metropolis-Hastings with Independent Proposals")
GRP_NAME = "metropolis-independent-proposal"
sampler = SAMP.MetropolisHastingsIndependentProposalsSampler(
self.approx_base_distribution, self.base_distribution )
# Load values if any
if GRP_NAME not in self.h5_root:
self.h5_root.create_group(GRP_NAME)
is_grp = self.h5_root[GRP_NAME]
elif self.MCMC_ALG == 'hmc':
self.filter_tstamp_print(" Hamiltonian Monte Carlo")
GRP_NAME = "hamiltonian-monte-carlo-samples"
if not isinstance(self.base_distribution, DIST.StandardNormalDistribution):
self.logger.warn("The HMC algorithm uses a Standard Normal distribution " +\
"as default proposal")
sampler = SAMP.HamiltonianMonteCarloSampler( self.approx_base_distribution )
# Load values if any
if GRP_NAME not in self.h5_root:
self.h5_root.create_group(GRP_NAME)
is_grp = self.h5_root[GRP_NAME]
if str(self.MCMC_HMC_EPS) not in is_grp:
is_grp.create_group(str(self.MCMC_HMC_EPS))
is_grp = is_grp[str(self.MCMC_HMC_EPS)]
if str(self.MCMC_HMC_NSTEPS) not in is_grp:
is_grp.create_group(str(self.MCMC_HMC_NSTEPS))
is_grp = is_grp[str(self.MCMC_HMC_NSTEPS)]
elif self.MCMC_ALG == 'mh':
self.filter_tstamp_print(" Metropolis-Hastings")
GRP_NAME = "metropolis-hastings"
if self.MCMC_MH_PCN:
if isinstance(self.target_distribution, DISTINF.BayesPosteriorDistribution) and \
(
isinstance(self.target_distribution.prior, DIST.NormalDistribution) or \
isinstance(self.target_distribution.prior, DIST.StandardNormalDistribution)
):
GRP_NAME += "/pcn"
prop_distribution = DIST.MeanConditionallyGaussianDistribution(
Maps.PreconditionedCrankNicolsonMap(
self.base_distribution.dim, self.MCMC_MH_EPS),
self.MCMC_MH_EPS**2 * self.target_distribution.prior.covariance)
else:
self.tstamp_print(
"In order to use preconditioned Crank-Nicolson " + \
"the target distribution must be a Bayesian posterior with " + \
"normal prior"
)
sys.exit(3)
else:
prop_distribution = DIST.MeanConditionallyNormalDistribution(
Maps.IdentityTransportMap(self.base_distribution.dim),
self.MCMC_MH_EPS * np.eye(self.base_distribution.dim) )
sampler = SAMP.MetropolisHastingsSampler(
self.approx_base_distribution, prop_distribution )
# Load values if any
if GRP_NAME not in self.h5_root:
self.h5_root.create_group(GRP_NAME)
is_grp = self.h5_root[GRP_NAME]
if str(self.MCMC_MH_EPS) not in is_grp:
is_grp.create_group(str(self.MCMC_MH_EPS))
is_grp = is_grp[str(self.MCMC_MH_EPS)]
else:
self.full_usage()
self.tstamp_print("ERROR: ALG %s not recognized." % self.MCMC_ALG)
sys.exit(3)
# Create group for each skipping value
SKIP_NAME = "skip-%d" % self.MCMC_SKIP
if SKIP_NAME not in is_grp:
is_grp.create_group(SKIP_NAME)
is_grp = is_grp[SKIP_NAME]
X_NAME = 'x' # Samples in pushforward space
S_NAME = 's' # Samples in pullback space
if X_NAME not in is_grp:
is_grp.create_dataset(
X_NAME, (0,self.dim), maxshape=(None,self.dim),
dtype='d', chunks=(self.CHUNK_SIZE,1))
if S_NAME not in is_grp:
is_grp.create_dataset(
S_NAME, (0,self.dim), maxshape=(None,self.dim),
dtype='d', chunks=(self.CHUNK_SIZE,1))
loaded_x = is_grp[X_NAME]
loaded_s = is_grp[S_NAME]
if self.MCMC_SAMPLES > loaded_x.shape[0]:
nold = loaded_x.shape[0]
new_nsamp = self.MCMC_SAMPLES - nold
s0 = None
if nold > 0:
s0 = loaded_s[-1,:]
self.filter_tstamp_print(" Restarting chain from stored data (length: %d)" % nold)
if self.MCMC_ALG in ['mh','mhind']:
(s, _) = sampler.rvs(new_nsamp*(self.MCMC_SKIP+1), x0=s0,
mpi_pool_tuple=(mpi_pool, None))
elif self.MCMC_ALG == 'hmc':
(s, _) = sampler.rvs(
new_nsamp*(self.MCMC_SKIP+1), x0=s0,
epsilon=self.MCMC_HMC_EPS, n_steps=self.MCMC_HMC_NSTEPS)
s = s[::(self.MCMC_SKIP+1),:]
self.filter_tstamp_print(" Pushing forward samples")
x = self.approx_target_distribution.map_samples_base_to_target(
s, mpi_pool=mpi_pool)
loaded_x.resize(self.MCMC_SAMPLES, axis=0)
loaded_s.resize(self.MCMC_SAMPLES, axis=0)
loaded_x[nold:,:] = x
loaded_s[nold:,:] = s
# Compute effective sample size
if self.MCMC_ESS is not None:
self.filter_tstamp_print(" Estimating ESS")
ess_list = []
for d in range(loaded_s.shape[1]):
fig = None if not (self.PLOTTING and self.MCMC_ESS_CORR_PLOT) else plt.figure()
if self.MCMC_ESS == 'acor':
ess = SAMP.ess(
loaded_s[self.MCMC_BURNIN:self.MCMC_SAMPLES:self.MCMC_ESS_SKIP,[d]],
quantile=self.MCMC_ESS_Q,
plotting=(self.PLOTTING and self.MCMC_ESS_CORR_PLOT),
plot_lag=self.MCMC_ESS_CORR_PLOT_LAG, fig=fig)
elif self.MCMC_ESS == 'uw':
ess = SAMP.uwerr(
loaded_s[self.MCMC_BURNIN:self.MCMC_SAMPLES:self.MCMC_ESS_SKIP,[d]],
plotting=(self.PLOTTING and self.MCMC_ESS_CORR_PLOT),
)
# self.filter_tstamp_print(
# " ESS dimension %d: %d/%d " % (
# d,ess,(self.MCMC_SAMPLES-self.MCMC_BURNIN)//self.MCMC_ESS_SKIP) + \
# "(%2.3f%%)" % (
# ess/float((self.MCMC_SAMPLES-self.MCMC_BURNIN)//self.MCMC_ESS_SKIP)*100.)
# )
ess_list.append( ess )
if self.PLOTTING and self.MCMC_ESS_CORR_PLOT:
if self.STORE_FIG_DIR is None:
plt.show(False)
else:
self.store_figure(fig, self.STORE_FIG_DIR+'/'+\
'metropolis-ess-corr-d%d' %d + \
self.EXTRA_TIT)
if self.PLOTTING and self.MCMC_ESS_HIST_PLOT:
fig = plt.figure()
ax = fig.add_subplot(111)
ax.plot(ess, drawstyle='steps-mid')
ax.set_ylabel("ESS")
ax.set_xlable("Dimension")
if self.STORE_FIG_DIR is None:
plt.show(False)
else:
self.store_figure(fig, self.STORE_FIG_DIR+'/'+\
'metropolis-ess-hist' + \
self.EXTRA_TIT)
amin_ess = np.argmin(ess_list)
amax_ess = np.argmax(ess_list)
min_ess = ess_list[amin_ess]
max_ess = ess_list[amax_ess]
mean_ess = np.mean(ess_list)
tot_samps = (self.MCMC_SAMPLES-self.MCMC_BURNIN)//self.MCMC_ESS_SKIP
self.filter_tstamp_print(
"[Stop] Markov Chain Monte Carlo " + \
"- ESS: %d/%d " % (ess,tot_samps) + \
"-- worst %2.3f%% - d: %d " % (min_ess/float(tot_samps)*100, amin_ess) + \
"-- best %2.3f%% - d: %d " % (max_ess/float(tot_samps)*100, amax_ess) + \
"-- avg %2.3f%%" % (mean_ess/float(tot_samps)*100)
)
else:
self.filter_tstamp_print("[Stop] Markov Chain Monte Carlo")
[docs] def run(self, mpi_pool):
self.aligned_conditionals(mpi_pool)
self.random_conditionals(mpi_pool)
self.variance_diagnostic(mpi_pool)
self.aligned_marginals(mpi_pool)
self.quadratures(mpi_pool)
self.importance_sampling(mpi_pool)
self.mcmc(mpi_pool)