#
# This file is part of TransportMaps.
#
# TransportMaps is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# TransportMaps is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with TransportMaps. If not, see <http://www.gnu.org/licenses/>.
#
# Transport Maps Library
# Copyright (C) 2015-2018 Massachusetts Institute of Technology
# Uncertainty Quantification group
# Department of Aeronautics and Astronautics
#
# Authors: Transport Map Team
# Website: transportmaps.mit.edu
# Support: transportmaps.mit.edu/qa/
#
import numpy as np
from ..Misc import \
required_kwargs, \
counted, cached, cached_tuple, get_sub_cache
from .MapBase import Map
__all__ = [
'ListStackedMap',
]
nax = np.newaxis
[docs]class ListStackedMap(Map):
r""" Defines the map :math:`T` obtained by stacking :math:`T_1, T_2, \ldots`.
.. math::
T({\bf x}) = \left[
\begin{array}{c}
T_1({\bf x}_{0:d_1}) \\
T_2({\bf x}_{0:d_2}) \\
\vdots
\end{array}
\right]
"""
@required_kwargs('map_list', 'active_vars')
def __init__(self, **kwargs):
r"""
Args:
map_list (:class:`list` of :class:`Map`): list of transport maps :math:`T_i`
active_vars (:class:`list` of :class:`list` of :class:`int`): active variables for each map
"""
map_list = kwargs['map_list']
active_vars = kwargs['active_vars']
if active_vars is None:
dim_in = max( [ tm.dim_in for tm in map_list ] )
self.active_vars = [ list(range(tm.dim_in)) for tm in map_list ]
else:
dim_in = max( [ max(avars) for avars in active_vars ] ) + 1
self.active_vars = active_vars
dim_out = sum( [tm.dim_out for tm in map_list] )
self.map_list = map_list
kwargs['dim_in'] = dim_in
kwargs['dim_out'] = dim_out
super(ListStackedMap, self).__init__(**kwargs)
@property
[docs] def map_list(self):
try:
return self._map_list
except AttributeError:
# Backward compatibility v < 3.0
return self.tm_list
@map_list.setter
def map_list(self, map_list):
self._map_list = map_list
@property
[docs] def active_vars(self):
return self._active_vars
@active_vars.setter
def active_vars(self, avars):
self._active_vars = avars
[docs] def get_ncalls_tree(self, indent=""):
out = Map.get_ncalls_tree(self, indent)
for i, tm in enumerate(self.map_list):
out += tm.get_ncalls_tree(indent + " T%d - " % i)
return out
[docs] def get_nevals_tree(self, indent=""):
out = Map.get_nevals_tree(self, indent)
for i, tm in enumerate(self.map_list):
out += tm.get_nevals_tree(indent + " T%d - " % i)
return out
[docs] def get_teval_tree(self, indent=""):
out = Map.get_teval_tree(self, indent)
for i, tm in enumerate(self.map_list):
out += tm.get_teval_tree(indent + " T%d - " % i)
return out
[docs] def update_ncalls_tree(self, obj):
super(ListStackedMap, self).update_ncalls_tree(obj)
for i, (tm, obj_tm) in enumerate(zip(self.map_list, obj.map_list)):
tm.update_ncalls_tree(obj_tm)
[docs] def update_nevals_tree(self, obj):
super(ListStackedMap, self).update_nevals_tree(obj)
for i, (tm, obj_tm) in enumerate(zip(self.map_list, obj.map_list)):
tm.update_nevals_tree(obj_tm)
[docs] def update_teval_tree(self, obj):
super(ListStackedMap, self).update_teval_tree(obj)
for i, (tm, obj_tm) in enumerate(zip(self.map_list, obj.map_list)):
tm.update_teval_tree(obj_tm)
[docs] def reset_counters(self):
super(ListStackedMap, self).reset_counters()
for tm in self.map_list:
tm.reset_counters()
@property
[docs] def n_maps(self):
return len(self.map_list)
@cached([('map_list',"n_maps")], False)
@counted
[docs] def evaluate(self, x, precomp=None, idxs_slice=slice(None), cache=None):
if x.shape[1] != self.dim_in:
raise ValueError("dimension mismatch")
map_list_cache = get_sub_cache(cache, ('map_list',self.n_maps))
out = np.zeros((x.shape[0], self.dim_out))
start = 0
for tm, avars, tm_cache in zip(self.map_list, self.active_vars, map_list_cache):
stop = start + tm.dim_out
out[:,start:stop] = tm.evaluate(
x[:,avars], idxs_slice=idxs_slice, cache=tm_cache)
start = stop
return out
@cached([('map_list',"n_maps")], False)
@counted
[docs] def grad_x(self, x, precomp=None, idxs_slice=slice(None), cache=None):
if x.shape[1] != self.dim_in:
raise ValueError("dimension mismatch")
map_list_cache = get_sub_cache(cache, ('map_list',self.n_maps))
out = np.zeros((x.shape[0], self.dim_out, self.dim_in))
start = 0
for tm, avars, tm_cache in zip(self.map_list, self.active_vars, map_list_cache):
stop = start + tm.dim_out
out[:,start:stop,avars] = tm.grad_x(
x[:,avars], idxs_slice=idxs_slice, cache=tm_cache)
start = stop
return out
@cached_tuple(['evaluate','grad_x'],[('map_list',"n_maps")], False)
@counted
[docs] def tuple_grad_x(self, x, *args, **kwargs):
if x.shape[1] != self.dim_in:
raise ValueError("dimension mismatch")
ev = self.evaluate(x, *args, **kwargs)
gx = self.grad_x(x, *args, **kwargs)
return ev, gx
@cached([('map_list',"n_maps")],False)
@counted
[docs] def hess_x(self, x, precomp=None, idxs_slice=slice(None), cache=None):
if x.shape[1] != self.dim_in:
raise ValueError("dimension mismatch")
map_list_cache = get_sub_cache(cache, ('map_list',self.n_maps))
out = np.zeros((x.shape[0], self.dim_out, self.dim_in, self.dim_in))
start = 0
for tm, avars, tm_cache in zip(self.map_list, self.active_vars, map_list_cache):
stop = start + tm.dim_out
# 2d numpy advanced indexing
nvar = len(avars)
ll, rr, cc = np.meshgrid(range(start,stop),avars,avars)
ll = list( ll.flatten() )
rr = list( rr.flatten() )
cc = list( cc.flatten() )
idxs = (slice(None), ll, rr, cc)
# Evaluate
hx = tm.hess_x(x[:,avars], idxs_slice=idxs_slice, cache=tm_cache)
if hx.ndim == 3:
out[idxs] = hx.reshape((stop-start)*nvar**2)[nax,:]
else:
out[idxs] = hx.reshape(
(x.shape[0], (stop-start)*nvar**2) )
start = stop
return out
@cached([('map_list',"n_maps")],False)
@counted
[docs] def action_hess_x(self, x, dx, precomp=None, idxs_slice=slice(None), cache=None):
if x.shape[1] != self.dim_in:
raise ValueError("dimension mismatch")
map_list_cache = get_sub_cache(cache, ('map_list',self.n_maps))
out = np.zeros((x.shape[0], self.dim_out, self.dim_in))
start = 0
for tm, avars, tm_cache in zip(self.map_list, self.active_vars, map_list_cache):
stop = start + tm.dim_out
# 2d numpy advanced indexing
nvar = len(avars)
ll, rr = np.meshgrid(range(start,stop),avars)
ll = list( ll.flatten() )
rr = list( rr.flatten() )
idxs = (slice(None), ll, rr)
# Evaluate
ahx = tm.action_hess_x(x[:,avars], dx[:,avars], idxs_slice=idxs_slice, cache=tm_cache)
out[:,start:stop,avars] = ahx
start = stop
return out