Source code for TransportMaps.Maps.ConditionalTriangularTransportMapBase

#
# This file is part of TransportMaps.
#
# TransportMaps is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# TransportMaps is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with TransportMaps.  If not, see <http://www.gnu.org/licenses/>.
#
# Transport Maps Library
# Copyright (C) 2015-2018 Massachusetts Institute of Technology
# Uncertainty Quantification group
# Department of Aeronautics and Astronautics
#
# Authors: Transport Map Team
# Website: transportmaps.mit.edu
# Support: transportmaps.mit.edu/qa/
#

import numpy as np

from .TriangularComponentwiseTransportMapBase import TriangularComponentwiseTransportMap

from ..Misc import \
    required_kwargs,\
    counted, \
    cached

from .SlicedTransportMapBase import SlicedTransportMap

__all__ = [
    'ConditionalTriangularTransportMap'
]

nax = np.newaxis


[docs]class ConditionalTriangularTransportMap(SlicedTransportMap): r""" Takes the transport map :math:`T({\bf x})` and construct the transport map :math:`S_{{\bf y}}({\bf x}) := [T({\bf y}, {\bf x})]_{d_y:}`. """ @required_kwargs('base_map', 'y') def __init__(self, **kwargs): r""" Args: base_map (:class:`TriangularComponentwiseTransportMap`): map :math:`T` y (:class:`ndarray<numpy.ndarray>` [:math:`d_y`]): values of :math:`{\bf y}` """ base_map = kwargs['base_map'] y = kwargs['y'] if not isinstance(base_map, TriangularComponentwiseTransportMap): raise ValueError( "The base_map must be a TriangularComponentwiseTransportMap" ) dy = y.shape[0] kwargs['idxs_fix'] = list(range(dy)) kwargs['idxs_out'] = list(range(dy, base_map.dim)) super(ConditionalTriangularTransportMap, self).__init__(**kwargs) @cached() @counted
[docs] def log_det_grad_x(self, x, *args, **kwargs): xin = self._xin(x) xout = np.zeros(x.shape[0]) # FIXME: sliced_approx_list and sliced_active_vars seem not to be defined anywhere for i,(a, avar) in enumerate(zip( self.sliced_approx_list, self.sliced_active_vars)): xout += a.partial_xd( xin[:,avar] ) return xout
@cached() @counted
[docs] def grad_x_log_det_grad_x(self, x, *args, **kwargs): xin = self._xin(x) xout = np.zeros(x.shape) leny = len(self.idxs_fix) for i,(a, avar) in enumerate(zip( self.sliced_approx_list, self.sliced_active_vars)): avarout = [ v for v in avar if v >= leny ] xout[:,avarout] += a.grad_x_partial_xd( xin[:,avar] )[:,leny:] / \ a.partial_xd( xin[:,avar] )[:, nax] return xout
@cached() @counted
[docs] def hess_x_log_det_grad_x(self, x, *args, **kwargs): xin = self._xin(x) xout = np.zeros((x.shape[0],self.dim,self.dim)) leny = len(self.idxs_fix) for i,(a, avar) in enumerate(zip( self.sliced_approx_list, self.sliced_active_vars)): avarout = [ v for v in avar if v >= leny ] # 2d numpy advanced indexing nvar = len(avarout) rr,cc = np.meshgrid(avarout,avarout) rr = list( rr.flatten() ) cc = list( cc.flatten() ) idxs = (slice(None), rr, cc) # Evaluate dxk = a.partial_xd(x[:,avar], p) xout[idxs] += (a.hess_x_partial_xd( xin[:,avar] )[:,leny:,leny:] / \ dxk[:,nax,nax]).reshape((x.shape[0], nvar**2)) dxdxkT = a.grad_x_partial_xd(x[:,avar], p)[:,leny:] dxdxkT2 = dxdxkT[:,:,nax] * dxdxkT[:,nax,:] xout[idxs] -= (dxdxkT2 / (dxk**2.)[:,nax,nax]).reshape((x.shape[0],nvar**2)) return xout
@counted
[docs] def inverse(self, x, *args, **kwargs): xin = self._xin(x) xout = np.zeros(xin.shape) skip_dim = len(self.idxs_fix) xout[:,:skip_dim] = xin[:,:skip_dim] for i in range(x.shape[0]): for k, (a,avar) in enumerate(zip( self.sliced_approx_list, self.sliced_active_vars)): xout[i,skip_dim+k] = a.inverse(xout[i,avar[:-1]], xin[i,skip_dim+k]) return xout[:,skip_dim:]
@counted
[docs] def log_det_grad_x_inverse(self, x, precomp=None, idxs_slice=slice(None), **kwargs): try: xinv = precomp['xinv'] except (TypeError, KeyError): xinv = self.inverse(x, precomp) return - self.log_det_grad_x( xinv )