Source code for TransportMaps.Maps.IdentityEmbeddedTransportMapBase

#
# This file is part of TransportMaps.
#
# TransportMaps is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# TransportMaps is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with TransportMaps.  If not, see <http://www.gnu.org/licenses/>.
#
# Transport Maps Library
# Copyright (C) 2015-2018 Massachusetts Institute of Technology
# Uncertainty Quantification group
# Department of Aeronautics and Astronautics
#
# Authors: Transport Map Team
# Website: transportmaps.mit.edu
# Support: transportmaps.mit.edu/qa/
#

import numpy as np

from ..Misc import \
    required_kwargs, \
    counted, cached, get_sub_cache

from .TransportMapBase import TransportMap

__all__ = [
    'IdentityEmbeddedTransportMap',
]


[docs]class IdentityEmbeddedTransportMap(TransportMap): @required_kwargs('tm', 'idxs', 'dim') def __init__(self, **kwargs): tm = kwargs.pop('tm') idxs = kwargs.pop('idxs') if not isinstance(tm, TransportMap): raise AttributeError("tm must be a TransportMap.") if len(idxs) != tm.dim: raise ValueError( "The dimension of tm must match the number of idxs.") if kwargs['dim'] <= max(idxs): raise ValueError( "The dimension of the new map must be > than the " + \ "maximum idxs.") self.tm = tm self.idxs = idxs super(IdentityEmbeddedTransportMap, self).__init__( **kwargs ) @cached([('tm', None)]) @counted
[docs] def evaluate(self, x, precomp=None, idxs_slice=slice(None), cache=None): if x.shape[1] != self.dim_in: raise ValueError("dimension mismatch") tm_cache = get_sub_cache(cache, ('tm', None)) out = x.copy() out[:,self.idxs] = self.tm.evaluate( x[:,self.idxs], precomp, idxs_slice, cache=tm_cache) return out
@counted
[docs] def inverse(self, x, precomp=None, idxs_slice=slice(None)): if x.shape[1] != self.dim_in: raise ValueError("dimension mismatch") out = x.copy() out[:,self.idxs] = self.tm.inverse( x[:,self.idxs], precomp, idxs_slice) return out
@cached([('tm', None)],False) @counted
[docs] def grad_x(self, x, precomp=None, idxs_slice=slice(None), cache=None): if x.shape[1] != self.dim_in: raise ValueError("dimension mismatch") tm_cache = get_sub_cache(cache, ('tm', None)) m = x.shape[0] out = np.zeros( (m, self.dim, self.dim) ) out[:,range(self.dim),range(self.dim)] = 1. out[np.ix_(range(m),self.idxs, self.idxs)] = self.tm.grad_x( x[:,self.idxs], precomp, idxs_slice, cache=tm_cache) return out
@cached([('tm', None)],False) @counted
[docs] def action_grad_x(self, x, dx, precomp=None, idxs_slice=slice(None), cache=None): if x.shape[1] != self.dim_in: raise ValueError("dimension mismatch") tm_cache = get_sub_cache(cache, ('tm', None)) gx_tm = self.tm.grad_x( x[:,self.idxs], precomp, idxs_slice, cache=tm_cache) out = dx.copy() idxs = tuple( [slice(None)]*(dx.ndim-1) + [self.idxs] ) out[idxs] = np.einsum('...jk,...k->...j', gx_tm, dx[idxs]) return out
@cached([('tm', None)],False) @counted
[docs] def action_adjoint_grad_x(self, x, dx, precomp=None, idxs_slice=slice(None), cache=None): if x.shape[1] != self.dim_in: raise ValueError("dimension mismatch") tm_cache = get_sub_cache(cache, ('tm', None)) gx_tm = self.tm.grad_x( x[:,self.idxs], precomp, idxs_slice, cache=tm_cache) out = dx.copy() idxs = tuple( [slice(None)]*(dx.ndim-1) + [self.idxs] ) if dx.ndim == 2: expr = '...j,...jk->...k' else: expr = '...ij,...jk->...ik' out[idxs] = np.einsum(expr, dx[idxs], gx_tm) return out
@cached([('tm', None)],False) @counted
[docs] def tuple_grad_x(self, x, precomp=None, idxs_slice=slice(None), cache=None): ev = self.evaluate(x, precomp=precomp, idxs_slice=idxs_slice, cache=cache) gx = self.grad_x(x, precomp=precomp, idxs_slice=idxs_slice, cache=cache) return ev, gx
@cached([('tm', None)],False) @counted
[docs] def action_tuple_grad_x(self, x, dx, precomp=None, idxs_slice=slice(None), cache=None): ev = self.evaluate(x, precomp=precomp, idxs_slice=idxs_slice, cache=cache) agx = self.action_grad_x(x, dx, precomp=precomp, idxs_slice=idxs_slice, cache=cache) return ev, agx
@cached([('tm', None)],False) @counted
[docs] def hess_x(self, x, precomp=None, idxs_slice=slice(None), cache=None): if x.shape[1] != self.dim_in: raise ValueError("dimension mismatch") tm_cache = get_sub_cache(cache, ('tm', None)) m = x.shape[0] out = np.zeros( (m, self.dim, self.dim, self.dim) ) out[np.ix_(range(m),self.idxs, self.idxs,self.idxs)] = \ self.tm.hess_x( x[:,self.idxs], precomp, idxs_slice, cache=tm_cache) return out
@cached([('tm',None)]) @counted
[docs] def log_det_grad_x(self, x, precomp=None, idxs_slice=slice(None), cache=None): if x.shape[1] != self.dim_in: raise ValueError("dimension mismatch") tm_cache = get_sub_cache(cache, ('tm', None)) return self.tm.log_det_grad_x(x[:,self.idxs], precomp=precomp, idxs_slice=idxs_slice, cache=tm_cache)
@cached([('tm',None)]) @counted
[docs] def grad_x_log_det_grad_x(self, x, precomp=None, idxs_slice=slice(None), cache=None): if x.shape[1] != self.dim_in: raise ValueError("dimension mismatch") tm_cache = get_sub_cache(cache, ('tm', None)) out = np.zeros((x.shape[0], self.dim)) out[:,self.idxs] = self.tm.grad_x_log_det_grad_x( x[:,self.idxs], precomp=precomp, idxs_slice=idxs_slice, cache=tm_cache) return out
@cached([('tm',None)],False) @counted
[docs] def hess_x_log_det_grad_x(self, x, precomp=None, idxs_slice=slice(None), cache=None): if x.shape[1] != self.dim_in: raise ValueError("dimension mismatch") tm_cache = get_sub_cache(cache, ('tm', None)) m = x.shape[0] out = np.zeros((m, self.dim, self.dim)) out[np.ix_(range(m),self.idxs, self.idxs)] = \ self.tm.hess_x_log_det_grad_x( x[:,self.idxs], precomp=precomp, idxs_slice=idxs_slice, cache=tm_cache) return out
@counted
[docs] def log_det_grad_x_inverse(self, x, *args, **kwargs): if x.shape[1] != self.dim_in: raise ValueError("dimension mismatch") return self.tm.log_det_grad_x_inverse(x[:,self.idxs], *args, **kwargs)