#
# This file is part of TransportMaps.
#
# TransportMaps is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# TransportMaps is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with TransportMaps. If not, see <http://www.gnu.org/licenses/>.
#
# Transport Maps Library
# Copyright (C) 2015-2018 Massachusetts Institute of Technology
# Uncertainty Quantification group
# Department of Aeronautics and Astronautics
#
# Author: Transport Map Team
# Website: transportmaps.mit.edu
# Support: transportmaps.mit.edu/qa/
#
import sys
import os
import os.path
import pickle
import numpy as np
from . import AvailableOptions as AO
from .PostprocessBase import Postprocess
from TransportMaps.External import H5PY_SUPPORT
import TransportMaps.Diagnostics as DIAG
import TransportMaps.Samplers as SAMP
import TransportMaps.Distributions as DIST
if H5PY_SUPPORT:
import h5py
__all__ = ['SequentialPostprocess']
[docs]class SequentialPostprocess(Postprocess):
[docs] def usage(self):
usage_str = """
Usage: tmap-sequential-postprocess [-h -v -I]
--input=INPUT --output=OUTPUT
[--trim=NSTEPS
--store-fig-dir=DIR --store-fig-fmats=FMATS
--extra-tit=TITLE --no-plotting
--trim=NSTEPS
--sequential-var-diag
--sequential-reg-diag
--filtering-conditionals
--filt-alc-n-points-x-ax=N --filt-alc-n-tri-plots=N
--filt-alc-anchor=LIST --filt-alc-range=LIST
--filtering-marginals
--filt-alm-n-points=N --filt-alm-n-tri-plots=N
--filtering-quadrature
--filt-quad-qtype=QTYPE
--filt-quad-qnum=QNUM
--log=LOG --batch=BATCH --nprocs=NPROCS]
"""
print(usage_str)
[docs] def description(self):
docs_distributions_str = \
AO.print_avail_options(AO.AVAIL_DISTRIBUTIONS,
' ', False)
docs_log_str = \
' --log=LOG log level (default=30). Uses package logging.\n' + \
AO.print_avail_options(AO.AVAIL_LOGGING,' ')
docs_str = """DESCRIPTION
Given a file (--input) storing the transport map pushing forward a base distribution
to a sequential Hidden Markov target distribution,
provides a number of postrprocessing routines.
All files involved are stored and loaded using the python package pickle and
an extra file OUTPUT.hdf5 is created to store big datasets in the hdf5 format.
In the following default values are shown in brackets.
All the options available for tmaps-postprocess are also available here.
OPTIONS - input/output:
--input=INPUT path to the file containing the target distribution,
the base distribution and the transport map pushing forward
the base to the target.
--output=OUTPUT path to the file storing all postprocess data.
The additional file OUTPUT.hdf5 will be used to store
the more memory consuming data.
--trim=NSTEPS trim the solution to NSTEPS and perform analysis
--store-fig-dir=DIR path to the directory where to store the figures.
--store-fig-fmats=FMATS figure formats - see matplotlib for supported formats (svg)
--extra-tit=TITLE additional title for the figures' file names.
--no-plotting do not plot figures, but only store their data.
(requires --output or --store-fig-dir)
--trim=NSTEPS trims the results to NSTEPS and run diagnostics on this distribution
OPTIONS - Diagnostics:
--sequential-var-diag plot value of variance diagnostic for the sequence of maps
--sequential-reg-diag plot value of all regression residuals
--filtering-conditionals plot aligned slices of the filtering distribution:
Optional arguments:
--filt-alc-n-points-x-ax=N number of discretization points per axis (30)
--filt-alc-n-tri-plots=N number of subplots (0)
--filt-alc-anchor=LIST list of floats "f1,f2,f3..." for the anchor point (0)
--filt-alc-range=LIST list of two floats "f1,f2" for the range (-5,5)
OPTIONS - Sampling:
--filtering-marginals plot aligned marginals of the filtering distribution:
Optional arguments:
--filt-alm-n-points=N number of samples to be used for the kernel density estimation
--filt-alm-n-tri-plots=N number of subplots (0)
--filtering-quadrature generate quadrature of the filtering distribution:
Optional arguments:
--filt-quad-qtype=QTYPE generate quadrature of type QTYPE (0)
--filt-quad-qnum=QNUM level of the quadrature (int or list)
OPTIONS - Computation:
""" + docs_log_str + """ --nprocs=NPROCS number of processors to be used (default=1)
--batch=BATCH list of batch sizes for function evaluation, gradient
evaluation and Hessian evaluation
OPTIONS - other:
-v verbose output (not affecting --log)
-I enter interactive mode after finishing
-h print this help
"""
print(docs_str)
@property
[docs] def long_options(self):
return super(SequentialPostprocess, self).long_options + \
[
"trim=",
# Sequential diagnostics
"sequential-var-diag", "sequential-reg-diag",
# Aligned conditionals
"filtering-conditionals",
"filt-alc-n-points-x-ax=", "filt-alc-n-tri-plots=",
"filt-alc-anchor=", "filt-alc-range=",
# Aligned marginals
"filtering-marginals",
"filt-alm-n-points=", "filt-alm-n-tri-plots=",
# Quadrature
"filtering-quadrature",
"filt-quad-qtype=", "filt-quad-qnum=",
]
[docs] def _load_opts(self, opts):
super(SequentialPostprocess, self)._load_opts(opts)
for opt, arg in opts:
if opt == "--trim":
self.TRIM = int(arg)
# Sequential diagnostics
elif opt == "--sequential-var-diag":
self.SEQUENTIAL_VAR_DIAG = True
elif opt == "--sequential-reg-diag":
self.SEQUENTIAL_REG_DIAG = True
# Aligned conditionals
elif opt in ("--filtering-conditionals"):
self.FLT_ALIGNED_CONDITIONALS = True
# Options
elif opt in ("--filt-alc-n-points-x-ax"):
self.FLT_ALC_N_POINTS_X_AX = int(arg)
elif opt in ("--filt-alc-n-tri-plots"):
self.FLT_ALC_N_TRI_PLOTS = list(range(int(arg)))
elif opt in ("--filt-alc-anchor"):
self.FLT_ALC_ANCHOR = [float(s) for s in arg.split(',')]
elif opt in ("--filt-alc-range"):
self.FLT_ALC_RANGE = [float(s) for s in arg.split(',')]
# Aligned marginals
elif opt in ("--filtering-marginals"):
self.FLT_ALIGNED_MARGINALS = True
# Options
elif opt in ("--filt-alm-n-points"):
self.FLT_ALM_N_POINTS = int(arg)
elif opt in ("--filt-alm-n-tri-plots"):
self.FLT_ALM_N_TRI_PLOTS = int(arg)
# Quadrature
elif opt in ("--filtering-quadrature"):
self.FLT_QUADRATURE.append( True )
self.FLT_QUAD_QTYPE.append( None )
self.FLT_QUAD_QNUM.append( None )
elif opt in ("--filt-quad-qtype"):
self.FLT_QUAD_QTYPE[len(self.FLT_QUADRATURE)-1] = int(arg)
elif opt in ("--filt-quad-qnum"):
self.FLT_QUAD_QNUM[len(self.FLT_QUADRATURE)-1] = [int(q) for q in arg.split(',')]
[docs] def _init_self_variables(self):
self(SequentialPostprocess, self)._init_self_variables()
self.TRIM = None
# Sequential diagnostics
self.SEQUENTIAL_VAR_DIAG = False
self.SEQUENTIAL_REG_DIAG = False
# Filtering aligned conditionals
self.FLT_ALIGNED_CONDITIONALS = False
self.FLT_ALC_N_POINTS_X_AX = 30
self.FLT_ALC_N_TRI_PLOTS = 0
self.FLT_ALC_ANCHOR = None
self.FLT_ALC_RANGE = [-5.,5.]
# Filtering Aligned marginals
self.FLT_ALIGNED_MARGINALS = False
self.FLT_ALM_N_POINTS = 1000
self.FLT_ALM_N_TRI_PLOTS = 0
# Samples
self.FLT_QUADRATURE = []
self.FLT_QUAD_QTYPE = []
self.FLT_QUAD_QNUM = []
[docs] def load(self):
super(SequentialPostprocess, self).load()
if self.TRIM is None:
self.TRIM = self.target_distribution.nsteps
if self.TRIM == self.target_distribution.nsteps:
self.filt_tmap_list = self.stg.integrator.filtering_map_list
elif self.TRIM < self.target_distribution.nsteps:
integrator = self.stg.integrator.trim(self.TRIM)
self.target_distribution = self.target_distribution.trim(self.TRIM)
self.dim = self.target_distribution.dim
self.base_distribution = DIST.StandardNormalDistribution(self.dim)
self.tmap = integrator.smoothing_map
self.approx_base_distribution = DIST.PullBackTransportMapDistribution(
self.tmap, self.target_distribution)
self.approx_target_distribution = DIST.PushForwardTransportMapDistribution(
self.tmap, self.base_distribution)
else:
raise ValueError("The value --trim exceed the total number of steps.")
# Set hdf5 root
ROOT_NAME = "trim-%d" % self.TRIM
if ROOT_NAME not in self.h5_file:
self.h5_file.create_group(ROOT_NAME)
self.h5_root = self.h5_file[ROOT_NAME]
[docs] def sequential_variance_diagnostic(self, mpi_pool):
if self.SEQUENTIAL_VAR_DIAG:
self.filter_tstamp_print("[Start] Sequential variance diagnostic")
if self.PLOTTING:
import matplotlib.pyplot as plt
fig = plt.figure()
ax = fig.add_subplot(111)
ax2 = ax.twinx()
ax.semilogy(self.stg.integrator.var_diag_convergence)
ax.set_xlabel("Step")
ax.set_ylabel(r"$\mathbb{V}[\log\rho/T_i^\sharp\pi_i]$")
n_coeffs = [ tm.n_coeffs for tm in self.stg.integrator.R_list ]
ax2.semilogy(n_coeffs, 'k')
ax2.set_ylabel("number of coefficients")
if self.STORE_FIG_DIR is not None:
self.store_figure(
fig, self.STORE_FIG_DIR + "/" + self.TITLE + \
'-sequential-var-diag' + self.EXTRA_TIT)
else:
plt.show(False)
self.filter_tstamp_print("[Stop] Sequential variance diagnostic")
[docs] def sequential_regression_diagnostic(self, mpi_pool):
if self.SEQUENTIAL_REG_DIAG:
self.filter_tstamp_print("[Start] Sequential regression diagnostic")
if self.PLOTTING:
import matplotlib.pyplot as plt
fig = plt.figure()
ax = fig.add_subplot(111)
ax.semilogy(self.stg.integrator.regression_convergence)
ax.set_xlabel("Iteration")
ax.set_ylabel(r"$\Vert H_i - H_{i-1}\circ\tilde{H}_i\Vert_2$")
if self.STORE_FIG_DIR is not None:
self.store_figure(
fig, self.STORE_FIG_DIR + "/" + self.TITLE + \
'-sequential-reg-diag' + self.EXTRA_TIT)
else:
plt.show(False)
self.filter_tstamp_print("[Stop] Sequential regression diagnostic")
[docs] def filtering_aligned_conditionals(self, mpi_pool):
if self.FLT_ALIGNED_CONDITIONALS:
self.filter_tstamp_print("[Start] Filtering conditionals")
for n, filt_tmap in enumerate(self.filt_tmap_list):
self.filter_tstamp_print(" Filtering conditionals " + \
"- Step %d" % n)
DATA_FIELD = 'filtering-conditionals-%d' % n
data = self.postproc_data.get(DATA_FIELD, None)
if data is None:
base_density = DIST.StandardNormalDistribution(filt_tmap.dim)
d = DIST.PushForwardTransportMapDistribution(
filt_tmap, base_density)
data = DIAG.computeAlignedConditionals(
d, dimensions_vec=self.FLT_ALC_N_TRI_PLOTS,
numPointsXax=self.FLT_ALC_N_POINTS_X_AX,
pointEval=self.FLT_ALC_ANCHOR, range_vec=self.FLT_ALC_RANGE,
mpi_pool=mpi_pool)
self.postproc_data[DATA_FIELD] = data
if self.OUTPUT is not None:
self.safe_store(self.postproc_data, self.OUTPUT)
if self.PLOTTING:
fig = DIAG.plotAlignedConditionals(
data=data, show_flag=(self.STORE_FIG_DIR is None))
if self.STORE_FIG_DIR is not None:
self.store_figure(fig, self.STORE_FIG_DIR + '/' + \
self.TITLE + \
'-filtering-conditionals-%d' % n +\
self.EXTRA_TIT)
self.filter_tstamp_print("[Stop] Filtering conditionals")
[docs] def filtering_aligned_marginals(self, mpi_pool):
if self.FLT_ALIGNED_MARGINALS:
self.filter_tstamp_print("[Start] Filtering marginals")
F_GRP_NAME = "/filtering"
if F_GRP_NAME not in self.h5_root:
self.h5_root.create_group(F_GRP_NAME)
fgrp = self.h5_root[F_GRP_NAME]
for n, filt_tmap in enumerate(self.filt_tmap_list):
self.filter_tstamp_print(" Filtering marginals " + \
"- Step %d - Sample generation" % n)
dim = filt_tmap.dim
# Load values if any
S_GRP_NAME = "step-%d" % n
if S_GRP_NAME not in fgrp:
fgrp.create_group(S_GRP_NAME)
sgrp = fgrp[S_GRP_NAME]
Q_GRP_NAME = "quadrature"
if Q_GRP_NAME not in sgrp:
sgrp.create_group(Q_GRP_NAME)
qgrp = sgrp[Q_GRP_NAME]
DSET_NAME = '0'
if DSET_NAME not in qgrp:
qgrp.create_dataset(
DSET_NAME, (0,dim), maxshape=(None,dim), dtype='d')
loaded_samp = qgrp[DSET_NAME]
if self.FLT_ALM_N_POINTS > loaded_samp.shape[0]:
nold = loaded_samp.shape[0]
new_nsamp = self.FLT_ALM_N_POINTS - nold
base_density = DIST.StandardNormalDistribution(dim)
d = DIST.PushForwardTransportMapDistribution(
filt_tmap, base_density)
x = d.rvs(new_nsamp, mpi_pool=mpi_pool)
loaded_samp.resize(self.FLT_ALM_N_POINTS, axis=0)
loaded_samp[nold:,:] = x
self.filter_tstamp_print(" Filtering marginals " + \
"- Step %d - Plotting" % n)
if self.PLOTTING:
fig = DIAG.plotAlignedMarginals(
loaded_samp[:self.FLT_ALM_N_POINTS,:], self.FLT_ALM_N_TRI_PLOTS,
show_flag=(self.STORE_FIG_DIR is None))
if self.STORE_FIG_DIR is not None:
self.store_figure(
fig, self.STORE_FIG_DIR+'/'+self.TITLE + \
'-filtering-marginals-%d' % n + \
self.EXTRA_TIT)
self.filter_tstamp_print("[Stop] Filtering marginals")
[docs] def filtering_quadratures(self, mpi_pool):
for _, qtype, qnum in zip(
self.FLT_QUADRATURE, self.FLT_QUAD_QTYPE, self.FLT_QUAD_QNUM):
self.filter_tstamp_print("[Start] Quadrature " + str(qtype))
F_GRP_NAME = "/filtering"
if F_GRP_NAME not in self.h5_root:
self.h5_root.create_group(F_GRP_NAME)
fgrp = self.h5_root[F_GRP_NAME]
for n, filt_tmap in enumerate(self.filt_tmap_list):
self.filter_tstamp_print(" Quadrature " + str(qtype) + \
"- Step %d - Sample generation" % n)
dim = filt_tmap.dim
# Load values if any
S_GRP_NAME = "step-%d" % n
if S_GRP_NAME not in fgrp:
fgrp.create_group(S_GRP_NAME)
sgrp = fgrp[S_GRP_NAME]
Q_GRP_NAME = "quadrature"
if Q_GRP_NAME not in sgrp:
sgrp.create_group(Q_GRP_NAME)
qgrp = sgrp[Q_GRP_NAME]
if qtype == 0: # Monte-Carlo
DSET_NAME = str(qtype)
if DSET_NAME not in qgrp:
qgrp.create_dataset(
DSET_NAME, (0,dim), maxshape=(None,dim), dtype='d')
loaded_samp = qgrp[DSET_NAME]
if qnum[0] > loaded_samp.shape[0]:
nold = loaded_samp.shape[0]
new_nsamp = qnum[0] - nold
base_density = DIST.StandardNormalDistribution(dim)
d = DIST.PushForwardTransportMapDistribution(
filt_tmap, base_density)
x = d.rvs(new_nsamp, mpi_pool=mpi_pool)
loaded_samp.resize(qnum[0], axis=0)
loaded_samp[nold:,:] = x
elif qtype == 3: # Gauss quadrature
QTYPE_NAME = str(qtype)
X_NAME = 'x'
W_NAME = 'w'
if QTYPE_NAME not in qgrp:
qgrp.create_group(QTYPE_NAME)
qtp_grp = qgrp[QTYPE_NAME]
QNUM_NAME = str(qnum)
if QNUM_NAME not in qtp_grp:
qtp_grp.create_group(QNUM_NAME)
qngrp = qtp_grp[QNUM_NAME]
base_density = DIST.StandardNormalDistribution(dim)
d = DIST.PushForwardTransportMapDistribution(
filt_tmap, base_density)
(x, w) = d.quadrature(qtype, qnum, mpi_pool=mpi_pool)
qngrp.create_dataset(X_NAME, data=x)
qngrp.create_dataset(W_NAME, data=w)
self.filter_tstamp_print("[Stop] Quadrature")
[docs] def run(self, mpi_pool):
super(SequentialPostprocess, self).run(mpi_pool)
self.sequential_variance_diagnostic(mpi_pool)
self.sequential_regression_diagnostic(mpi_pool)
self.filtering_aligned_conditionals(mpi_pool)
self.filtering_aligned_marginals(mpi_pool)
self.filtering_quadratures(mpi_pool)