#
# This file is part of TransportMaps.
#
# TransportMaps is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# TransportMaps is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with TransportMaps. If not, see <http://www.gnu.org/licenses/>.
#
# Transport Maps Library
# Copyright (C) 2015-2018 Massachusetts Institute of Technology
# Uncertainty Quantification group
# Department of Aeronautics and Astronautics
#
# Authors: Transport Map Team
# Website: transportmaps.mit.edu
# Support: transportmaps.mit.edu/qa/
#
import numpy as np
from ..Maps import TransportMap
from .DistributionBase import Distribution
__all__ = [
'TransportMapDistribution',
]
[docs]class TransportMapDistribution(Distribution):
r""" Abstract class for distributions of the transport map type (:math:`T^\sharp \pi` or :math:`T_\sharp \pi`)
.. seealso:: :class:`PushForwardTransportMapDistribution` and :class:`PullBackTransportMapDistribution`.
"""
def __init__(
self,
transport_map: TransportMap,
base_distribution: Distribution
):
r"""
Args:
transport_map (:class:`TransportMap<TransportMaps.Maps.TransportMap>`): transport map :math:`T`
base_distribution (:class:`Distribution`): distribution :math:`\pi`
"""
if transport_map.dim != base_distribution.dim:
raise ValueError(
"The transport_map and the base_distribution should have " +
"the same dimension"
)
super(TransportMapDistribution,self).__init__(dim=transport_map.dim)
self.transport_map = transport_map
self.base_distribution = base_distribution
[docs] def get_ncalls_tree(self, indent=""):
out = super(TransportMapDistribution, self).get_ncalls_tree(indent)
out += self.transport_map.get_ncalls_tree(indent + " ")
out += self.base_distribution.get_ncalls_tree(indent + ' ')
return out
[docs] def get_nevals_tree(self, indent=""):
out = super(TransportMapDistribution, self).get_nevals_tree(indent)
out += self.transport_map.get_nevals_tree(indent + " ")
out += self.base_distribution.get_nevals_tree(indent + ' ')
return out
[docs] def get_teval_tree(self, indent=""):
out = super(TransportMapDistribution, self).get_teval_tree(indent)
out += self.transport_map.get_teval_tree(indent + " ")
out += self.base_distribution.get_teval_tree(indent + ' ')
return out
[docs] def update_ncalls_tree(self, obj):
super(TransportMapDistribution, self).update_ncalls_tree(obj)
self.transport_map.update_ncalls_tree(obj.transport_map)
self.base_distribution.update_ncalls_tree(obj.base_distribution)
[docs] def update_nevals_tree(self, obj):
super(TransportMapDistribution, self).update_nevals_tree(obj)
self.transport_map.update_nevals_tree(obj.transport_map)
self.base_distribution.update_nevals_tree(obj.base_distribution)
[docs] def update_teval_tree(self, obj):
super(TransportMapDistribution, self).update_teval_tree(obj)
self.transport_map.update_teval_tree(obj.transport_map)
self.base_distribution.update_teval_tree(obj.base_distribution)
[docs] def reset_counters(self):
super(TransportMapDistribution, self).reset_counters()
self.transport_map.reset_counters()
self.base_distribution.reset_counters()
[docs] def rvs(self, m, mpi_pool=None, batch_size=None):
r""" Generate :math:`m` samples from the distribution.
Args:
m (int): number of samples to generate
mpi_pool (:class:`mpi_map.MPI_Pool<mpi_map.MPI_Pool>`): pool of processes
batch_size (int): whether to generate samples in batches
Returns:
(:class:`ndarray<numpy.ndarray>` [:math:`m,d`]) -- :math:`m`
:math:`d`-dimensional samples
"""
x, w = self.quadrature(0, m, mpi_pool=mpi_pool, batch_size=batch_size)
return x
[docs] def quadrature(self, qtype, qparams, mass=1.,
mpi_pool=None, **kwargs):
r""" Generate quadrature points and weights.
Args:
qtype (int): quadrature type number. The different types are defined in
the associated sub-classes.
qparams (object): inputs necessary to the generation of the selected
quadrature
mass (float): total mass of the quadrature (1 for probability measures)
mpi_pool (:class:`mpi_map.MPI_Pool<mpi_map.MPI_Pool>`): pool of processes
Return:
(:class:`tuple` (:class:`ndarray<numpy.ndarray>` [:math:`m,d`],
:class:`ndarray<numpy.ndarray>` [:math:`m`])) -- list of quadrature
points and weights
"""
if qtype in [4]:
return self.adaptive_quadrature(
qtype, qparams, mass=mass, mpi_pool=mpi_pool, **kwargs)
else:
(x, w) = self.base_distribution.quadrature(
qtype, qparams, mass=mass, mpi_pool=mpi_pool,**kwargs)
x = self.map_samples_base_to_target(x, mpi_pool=mpi_pool)
return (x, w)
[docs] def adaptive_quadrature(self, qtype, qparams, mass=1., mpi_pool=None, **kwargs):
if qtype in [4]:
if 'f' not in kwargs:
raise ValueError(
"This kind of adaptive quadrature requires the argument " + \
"integrand function to be provided as the argument f.")
f = kwargs['f']
kwargs['f'] = self.map_function_base_to_target(kwargs['f'])
x, w = self.base_distribution.quadrature(
qtype, qparams, mass=mass, mpi_pool=mpi_pool, **kwargs)
kwargs['f'] = f
else:
raise ValueError("Quadrature type not recognized")
x = self.map_samples_base_to_target(x, mpi_pool=mpi_pool)
return (x, w)
[docs] def map_samples_base_to_target(self, x, mpi_pool=None):
r""" [Abstract] Map input samples (assumed to be from :math:`\pi`) to the corresponding samples from :math:`T^\sharp \pi` or :math:`T_\sharp \pi`.
Args:
x (:class:`ndarray<numpy.ndarray>` [:math:`m,d`]): input samples
mpi_pool (:class:`mpi_map.MPI_Pool<mpi_map.MPI_Pool>`): pool of processes
Returns:
(:class:`ndarray<numpy.ndarray>` [:math:`m,d`]) -- corresponding samples
"""
raise NotImplementedError("Abstract method. Implement in sub-class.")
[docs] def map_samples_target_to_base(self, x, mpi_pool=None):
r""" [Abstract] Map input samples (assumed to be from :math:`T^\sharp \pi` or :math:`T_\sharp \pi`) to the corresponding samples from :math:`\pi`.
Args:
x (:class:`ndarray<numpy.ndarray>` [:math:`m,d`]): input samples
mpi_pool (:class:`mpi_map.MPI_Pool<mpi_map.MPI_Pool>`): pool of processes
Returns:
(:class:`ndarray<numpy.ndarray>` [:math:`m,d`]) -- corresponding samples
"""
raise NotImplementedError("Abstract method. Implement in sub-class.")
@staticmethod
[docs] def _evaluate_log_transport(lpdf, ldgx):
return lpdf + ldgx
@staticmethod
[docs] def _evaluate_grad_x_log_transport(gxlpdf, gx, gxldgx):
return np.einsum('...i,...ij->...j', gxlpdf, gx) + gxldgx