Source code for TransportMaps.Algorithms.Adaptivity.KullbackLeiblerAdaptivity

#
# This file is part of TransportMaps.
#
# TransportMaps is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# TransportMaps is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with TransportMaps.  If not, see <http://www.gnu.org/licenses/>.
#
# Transport Maps Library
# Copyright (C) 2015-2018 Massachusetts Institute of Technology
# Uncertainty Quantification group
# Department of Aeronautics and Astronautics
#
# Authors: Transport Map Team
# Website: transportmaps.mit.edu
# Support: transportmaps.mit.edu/qa/
#

import copy
from typing import List

import pickle
import numpy as np

import semilattices as SL

from .RegressionAdaptivity import L2RegressionBuilder

from ...Misc import \
    cmdinput, \
    read_and_cast_input, \
    argsort, \
    state_loader
from ...External import PLOT_SUPPORT
from ...Builders import \
    KullbackLeiblerBuilder
from ...Distributions import \
    PullBackTransportMapDistribution, Distribution
from ...KL import \
    grad_t_kl_divergence
from ...LaplaceApproximationRoutines import laplace_approximation
from ...MPI import mpi_map, mpi_bcast_dmem, mpi_scatter_dmem, mpi_map_alloc_dmem
from ...Maps import \
    assemble_LinearSpanTriangularMap, ParametricTransportMap
from ...Diagnostics.Routines import variance_approx_kl
from ...Maps import \
    AffineTransportMap, \
    AffineTriangularMap
from ...Maps.Functionals import \
    LinearSpanTensorizedParametricFunctional

__all__ = [
    'SequentialKullbackLeiblerBuilder',
    'ToleranceSequentialKullbackLeiblerBuilder',
    'FirstVariationKullbackLeiblerBuilder'
]

nax = np.newaxis

[docs]class SequentialKullbackLeiblerBuilder(KullbackLeiblerBuilder): r""" Solve over a list of maps, using the former to warm start the next one Given distribution :math:`\nu_\rho` and :math:`\nu_\pi`, and the list of parametric transport maps :math:`[T_1[{\bf a}_1,\ldots,T_n[{\bf a}_n]`, provides the functionalities to solve the problems .. math:: \arg\min_{{\bf a}_i}\mathcal{D}_{\rm KL}\left( T_i[{\bf a}_i]_\sharp\rho \Vert \pi\right) up to a chosen tolerance, where the numerical solution for map :math:`T_{i+1}` is started at :math:`T_i` """ def __init__( self, validator=None, regression_params_list=None, callback=None, callback_kwargs={}, verbosity=0): r""" Args: validator (:class:`Validator<TransportMaps.Diagnostic.Validator>`): validator to be used to check stability of the solution regression_params_list (:class:`list` of :class:`dict`): list of dictionaries of parameters for the regression between :math:`T_i` and :math:`T_{i+1}` verbosity (int): level of verbosity of the builder """ self.solve_counter = 0 self.regression_params_list = regression_params_list super(SequentialKullbackLeiblerBuilder, self).__init__( validator, callback=callback, callback_kwargs=callback_kwargs, verbosity=verbosity ) @state_loader( keys = [ 'transport_map', 'base_distribution', 'target_distribution', 'solve_params' ] )
[docs] def solve( self, transport_map: List[ParametricTransportMap] = None, base_distribution: Distribution = None, target_distribution: Distribution = None, solve_params: List[dict] = None, state=None, mpi_pool=None ): r""" Args transport_map (:class:`list` of :class:`TransportMap<TransportMaps.Maps.TransportMap>`): transport maps :math:`T` base_distribution (:class:`Distribution<TransportMaps.Distributions.Distribution>`): distribution :math:`\nu_\rho` target_distribution (:class:`Distribution<TransportMaps.Distributions.Distribution>`): distribution :math:`\nu_\pi` solve_params (:class:`list` of :class:`dict`): list of dictionaries of parameters for solution state (:class:`TransportMaps.DataStorageObject`): if provided, it must contain all the information needed for reloading, or a handle to an empty storage object which can be externally stored. If ``state`` contains the keys corresponding to arguments to this function, they will be used instead of the input themselves. Returns: (:class:`TransportMaps.Maps.TransportMap`) -- the transport map fitted. """ # Internal states state.solve_counter = getattr(state, 'solve_counter', 0) if len(state.transport_map_list) != len(state.solve_params_list): raise ValueError( "Unconsistent number of transport maps." ) if state.solve_counter == 0: transport_map = state.transport_map_list[0] solve_params = state.solve_params_list[0] tm, log = super(SequentialKullbackLeiblerBuilder, self).solve( transport_map, state.base_distribution, state.target_distribution, solve_params, mpi_pool=mpi_pool ) if not log['success']: tm.coeffs = x0 return tm, log state.solve_counter += 1 tm_old = state.transport_map_list[state.solve_counter-1] for transport_map, solve_params in zip( state.transport_map_list[state.solve_counter:], state.solve_params_list[state.solve_counter:]): # Here we are assuming nested basis for c1, c2 in zip(tm_old.approx_list, transport_map.approx_list): # Constant part for i1, midx1 in enumerate(c1.c.multi_idxs): for i2, midx2 in enumerate(c2.c.multi_idxs): if midx1 == midx2: break c2.c.coeffs[i2] = c1.c.coeffs[i1] # Integrated part for i1, midx1 in enumerate(c1.h.multi_idxs): for i2, midx2 in enumerate(c2.h.multi_idxs): if midx1 == midx2: break c2.h.coeffs[i2] = c1.h.coeffs[i1] # solve for the new map using regressed starting point solve_params['x0'] = transport_map.coeffs tm, log = super(SequentialKullbackLeiblerBuilder, self).solve( transport_map, base_distribution, target_distribution, solve_params, mpi_pool=mpi_pool ) if not log['success']: return tm_old, log tm_old = tm state.solve_counter += 1 return tm, log
[docs]class ToleranceSequentialKullbackLeiblerBuilder(KullbackLeiblerBuilder): r""" Solve over a list of maps, using the former to warm start the next one, until a target tolerance is met Given distribution :math:`\nu_\rho` and :math:`\nu_\pi`, and the list of parametric transport maps :math:`[T_1[{\bf a}_1,\ldots,T_n[{\bf a}_n]`, provides the functionalities to solve the problems .. math:: \arg\min_{{\bf a}_i}\mathcal{D}_{\rm KL}\left( T_i[{\bf a}_i]_\sharp\rho \Vert \pi\right) up to a chosen tolerance, where the numerical solution for map :math:`T_{i+1}` is started at :math:`T_i` """ def __init__( self, validator=None, tol=1e-2, laplace_pull=False, callback=None, callback_kwargs={}, verbosity=0): r""" Args: validator (:class:`Validator<TransportMaps.Diagnostic.Validator>`): validator to be used to check stability of the solution tol (float): target variance diagnostic tolerance callback (function): function taking a map and optional additional arguments which is called whenever it is deemed necessary by the chosen algorithm (e.g. for storing purposes) callback_kwargs (dict): additional arguments to be provided to the function ``callback``. verbosity (int): level of verbosity of the builder """ self.solve_counter = 0 self.tol = tol self.laplace_pull = laplace_pull super(ToleranceSequentialKullbackLeiblerBuilder, self).__init__( validator, callback=callback, callback_kwargs=callback_kwargs, verbosity=verbosity) @state_loader( keys = [ 'transport_map_list', 'base_distribution', 'target_distribution', 'solve_params_list', 'var_diag_params' ] )
[docs] def solve( self, transport_map: List[ParametricTransportMap], base_distribution: Distribution, target_distribution: Distribution, solve_params: List[dict], var_diag_params: dict, state=None, mpi_pool=None, ): r""" Args transport_map (:class:`list` of :class:`TransportMap<TransportMaps.Maps.TransportMap>`): transport maps :math:`T` base_distribution (:class:`Distribution<TransportMaps.Distributions.Distribution>`): distribution :math:`\nu_\rho` target_distribution (:class:`Distribution<TransportMaps.Distributions.Distribution>`): distribution :math:`\nu_\pi` solve_params (:class:`list` of :class:`dict`): list of dictionaries of parameters for solution var_diag_params (dict): parameters to be used in the variance diagnostic approximation state (:class:`TransportMaps.DataStorageObject`): if provided, it must contain all the information needed for reloading, or a handle to an empty storage object which can be externally stored. If ``state`` contains the keys corresponding to arguments to this function, they will be used instead of the input themselves. Returns: (:class:`TransportMaps.Maps.TransportMap`) -- the transport map fitted. """ # Internal states state.solve_counter = getattr(state, 'solve_counter', 0) transport_map_list = transport_map solve_params_list = solve_params if len(state.transport_map_list) != len(state.solve_params_list): raise ValueError( "Unconsistent number of transport maps." ) if state.solve_counter == 0: if state.var_diag_params is None: state.var_diag_params = { 'qtype': solve_params_list[-1]['qtype'], 'qparams': solve_params_list[-1]['qparams']} transport_map = state.transport_map_list[0] solve_params = state.solve_params_list[0] if self.laplace_pull: # First find Laplace point and center to it lap = laplace_approximation(state.target_distribution) lap_map = AffineTransportMap.build_from_Gaussian(lap) # Set initial conditions to Laplace approximation transport_map.regression( lap_map, d=state.base_distribution, qtype=3, qparams=[3]*state.base_distribution.dim, regularization={'alpha': 1e-4, 'type': 'L2'}) solve_params['x0'] = transport_map.coeffs tm, log = super(ToleranceSequentialKullbackLeiblerBuilder, self).solve( transport_map, state.base_distribution, state.target_distribution, solve_params, mpi_pool=mpi_pool ) if not log['success']: tm.coeffs = x0 return tm, log pull_tar = PullBackTransportMapDistribution( tm, state.target_distribution) var = variance_approx_kl( state.base_distribution, pull_tar, **state.var_diag_params) self.logger.info("Variance diagnostic: %e" % var) if var <= self.tol: return tm, log state.solve_counter += 1 tm_old = state.transport_map_list[state.solve_counter-1] for transport_map, solve_params in zip( state.transport_map_list[state.solve_counter:], state.solve_params_list[state.solve_counter:]): # Here we are assuming nested basis for c1, c2 in zip(tm_old.approx_list, transport_map.approx_list): # Constant part for i1, midx1 in enumerate(c1.c.multi_idxs): for i2, midx2 in enumerate(c2.c.multi_idxs): if midx1 == midx2: break c2.c.coeffs[i2] = c1.c.coeffs[i1] # Integrated part for i1, midx1 in enumerate(c1.h.multi_idxs): for i2, midx2 in enumerate(c2.h.multi_idxs): if midx1 == midx2: break c2.h.coeffs[i2] = c1.h.coeffs[i1] # solve for the new map using regressed starting point solve_params['x0'] = transport_map.coeffs tm, log = super(ToleranceSequentialKullbackLeiblerBuilder, self).solve( transport_map, state.base_distribution, state.target_distribution, solve_params, mpi_pool=mpi_pool ) if not log['success']: return tm_old, log pull_tar = PullBackTransportMapDistribution(tm, state.target_distribution) var = variance_approx_kl( state.base_distribution, pull_tar, **state.var_diag_params) self.logger.info("Variance diagnostic: %e" % var) if var <= self.tol: return tm, log tm_old = tm state.solve_counter += 1 # Variance was not met log['success'] = False log['msg'] = "Desired tolerance was no met by the map adaptivity. " + \ "Target variance: %e - Variance: %e " % (self.tol, var) return tm, log
[docs]class FirstVariationKullbackLeiblerBuilder(KullbackLeiblerBuilder): r""" Adaptive builder based on the first variation of the kl divergence Given distribution :math:`\nu_\rho` and :math:`\nu_\pi`, and the parametric transport map :math:`T[{\bf a}]`, provides the functionalities to solve the problem .. math:: \arg\min_{\bf a}\mathcal{D}_{\rm KL}\left( T[{\bf a}]_\sharp\rho \Vert \pi\right) = \arg\min_{\bf a}\underbrace{\mathbb{E}_\rho\left[ -\log T[{\bf a}]^\sharp\pi \right]}_{ \mathcal{J}[T]({\bf x})} up to a chosen tolerance, by enriching the map using information from the first variation .. math:: \nabla\mathcal{J}[T]({\bf x}) = (\nabla_{\bf x}T)^{-\top} \left(\log\frac{\rho({\bf x})}{T^\sharp\pi({\bf x})}\right) """ def __init__( self, validator, eps_bull, regression_builder=L2RegressionBuilder(), line_search_params={}, max_it=20, prune_trunc={'type': 'manual', 'val': None}, avar_trunc={'type': 'manual', 'val': None}, coeff_trunc={'type': 'manual', 'val': None}, callback=None, callback_kwargs={}, verbosity=0, interactive=False): r""" Args: validator (:class:`Validator<TransportMaps.Diagnostic.Validator>`): validator to be used to check stability of the solution eps_bull (float): target tolerance of variance diagnostic callback (function): function taking a map and optional additional arguments which is called whenever it is deemed necessary by the chosen algorithm (e.g. for storing purposes) callback_kwargs (dict): additional arguments to be provided to the function ``callback``. verbosity (int): level of verbosity of the builder interactive (bool): whether to ask for permission to proceed to the user """ self.regression_builder = regression_builder self.eps_bull = eps_bull self.line_search_params = line_search_params self.max_it = max_it self.prune_trunc = prune_trunc self.avar_trunc = avar_trunc self.coeff_trunc = coeff_trunc super(FirstVariationKullbackLeiblerBuilder, self).__init__( validator, callback=callback, callback_kwargs=callback_kwargs, verbosity=verbosity, interactive=interactive)
[docs] def _validation(self, state, mpi_pool=None): spmet = False while not spmet and \ not state.validation_log.get('validator_cost_exceeded', False) and \ not state.validation_log.get('validator_fcast_cost_exceeded', False): self.logger.info("Validation...") _, state.validation_log = super( FirstVariationKullbackLeiblerBuilder, self).solve( state.transport_map, state.base_distribution, state.target_distribution, state.solve_params, mpi_pool=mpi_pool ) # Separate cache from log, so not to store it. cache = state.validation_log.pop('cache') if not state.validation_log.get('success', False) or \ state.validation_log.get('validator_cost_exceeded', False) or \ state.validation_log.get('validator_fcast_cost_exceeded', False): if not state.validation_log.get('success', True): state.fv_adapt_status = 'Failed to converge' self.logger.warning( "KL-minimization failed to converge. " + \ "Reverting to the last available map.") else: state.fv_adapt_status = 'Cost exceeded' self.logger.warning( "Maximum cost exceeded. Reverting to the last available map.") if len(state.transport_map_list) > 0: state.transport_map = state.transport_map_list[-1] return False, cache state.transport_map_list.append( state.transport_map ) state.qparams_list.append( {'n_samps': state.solve_params['x'].shape[0]} ) state.target_ncalls_list.append( state.target_distribution.get_ncalls_tree() ) state.target_nevals_list.append( state.target_distribution.get_nevals_tree() ) if self.callback is not None: self.callback( state.transport_map, **self.callback_kwargs ) state.validator_error_list.append( state.validation_log.get('validator_error', 0.) ) spmet = state.validator_error_list[-1] < state.validation_log.get( 'validator_target_error', np.inf) state.spmet_list.append(spmet) if not spmet: self.logger.info("Pruning...") # Prune tm_new, flag, prune_params = \ FirstVariationKullbackLeiblerBuilder._prune_map( state.transport_map, state.validation_log['validator_prune_params'], self.prune_trunc, method='active') state.prune_trunc_params_list.append( prune_params ) self.logger.info( "Map pruning. Map structure:\n" + \ map_structure_str(tm_new, indent=' ', verbosity=self.verbosity) + \ map_sparsity_str(tm_new, indent=' ', verbosity=self.verbosity)) if tm_new.n_coeffs == state.transport_map.n_coeffs: # The minimum number of coefficients has already been reached state.fv_adapt_status = \ "The pruning of the map, did not lead to the removal of " + \ "any degree of freedom." self.logger.warning( state.fv_adapt_status ) state.spmet = False state.tolmet = False if len(state.transport_map_list) > 0: state.transport_map = state.transport_map_list[-1] return False, cache else: state.transport_map = tm_new state.solve_params['x0'] = state.transport_map.coeffs return True, cache
[docs] def _diagnostic(self, state, mpi_pool=None): self.logger.info("Computing variance diagnostic...") pb_distribution = PullBackTransportMapDistribution( state.transport_map, state.target_distribution) var_diag = variance_approx_kl( state.base_distribution, pb_distribution, qtype=state.solve_params['qtype'], qparams=state.solve_params.get('qparams', {'eps_bull': self.eps_bull}), mpi_pool_tuple=(None, mpi_pool) ) self.logger.info("Variance diagnostic: %.3e (target %.3e)" % ( var_diag,self.eps_bull)) state.variance_diagnostic_list.append( var_diag ) if len(state.variance_diagnostic_list) > 1 and \ var_diag > state.variance_diagnostic_list[-2]: state.fv_adapt_status = \ "The variance diagnostic is not decreasing. This can be due to " + \ "several reasons:\n" + \ " 1) the number of quadrature points is insufficient and/or\n" + \ " 2) the validation tolerance is too low with respect to the " + \ "target adaptivity tolerance and/or\n" + \ " 3) a pruning step with too strict tolerance has occurred" if self.interactive: self.logger.warning(state.fv_adapt_status) instr = None while instr not in ['c', 'q']: instr = cmdinput( "Please specify whether to (c)ontinue or to (q)uit: ") if instr == 'q': self.logger.warning( "The algorithm has been manually terminated.") return False # Terminate else: if var_diag > state.variance_diagnostic_list[-2] + state.validator_error_list[-1]: state.fv_adapt_status += "\n" + \ "The algorithms is automatically terminating." self.logger.warning(state.fv_adapt_status) return False # Terminate else: keep_going_msg = "\n" + \ "Even though not decreasing, the variance diagnostic is still within " + \ "the validation error." state.fv_adapt_status += keep_going_msg self.logger.warning(keep_going_msg) state.tolmet = var_diag <= self.eps_bull return True
[docs] def _refinement(self, state, cache=None, mpi_pool=None): self.logger.info("Computing first variation...") if state.solve_params['qtype'] == 4: self.logger.warning( "Using MC for first variation." + \ "We should be able to use other quadratures as well.") (x, w) = state.base_distribution.quadrature( 0, 10000) else: try: x = state.solve_params['x'] w = state.solve_params['w'] except KeyError: (x, w) = state.base_distribution.quadrature( state.solve_params['qtype'], state.solve_params['qparams']) # Compute first variation (here we need to make use of the caching) pb_distribution = PullBackTransportMapDistribution( state.transport_map, state.target_distribution) gt = FirstVariationKullbackLeiblerBuilder._compute_first_variation( x, w, state.base_distribution, pb_distribution, cache=cache, batch_size=state.solve_params.get('batch_size'), mpi_pool=mpi_pool) # Generate candidate transport map for regression of first variation self.logger.info("Projection of first variation...") # Project first variation on linear map (to extract active variables) fv_tri_tm = AffineTriangularMap(dim=state.transport_map.dim) fv_tri_tm, log_reg1 = self.regression_builder.solve( fv_tri_tm, gt, x=x, w=w) abs_exp_gx_fv_tri_tm = np.abs(fv_tri_tm.L) # Construct first variation candidate map (and prune unnecessary variables) fv_tri_tm, flag, trunc_params = \ FirstVariationKullbackLeiblerBuilder._first_variation_candidate_triangular_map( state.transport_map, abs_exp_gx_fv_tri_tm, self.avar_trunc) state.avars_trunc_params_list.append( trunc_params ) if flag == 'quit': self.logger.info("Terminating.") state.fv_adapt_status = \ "Simulation aborted by the user during " + \ "the pruning of the first variation in the refinement step." return False # Apply regression again to learn only important coefficients fv_tri_tm, log_reg2 = self.regression_builder.solve( fv_tri_tm, gt, x=x, w=w) for log_entry in log_reg2: if not log_entry['success']: state.fv_adapt_status = \ "Some of the optimizations during the second regression did not converge." self.logger.warning( "Terminating: " + state.fv_adapt_status) return False # Line search and evaluation of improved map self.logger.info("Line search...") tm_ev = state.transport_map.evaluate(x) fv_tri_tm_ev = fv_tri_tm.evaluate(x) tm_pxd = state.transport_map.partial_xd(x) fv_tri_tm_pxd = fv_tri_tm.partial_xd(x) delta, ls_success = \ FirstVariationKullbackLeiblerBuilder._kl_divergence_fv_line_search( state.target_distribution, w, tm_ev, fv_tri_tm_ev, tm_pxd, fv_tri_tm_pxd, self.line_search_params, mpi_pool=mpi_pool, interactive=self.interactive) if not ls_success: state.fv_adapt_status = \ "Line search did not converge (delta: %e" % delta + "). " + \ "This may due to several causes:\n" + \ " 1) the maximum number of line search iterations is too low\n" + \ " 2) the validation tolerance is too low to " + \ "be able to detect improving directions\n" + \ " 3) the validation tolerance is too low to detect that " + \ "there is no other improving direction\n" + \ " 4) the validation tolerance is too low w.r.t the target " + \ "adaptivity tolerance" self.logger.warning(state.fv_adapt_status) if self.interactive: instr = None while instr not in ['c', 'q']: instr = cmdinput( "Specify whether to " + \ "(c)ontinue with stricter validation tolerances, " + \ "or to (q)uit: ") if instr == 'q': self.logger.warning( "The algorithm has been manually terminated.") return False elif instr == 'c': flag = self.validator.update_tolerances() return flag else: self.logger.warning("Terminating: " + state.fv_adapt_status) return False self.logger.info("Line search - delta: %e" % delta) # Generate candidate transport map for regression on improved map self.logger.info("Generating new candidate map...") tm_new = FirstVariationKullbackLeiblerBuilder._improved_candidate_map( state.transport_map, fv_tri_tm) x0 = tm_new.coeffs tm_new, log_reg3 = self.regression_builder.solve( tm_new, tm_ev - delta * fv_tri_tm_ev, x=x, w=w, x0=x0) for log_entry in log_reg3: if not log_entry['success']: state.fv_adapt_status = \ "Some of the optimizations during the third regression step did not converge." self.logger.warning( "Terminating. " + state.fv_adapt_status) return False # Remove unnecessary coefficients tm_new, flag, trunc_params = \ FirstVariationKullbackLeiblerBuilder._prune_map( tm_new, np.abs(tm_new.coeffs), self.coeff_trunc, method='childless') if flag == 'quit': self.logger.info("Terminating.") state.fv_adapt_status = \ "Simulation aborted by the user during the prune in the refinement step." return False if is_equal_map(state.transport_map, tm_new): state.fv_adapt_status = \ "The refinement step did not change the original map. " + \ "This may be due to several factors:\n" + \ " 1) the truncation tolerances are too strict w.r.t. " + \ "the validation tolerances\n" + \ " 2) the validation tolerance is too low w.r.t the target " + \ "adaptivity tolerance" self.logger.warning("Terminating. " + state.fv_adapt_status) return False # Set the values of the new coefficients in the approximation to zero # and the values of the old coefficients to the ones obtained in the latest # optimization cycle. # If changing dimension we need to rescale the coefficients # with respect to the normalization constant. for comp, avars, avars_old in zip( tm_new.approx_list, tm_new.active_vars, state.transport_map.active_vars): sl = comp.c.semilattice factor = np.sqrt( np.prod( [comp.c.full_basis_list[var].Gamma(0) for var in avars_old[:-1] ]) \ / np.prod( [comp.c.full_basis_list[var].Gamma(0) for var in avars[:-1] ]) ) for v in sl: if v.data['is_new']: v.coeff = 0. else: v.coeff = v.data['old_coeff'] / factor del v.data['old_coeff'] del v.data['is_new'] sl = comp.h.semilattice factor = np.sqrt( np.prod( [comp.h.full_basis_list[var].Gamma(0) for var in avars_old[:-1] ] ) \ / np.prod( [comp.h.full_basis_list[var].Gamma(0) for var in avars[:-1] ] ) ) for v in sl: if v.data['is_new']: v.coeff = 0. else: v.coeff = v.data['old_coeff'] / factor del v.data['old_coeff'] del v.data['is_new'] self.logger.info( "Map refinement. Map structure:\n" + \ map_structure_str(tm_new, indent=' ', verbosity=self.verbosity) + \ map_sparsity_str(tm_new, indent=' ', verbosity=self.verbosity)) # Set up as the new transport map and new initial conditions state.transport_map = tm_new state.solve_params['x0'] = state.transport_map.coeffs return True
@state_loader( keys = [ 'transport_map', 'base_distribution', 'target_distribution', 'solve_params'] )
[docs] def solve( self, transport_map=None, base_distribution=None, target_distribution=None, solve_params=None, state=None, mpi_pool=None ): r""" Args: transport_map (:class:`TransportMap<TransportMaps.Maps.TransportMap>`): starting transport map :math:`T` base_distribution (:class:`Distribution<TransportMaps.Distributions.Distribution>`): distribution :math:`\nu_\rho` target_distribution (:class:`Distribution<TransportMaps.Distributions.Distribution>`): distribution :math:`\nu_\pi` solve_params (dict): dictionary of parameters for solutino state (:class:`TransportMaps.DataStorageObject`): if provided, it must contain all the information needed for reloading, or a handle to an empty storage object which can be externally stored. If ``state`` contains the keys corresponding to arguments to this function, they will be used instead of the input themselves. Returns: (:class:`TransportMaps.Maps.TransportMap`) -- the transport map fitted. """ # Internal solve states state.iter_counter = getattr(state, 'iter_counter', 0) state.adapt_stage = getattr(state, 'adapt_stage', 'validation') state.tolmet = getattr(state, 'tolmet', False) state.validation_log = getattr(state, 'validation_log', {}) state.refinement_log = getattr(state, 'refinement_log', {}) state.diagnostic_log = getattr(state, 'diagnostic_log', {}) state.fv_adapt_status = getattr(state, 'fv_adapt_status', 'success') state.transport_map_list = getattr(state, 'transport_map_list', []) state.validator_error_list = getattr(state, 'validator_error_list', []) state.spmet_list = getattr(state, 'spmet_list', []) state.variance_diagnostic_list = getattr(state, 'variance_diagnostic_list', []) state.qparams_list = getattr(state, 'qparams_list', []) state.target_ncalls_list = getattr(state, 'target_ncalls_list', []) state.target_nevals_list = getattr(state, 'target_nevals_list', []) state.avars_trunc_params_list = getattr(state, 'avars_trunc_params_list', []) state.coeffs_trunc_params_list = getattr(state, 'coeffs_trunc_params_list', []) state.prune_trunc_params_list = getattr(state, 'prune_trunc_params_list', []) continue_flag = True # Continue self.logger.info( "Starting. Map structure:\n" + \ map_structure_str( state.transport_map, indent=' ', verbosity=self.verbosity) + \ map_sparsity_str( state.transport_map, indent=' ', verbosity=self.verbosity)) if state.solve_params.get('x0') is None: state.transport_map.coeffs = state.transport_map.get_identity_coeffs() state.solve_params['x0'] = state.transport_map.coeffs while state.iter_counter < self.max_it and not state.tolmet: self.logger.info("Iteration %d" % state.iter_counter) if state.adapt_stage == 'validation': continue_flag, cache = self._validation(state, mpi_pool) if not continue_flag: break state.adapt_stage = 'diagnostic' if self.callback is not None: self.callback( state.transport_map, **self.callback_kwargs ) if state.adapt_stage == 'diagnostic': # Once the coefficients are determined we check whether the # variance diagnostic tolerance is met continue_flag = self._diagnostic(state, mpi_pool) if not continue_flag or state.tolmet: break state.adapt_stage = 'refinement' if self.callback is not None: self.callback( state.transport_map, **self.callback_kwargs ) if state.adapt_stage == 'refinement': try: cache except NameError: cache = None if not state.tolmet: # Refinement continue_flag = self._refinement(state, cache, mpi_pool) if not continue_flag: break state.adapt_stage = 'validation' del cache # Free some memory... state.iter_counter += 1 if self.callback is not None: self.callback( state.transport_map, **self.callback_kwargs ) if state.iter_counter == self.max_it: fv_adapt_status = 'Maximum number of iterations exceeded' if self.callback is not None: self.callback( state.transport_map, **self.callback_kwargs ) log = {'fv_adapt_status': state.fv_adapt_status, 'fv_adapt_tolmet': state.tolmet, 'fv_adapt_it': state.iter_counter} return state.transport_map, log
@staticmethod
[docs] def _compute_first_variation( x, w, d1, d2, cache=None, batch_size=None, mpi_pool=None): # Distribute objects d2_distr = pickle.loads( pickle.dumps(d2) ) d2_distr.reset_counters() mpi_bcast_dmem(d1=d1, d2=d2_distr, mpi_pool=mpi_pool) # Link tm to d2.transport_map def link_tm_d2(d2): return (d2.transport_map,) (tm,) = mpi_map_alloc_dmem( link_tm_d2, dmem_key_in_list=['d2'], dmem_arg_in_list=['d2'], dmem_val_in_list=[d2], dmem_key_out_list=['tm'], mpi_pool=mpi_pool) # Prepare cache if mpi_pool: if cache is not None: cache2 = [ cc['pi_cache'] for cc in cache ] else: cache2 = [ None ] * mpi_pool.nprocs mpi_scatter_dmem(cache2=cache2, mpi_pool=mpi_pool) else: cache2 = cache['pi_cache'] if cache else None # Prepare batch size if batch_size is None: bsize = x.shape[0] else: bsize = batch_size[1] # # Split data # if mpi_pool is None: # x_list = [x] # else: # split_dict = mpi_pool.split_data([x],['x']) # x_list = [sd['x'] for sd in split_dict] # max_len = x_list[0].shape[0] # # Compute the number of iterations necessary for batching # niter = max_len // batch_size + (1 if max_len % batch_size > 0 else 0) # # Iterate # idx0_list = [ 0 ] * len(x_list) # grad_t = np.zeros((x.shape[0], d2.dim)) # for it in range(niter): # # Prepare batch-slicing for each chunk # idxs_slice_list = [] # for i, (xs, idx0) in enumerate(zip(x_list, idx0_list)): # incr = min(batch_size, xs.shape[0] - idx0) # idxs_slice_list.append( slice(idx0, idx0+incr, None) ) # idx0_list[i] += incr # # Prepare input x # x_in = [ xs[idxs_slice,:] for xs, idxs_slice in zip(x_list, idxs_slice_list) ] # # Compute grad_x and store in distributed memory # (grad_x_tm,) = mpi_map_alloc_dmem( # 'grad_x', # scatter_tuple=scatter_tuple, # dmem_key_out_list=['grad_x_tm'], # obj='tm', # obj_val=tm, # mpi_pool=mpi_pool) # Compute grad_x and store in distributed memory scatter_tuple = (['x'], [x]) (grad_x_tm,) = mpi_map_alloc_dmem( 'grad_x', scatter_tuple=scatter_tuple, dmem_key_out_list=['grad_x_tm'], obj='tm', obj_val=tm, mpi_pool=mpi_pool) # Compute first variation grad_t = mpi_map( grad_t_kl_divergence, scatter_tuple=scatter_tuple, bcast_tuple=(['batch_size'],[bsize]), dmem_key_in_list=['d1', 'd2', 'grad_x_tm', 'cache2'], dmem_arg_in_list=['d1', 'd2', 'grad_x_tm', 'cache2'], dmem_val_in_list=[d1, d2, grad_x_tm, cache2], mpi_pool=mpi_pool, concatenate=True ) # start = 0 # grad_t = np.zeros((x.shape[0], d2.dim)) # while start < x.shape[0]: # stop = min(x.shape[0], start + bsize) # # Compute grad_x and store in distributed memory # scatter_tuple = (['x'], [x[start:stop,:]]) # (grad_x_tm,) = mpi_map_alloc_dmem( # 'grad_x', # scatter_tuple=scatter_tuple, # dmem_key_out_list=['grad_x_tm'], # obj='tm', # obj_val=tm, # mpi_pool=mpi_pool) # # Compute first variation # (grad_t_batch,) = mpi_map( # grad_t_kl_divergence, # scatter_tuple=scatter_tuple, # dmem_key_in_list=['d1', 'd2', 'grad_x_tm', 'cache2'], # dmem_arg_in_list=['d1', 'd2', 'grad_x_tm', 'cache2'], # dmem_val_in_list=[d1, d2, grad_x_tm, cache2], # mpi_pool=mpi_pool, concatenate=True ) # grad_t[start:stop,:] = grad_t_batch # start = stop # Update counters if mpi_pool is not None: d2_child_list = mpi_pool.get_dmem('d2') d2.update_ncalls_tree( d2_child_list[0][0] ) for (d2_child,) in d2_child_list: d2.update_nevals_tree(d2_child) d2.update_teval_tree(d2_child) # Clear mpi_pool if mpi_pool is not None: mpi_pool.clear_dmem() return grad_t
@staticmethod @staticmethod
[docs] def _first_variation_candidate_triangular_map( transport_map, sensitivities, avar_trunc): r""" Construct the candidate map to be used in the regression of the first variation. It takes the multi-indices in ``transport_map`` and increases them by one, adding also active variables if needed. The active variables to add are detected using the information contained in ``sensitivities``. """ active_vars = [] sl_list = [] scsens = sensitivities / np.max(sensitivities) if isinstance(avar_trunc, dict): if avar_trunc['type'] == 'manual': if not PLOT_SUPPORT: raise ImportError( "The 'manual' truncation type requires plotting, but " + \ "plotting is not supported on this machine. " + \ "The program will be terminated.") import matplotlib.pyplot as plt plt.figure() for d in range(transport_map.dim): plt.semilogy(range(d+1), scsens[d,:d+1], 'o-') # plt.imshow(np.log10(scsens)) # plt.colorbar() plt.grid(True) plt.show(False) trunc_val = None while trunc_val is None: try: istr = cmdinput( "Provide an active variable truncation level [q to quit]: ") trunc_val = float( istr ) except ValueError: if istr == 'q': return transport_map, 'quit', None print("The value entered cannot be casted to a floating point number.") elif avar_trunc['type'] == 'constant': trunc_val = avar_trunc['val'] else: trunc_val = avar_trunc log_params = { 'avar_trunc': copy.deepcopy(avar_trunc), 'trunc_val': trunc_val } for d, (tm_comp, tm_comp_avars) in enumerate(zip( transport_map.approx_list, transport_map.active_vars)): # Update active variables using second order information fv_avars = list(np.where(scsens[d,:d] > trunc_val)[0]) add_avars = [ var for var in fv_avars if not var in tm_comp_avars ] avars = tm_comp_avars + add_avars perm_avars = argsort( avars ) avars = [ avars[i] for i in perm_avars ] # Extract semilattices from constant and integrated squared part c_sl = tm_comp.c.semilattice h_sl = tm_comp.h.semilattice # Expand integrated squared semilattice by doubling the orders # and adding one to the trailing dimension h_sl_2 = h_sl * h_sl # double the orders h_dim = h_sl.dims lst = [ v for v in SL.BreadthFirstSemilatticeIterable( h_sl_2 ) ] for v in lst: # add one to the last dimension if h_dim-1 not in v.children: h_sl_2.new_vertex( parent=v, edge=h_dim-1 ) # Merge constant and integrated squared semilattices sl = c_sl | h_sl_2 if len(add_avars) > 0: # Expand dimension to match new active variables sl.modify_dims(add_dims=len(add_avars)) # Permute semilattice according to the order of the active variables sl = SL.permute(sl, perm_avars) # Expand along admissible frontier lst = [ v for v in sl.admissible_frontier ] for v in lst: edges = sl.potential_children_edges( v ) for edge in edges: sl.new_vertex(parent=v, edge=edge) # Update active_vars.append(avars) sl_list.append(sl) fv_approx = assemble_LinearSpanTriangularMap( transport_map.dim, sl_list, active_vars) return fv_approx, 'success', log_params
@staticmethod
[docs] def _improved_candidate_map(transport_map, fv_map): active_vars = [] approx_list = [] for d, (tm_avars, tm_comp, fv_avars, fv_comp) in enumerate(zip( transport_map.active_vars, transport_map.approx_list, fv_map.active_vars, fv_map.approx_list)): tm_full_c_blist = tm_comp.c.full_basis_list tm_full_h_blist = tm_comp.h.full_basis_list # Retrieve semilattices tmc_sl = tm_comp.c.semilattice.copy() tmh_sl = tm_comp.h.semilattice.copy() # Find missing active variables (tm_avars is a subset of fv_avars) add_avars = [] j = 0 for var in fv_avars: if var != tm_avars[j]: add_avars.append( var ) else: j += 1 # Add the active variables and find appropriate permutation avars = tm_avars + add_avars perm_avars = argsort( avars ) avars = [ avars[i] for i in perm_avars ] # Expand dimensions (append) to match new active variables tmc_sl.modify_dims(add_dims=len(add_avars)) tmh_sl.modify_dims(add_dims=len(add_avars)) # Permute semilattice according to the order of the active variables tmc_sl = SL.permute(tmc_sl, perm_avars) tmh_sl = SL.permute(tmh_sl, perm_avars) # Mark all old vertices as old vertices and store coefficent values for v in tmc_sl: v.data['is_new'] = False v.data['old_coeff'] = v.coeff for v in tmh_sl: v.data['is_new'] = False v.data['old_coeff'] = v.coeff # Expand along admissible frontier (constant part) # and flag new vertices (using the data dictionary in vertices) lst = [ v for v in tmc_sl.admissible_frontier ] for v in lst: edges = tmc_sl.potential_children_edges( v ) for edge in edges: if edge != tmc_sl.dims - 1: # Not allowed to increase in the last dimension # for the constant part new_v = tmc_sl.new_vertex( parent=v, edge=edge) new_v.data['is_new'] = True # Expand along admissible frontier (integrated squared part) # and flag new vertices (using the data dictionary in vertices) lst = [ v for v in tmh_sl.admissible_frontier ] for v in lst: edges = tmh_sl.potential_children_edges( v ) for edge in edges: new_v = tmh_sl.new_vertex( parent=v, edge=edge) new_v.data['is_new'] = True # Build basis (using the full basis set provided by the transport map) c_basis = [ tm_full_c_blist[a] for a in avars ] h_basis = [ tm_full_h_blist[a] for a in avars ] # Build constant and integrated linear span functions c = LinearSpanTensorizedParametricFunctional( c_basis, semilattice=tmc_sl, full_basis_list=tm_full_c_blist) h = LinearSpanTensorizedParametricFunctional( h_basis, semilattice=tmh_sl, full_basis_list=tm_full_h_blist) # Assemble component comp = type(tm_comp)(c, h) # Append to list of components and active variables approx_list.append(comp) active_vars.append(avars) new_map = type(transport_map)( active_vars=active_vars, approx_list=approx_list, full_c_basis_list=transport_map.full_c_basis_list, full_h_basis_list=transport_map.full_h_basis_list) return new_map
@staticmethod
[docs] def _prune_map( tm, coeffs_weights, coeff_trunc, method='active'): r""" With the option ``method==active`` all the ``active`` vertices will be considered for removal. With the option ``method==childless`` only the ``active`` childless vertices of the semilattices will be considered for removal. In both cases the roots will never be removed/inactivated. Args: coeffs_weights (list): weights for each degree of freedom. Must be ``len(coeffs_weights)==tm.n_coeffs``. Degrees of freedoms with lower coefficients will be removed. method (str): method for pruning. Available options are ``childless``, ``active``. """ tm = copy.deepcopy(tm) # Identify coefficients that are allowed to be removed removable_flag_list = [] # List of booleans marking whether one dof may be removed (True) for comp in tm.approx_list: # Constant part sl = comp.c.semilattice for v in sl.dof: if v is sl.root: removable_flag_list.append( False ) else: if method == 'childless': removable_flag_list.append( v in sl.childless ) elif method == 'active': removable_flag_list.append( True ) else: raise ValueError("Unrecognized pruning method") # Integrated squared part sl = comp.h.semilattice for v in sl.dof: if v is sl.root: removable_flag_list.append( False ) else: if method == 'childless': removable_flag_list.append( v in sl.childless ) elif method == 'active': removable_flag_list.append( True ) else: raise ValueError("Unrecognized pruning method") # Identify the truncation value rem_coeffs_weights = sorted( [ cw for flag, cw in zip(removable_flag_list, coeffs_weights) if flag ] ) if isinstance(coeff_trunc, float): coeff_trunc = {'type': 'value', 'val': coeff_trunc} elif not isinstance(coeff_trunc, dict): raise ValueError( "coeff_trunc must be either a float or a dictionary.") if coeff_trunc['type'] == 'constant': idx = coeff_trunc['val'] if coeff_trunc['val'] < len(rem_coeffs_weights) else -1 trunc_val = rem_coeffs_weights[idx] elif coeff_trunc['type'] == 'percentage': idx = int( np.ceil( len(rem_coeffs_weights) * coeff_trunc['val'] ) ) trunc_val = rem_coeffs_weights[idx] elif coeff_trunc['type'] == 'manual': if not PLOT_SUPPORT: raise ImportError( "The 'manual' truncation type requires plotting, but " + \ "plotting is not supported on this machine. " + \ "The program will be terminated.") import matplotlib.pyplot as plt plt.figure() plt.semilogy( rem_coeffs_weights[::-1], 'o-' ) plt.grid(which='major', linewidth=1.) plt.grid(which='minor', linewidth=.3) plt.show(False) trunc_val = None while trunc_val is None: try: istr = cmdinput("Provide a truncation level [q to quit]: ") trunc_val = float( istr ) except ValueError: if istr == 'q': return tm, 'quit' print("The value entered cannot be casted to a floating point number.") elif coeff_trunc['type'] == 'value': trunc_val = coeff_trunc['val'] else: raise ValueError( "Unrecognized truncation type. Available options are: " + \ "constant, percentage, manual.") log_params = { 'coeff_trunc': copy.deepcopy(coeff_trunc), 'trunc_val': trunc_val } # Run through the degrees of freedom and inactivate them j = 0 for comp in tm.approx_list: # Constant part inactive_list = [] sl = comp.c.semilattice for v in sl.dof: removable = removable_flag_list[j] if removable and abs(v.coeff) < trunc_val: inactive_list.append( v ) j += 1 for v in inactive_list: sl.set_inactive( v ) # Integrated squared part inactive_list = [] sl = comp.h.semilattice for v in sl.dof: removable = removable_flag_list[j] if removable and abs(v.coeff) < trunc_val: inactive_list.append( v ) j += 1 for v in inactive_list: sl.set_inactive( v ) # Clean up the semilattices # Run through the childless vertices and if inactive, remove them nrem = 1 # Keep removing untill nothing can be removed (i.e. every childless is active) while nrem > 0: nrem = 0 for comp in tm.approx_list: # Constant part rm_lst = [] sl = comp.c.semilattice for v in sl.childless: if v not in sl.dof: rm_lst.append( v ) for v in rm_lst: sl.delete_vertex( v ) nrem += len(rm_lst) # Integrated squared part rm_lst = [] sl = comp.h.semilattice for v in sl.childless: if v not in sl.dof: rm_lst.append( v ) for v in rm_lst: sl.delete_vertex( v ) nrem += len(rm_lst) # Remove dimensions that have been inactivated by the pruning for icomp, comp in enumerate(tm.approx_list): avars_old = tm.active_vars[icomp] c = comp.c h = comp.h.h # Figure out which active variables are not active anymore # One just needs to look at the children that the root has left. avars_idxs = set([ comp.dim_in-1 ]) sl = c.semilattice # Constant part avars_idxs |= sl.root.children.keys() sl = h.semilattice # Integrated squared part avars_idxs |= sl.root.children.keys() avars_idxs_compl = set([ idx for idx in range(comp.dim_in) if idx not in avars_idxs ]) # Inactive variables indices avars_new = sorted([ avars_old[i] for i in avars_idxs ]) # Re-sort the input dimensions of the semilattices to have # first the active indices (sorted) and then the inactive ones. # Then we can just trim the dimension of the semilattice # as all the trailing dimensions are inactive. perm_idxs = sorted(list(avars_idxs)) + list(avars_idxs_compl) c.semilattice = SL.permute(c.semilattice, perm_idxs) h.semilattice = SL.permute(h.semilattice, perm_idxs) c.semilattice.modify_dims(subtract_dims=len(avars_idxs_compl)) h.semilattice.modify_dims(subtract_dims=len(avars_idxs_compl)) # Modify basis list for the tensorized parametric functionals c.basis_list = [ c.basis_list[i] for i in sorted(list(avars_idxs)) ] h.basis_list = [ h.basis_list[i] for i in sorted(list(avars_idxs)) ] # Modify dimensions c.dim_in = len(avars_idxs) h.dim_in = len(avars_idxs) comp.dim_in = len(avars_idxs) # Modify active variables for the component tm.active_vars[icomp] = avars_new # Return return tm, 'success', log_params
def map_structure_str(tm, indent, verbosity=0): out = indent + "Number of degrees of freedom: %d\n" % tm.n_coeffs cmp_str = "Component %%%dd" % len(str(tm.dim)) if verbosity > 1: for d, (comp, avars) in enumerate(zip(tm.approx_list, tm.active_vars)): out += indent + cmp_str % d + " #D.o.F: %d - " % comp.n_coeffs + \ "Active variables: %s\n" % str(avars) out += indent + " Const part - midxs: %s\n" % (comp.c.multi_idxs) if verbosity > 1: out += indent + " Const part - coeffs: %s\n" % ([ "%.2e" % c for c in comp.c.coeffs]) out += indent + " Integ part - midxs: %s\n" % (comp.h.multi_idxs) if verbosity > 1: out += indent + " Integ part - coeffs: %s\n" % ([ "%.2e" % c for c in comp.h.coeffs]) return out def map_sparsity_str(tm, indent, verbosity=0): navar = sum([len(avars) for avars in tm.active_vars]) totvar = (tm.dim+1)*tm.dim / 2 out = indent + "Map sparsity: %d/%d (%.4f%%)\n" % ( navar, totvar, float(navar)/float(totvar)*100) if verbosity == 1: for d, (comp, avars) in enumerate(zip(tm.approx_list, tm.active_vars)): out += indent + \ ' comp %d - avars: %s ' % (d, str(avars)) + \ '- maxord: %d' % (max( np.max(comp.c.multi_idxs), np.max(comp.h.multi_idxs))) + '\n' if verbosity > 1: for d, avars in enumerate(tm.active_vars): str_list = [' '] * tm.dim for var in avars: str_list[var] = 'x' out += indent + ' |' + ''.join(str_list) + '|' + '\n' return out def is_equal_map(tm1, tm2): # TODO: update with semilattices match = True for d, (c1, a1, c2, a2) in enumerate(zip( tm1.approx_list, tm1.active_vars, tm2.approx_list, tm2.active_vars)): match = set(a1) == set(a2) if not match: break match = c1.c.semilattice == c2.c.semilattice if not match: break match = c1.h.semilattice == c2.h.semilattice if not match: break return match