Source code for TransportMaps.Maps.SlicedMapBase

#
# This file is part of TransportMaps.
#
# TransportMaps is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# TransportMaps is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with TransportMaps.  If not, see <http://www.gnu.org/licenses/>.
#
# Transport Maps Library
# Copyright (C) 2015-2018 Massachusetts Institute of Technology
# Uncertainty Quantification group
# Department of Aeronautics and Astronautics
#
# Authors: Transport Map Team
# Website: transportmaps.mit.edu
# Support: transportmaps.mit.edu/qa/
#

import numpy as np

from ..Misc import \
    required_kwargs, \
    counted, cached
from .MapBase import Map

__all__ = [
    'SlicedMap'
]

nax = np.newaxis


[docs]class SlicedMap(Map): r""" Takes the map :math:`T({\bf x})` and construct the map :math:`S_{\bf y}({\bf x}) := [T({\bf y}_{\bf i} \cup {\bf x}_{\neg{\bf i}})]_{\bf j}`, where :math:`S_{\bf y}:\mathbb{R}^{\sharp(\neg{\bf i})}\rightarrow\mathbb{R}^{\sharp{\bf j}}`. """ @required_kwargs('base_map', 'y', 'idxs_fix', 'idxs_out') def __init__(self, **kwargs): r""" Args: base_map (:class:`Map`): map :math:`T` y (:class:`ndarray<numpy.ndarray>` [:math:`d_y`]): values of :math:`{\bf y}_{\bf i}` idxs_fix (:class:`list`): list of indices :math:`{\bf i}` idxs_out (:class:`list`): list of indeices :math:`{\bf j}` """ base_map = kwargs['base_map'] y = kwargs['y'] idxs_fix = kwargs['idxs_fix'] idxs_out = kwargs['idxs_out'] if len(y) != len(idxs_fix): raise ValueError("The length of y and idxs_fix must be the same") if len(set(idxs_fix)) != len(idxs_fix): raise ValueError("idxs_fix must contain unique values") if len(idxs_fix) > base_map.dim_in: raise ValueError("idxs_fix must be a subset of the input dimensions of base_map") self.base_map = base_map self.y = y self.idxs_fix = idxs_fix self.idxs_var = [ i for i in range(base_map.dim_in) if i not in idxs_fix ] self.idxs_out = idxs_out kwargs['dim_in'] = len(self.idxs_var) kwargs['dim_out'] = len(self.idxs_out) super(SlicedMap, self).__init__( **kwargs )
[docs] def _xin(self, x): xin = np.zeros((x.shape[0], self.base_map.dim)) xin[:,self.idxs_fix] = self.y[nax,:] xin[:,self.idxs_var] = x return xin
@cached() @counted
[docs] def evaluate(self, x, **kwargs): return self.base_map.evaluate( self._xin(x), **kwargs)[:,self.idxs_out]
@cached() @counted
[docs] def grad_x(self, x, **kwargs): return self.base_map.grad_x( self._xin(x), **kwargs)[:,self.idxs_out,self.idxs_var]
@cached(caching=False) @counted
[docs] def hess_x(self, x, **kwargs): return self.base_map.hess_x( self._xin(x), **kwargs)[ :,self.idxs_out,self.idxs_var, self.idxs_var]
@cached(caching=False) @counted
[docs] def action_hess_x(self, x, dx, **kwargs): return self.base_map.action_hess_x( # There may be a problem here with dx self._xin(x), dx, **kwargs)[ :,self.idxs_out,self.idxs_var]