Source code for TransportMaps.CLI.DeepLazyMapsConstructionScriptBase

# This file is part of TransportMaps.
# TransportMaps is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# TransportMaps is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# GNU Lesser General Public License for more details.
# You should have received a copy of the GNU Lesser General Public License
# along with TransportMaps.  If not, see <>.
# Transport Maps Library
# Copyright (C) 2015-2018 Massachusetts Institute of Technology
# Uncertainty Quantification group
# Department of Aeronautics and Astronautics
# Author: Transport Map Team
# Website:
# Support:

import sys

from .ConstructionScriptBase import ConstructionScript

import TransportMaps as TM
import TransportMaps.Algorithms.DeepLazyMaps as ALGDEEP

__all__ = ['DeepLazyMapsConstructionScript']

[docs]class DeepLazyMapsConstructionScript( ConstructionScript ):
[docs] cmd_usage_str = "Usage: tmap-deep-lazy-tm "
[docs] opts_usage_str = \ ConstructionScript.opts_usage_str + \ " [--lazy-eps=FLOAT --lazy-maxit=INT \n" + \ " --lazy-random-rotations --lazy-random-rotations-step=INT \n" + \ " --lazy-rank-max=INT --lazy-rank-eps=FLOAT \n" + \ " --lazy-rank-qtype=INT --lazy-rank-qnum=INT,.. \n" + \ " --lazy-hard-truncation \n" + \ " --lazy-ht-qtype=INT --lazy-ht-qnum=INT,.. \n" + \ " --lazy-var-diag-qtype=INT --lazy-var-diag-qnum=INT,.. \n" + \ " --lazy-plot] \n"
[docs] docs_descr_str = """DESCRIPTION Given a file (--input) storing the target distribution, produce the transport map that pushes forward the base distribution (standard normal) to the target distribution, using the composition (deep) of lazy (low-rank) maps. All files involved are stored and loaded using the python package pickle."""
[docs] docs_options_str = \ ConstructionScript.docs_options_str + \ """ OPTIONS -- deep-lazy construction: --lazy-eps=FLOAT target tolerance --lazy-maxit=INT maximum number of iterations of the algorithm --lazy-random-rotations apply random rotations instead of target informed ones --lazy-random-rorations-step=INT number of greedy steps between each random rotation (default: 1) --lazy-rank-max=INT maximum rank allowed for the lazy maps --lazy-rank-eps=FLOAT cumulative power of the ignored sub-space, or 'manual' for manual selection at each iteration --lazy-rank-qtype=INT quadrature type for computing the low-rank sub-space --lazy-rank-qnum=INT,.. quadrature parameters for computing the low-rank subspace --lazy-hard-truncation whether to use the pi^\star formulation of the algorithm --lazy-ht-qtype=INT quadrature type to use in the evaluation of the conditional expectation --lazy-ht-qnum=INT,.. quadrature parameters in the conditional expectation --lazy-var-diag-qtype=INT quadrature type to estimate convergence --lazy-var-diag-qnum=INT,.. quadrature parameters to estimate convergence --lazy-plot whether to plot progress """
[docs] def long_options(self): return super(DeepLazyMapsConstructionScript, self).long_options + \ [ 'lazy-eps=', 'lazy-maxit=', 'lazy-random-rotations', 'lazy-random-rotations-step=', 'lazy-rank-max=', 'lazy-rank-eps=', 'lazy-rank-qtype=', 'lazy-rank-qnum=', 'lazy-hard-truncation', 'lazy-ht-qtype=', 'lazy-ht-qnum=', 'lazy-var-diag-qtype=', 'lazy-var-diag-qnum=', 'lazy-plot' ]
[docs] def _load_opts(self, opts): super(DeepLazyMapsConstructionScript, self)._load_opts( opts ) for opt, arg in opts: if opt == '--lazy-eps': self.stg.LAZY_EPS = float(arg) elif opt == '--lazy-maxit': self.stg.LAZY_MAXIT = int(arg) elif opt == '--lazy-random-rotations': self.stg.LAZY_RANDOM_ROTATIONS = True elif opt == '--lazy-random-rotations-step': self.stg.LAZY_RANDOM_ROTATIONS_STEP = int(arg) elif opt == '--lazy-rank-max': self.stg.LAZY_RANK_MAX = int(arg) elif opt == '--lazy-rank-eps': try: self.stg.LAZY_RANK_EPS = float(arg) except ValueError: if arg != 'manual': self.usage() self.tstamp_print( "ERROR: Unrecognized option for --lazy-rank-eps") sys.exit(3) self.stg.LAZY_RANK_EPS = arg elif opt == '--lazy-rank-qtype': self.stg.LAZY_RANK_QTYPE = int(arg) elif opt == '--lazy-rank-qnum': self.stg.LAZY_RANK_QNUM = [int(q) for q in arg.split(',')] elif opt == '--lazy-hard-truncation': self.stg.LAZY_HARD_TRUNCATION = True elif opt == '--lazy-ht-qtype': self.stg.LAZY_HT_QTYPE = int(arg) elif opt == '--lazy-ht-qnum': self.stg.LAZY_HT_QNUM = [int(q) for q in arg.split(',')] elif opt == '--lazy-var-diag-qtype': self.stg.LAZY_VAR_DIAG_QTYPE = int(arg) elif opt == '--lazy-var-diag-qnum': self.stg.LAZY_VAR_DIAG_QNUM = [int(q) for q in arg.split(',')] elif opt == '--lazy-plot': self.LAZY_PLOT = True
[docs] def _init_self_variables(self): super(DeepLazyMapsConstructionScript, self)._init_self_variables() self.stg.LAZY_EPS = 1e-2 self.stg.LAZY_MAXIT = 20 self.stg.LAZY_RANDOM_ROTATIONS = False self.stg.LAZY_RANDOM_ROTATIONS_STEP = 1 self.stg.LAZY_RANK_MAX = 3 self.stg.LAZY_RANK_EPS = 1e-2 self.stg.LAZY_RANK_QTYPE = 0 self.stg.LAZY_RANK_QNUM = [20] self.stg.LAZY_HARD_TRUNCATION = False self.stg.LAZY_HT_QTYPE = 0 self.stg.LAZY_HT_QNUM = [100] self.stg.LAZY_VAR_DIAG_QTYPE = 0 self.stg.LAZY_VAR_DIAG_QNUM = [100] self.LAZY_PLOT = False
[docs] def _check_required_args(self): super(DeepLazyMapsConstructionScript, self)._check_required_args() if self.stg.LAZY_RANK_QTYPE < 3: self.stg.LAZY_RANK_QNUM = self.stg.LAZY_RANK_QNUM[0] if self.stg.LAZY_HT_QTYPE < 3: self.stg.LAZY_HT_QNUM = self.stg.LAZY_HT_QNUM[0] if self.stg.LAZY_VAR_DIAG_QTYPE < 3: self.stg.LAZY_VAR_DIAG_QNUM = self.stg.LAZY_VAR_DIAG_QNUM[0]
[docs] def load(self): super(DeepLazyMapsConstructionScript, self).load() if not self.stg.LAZY_HARD_TRUNCATION: if not issubclass(type(self.stg.tm_factory), ALGDEEP.DeepLazyMapFactory): raise ValueError( "The provided map factory must be a subclass of DeepLazyMapFactory" ) if self.RELOAD: self.stg.assembler.callback = self.safe_store self.stg.assembler.callback_kwargs = {} else: self.stg.assembler_state = TM.DataStorageObject() # Build the assembler self.stg.assembler = ALGDEEP.DeepLazyMapsAssembler( builder = self.stg.builder, map_factory = self.stg.tm_factory, eps = self.stg.LAZY_EPS, maxit = self.stg.LAZY_MAXIT, random_rotations = self.stg.LAZY_RANDOM_ROTATIONS, random_rotations_step = self.stg.LAZY_RANDOM_ROTATIONS_STEP, rank_max = self.stg.LAZY_RANK_MAX, rank_eps = self.stg.LAZY_RANK_EPS, rank_qtype = self.stg.LAZY_RANK_QTYPE, rank_qparams = self.stg.LAZY_RANK_QNUM, hard_truncation = self.stg.LAZY_HARD_TRUNCATION, ht_qtype = self.stg.LAZY_HT_QTYPE, ht_qparams = self.stg.LAZY_HT_QNUM, var_diag_qtype = self.stg.LAZY_VAR_DIAG_QTYPE, var_diag_qparams = self.stg.LAZY_VAR_DIAG_QNUM, callback = self.safe_store, callback_kwargs = {} )
[docs] def _solve(self, mpi_pool=None): if not self.RELOAD: # Assemble function kwargs self.stg.assembler_assemble_kwargs = { 'target_distribution': self.stg.preconditioned_target_distribution, 'builder_solve_params': self.stg.solve_params } return self.stg.assembler.assemble( state = self.stg.assembler_state, mpi_pool = mpi_pool, plotting = self.LAZY_PLOT, **self.stg.assembler_assemble_kwargs )