Source code for TransportMaps.Distributions.FactorizedDistributionBase

#
# This file is part of TransportMaps.
#
# TransportMaps is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# TransportMaps is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with TransportMaps.  If not, see <http://www.gnu.org/licenses/>.
#
# Transport Maps Library
# Copyright (C) 2015-2018 Massachusetts Institute of Technology
# Uncertainty Quantification group
# Department of Aeronautics and Astronautics
#
# Authors: Transport Map Team
# Website: transportmaps.mit.edu
# Support: transportmaps.mit.edu/qa/
#

import numpy as np

from ..Misc import counted, cached, cached_tuple

from .DistributionBase import Distribution
from .ConditionalDistributionBase import ConditionalDistribution

__all__ = [
    'FactorizedDistribution'
]

[docs]class FactorizedDistribution(Distribution): r""" Distribution :math:`\nu_\pi` defiened by its conditional factors. The density of the distribution :math:`\nu_\pi` is defined by .. math:: \pi({\bf x}) = \prod_{({\bf i},{\bf k}) \in \mathcal{I}} \pi({\bf x}_{\bf i},{\bf x}_{\bf k})` Args: factors (:class:`list` of :class:`tuple`): each tuple contains a factor (:class:`ConditionalDistribution` and/or :class:`Distribution`), and two lists containing the list of marginal variables and conditioning variables Example ------- Let :math:`\pi(x_0,x_1,x_2) = \pi_1(x_2|x_1,x_0) \pi_2(x_0|x_1) \pi_3(x_1)`. >>> factors = [(p1, [2], [1,0] ), >>> (p2, [0], [1] ), >>> (p3, [1], [] )] >>> pi = FactorizedDistribution(factors) """ def __init__(self, factors): input_dim = [] self.factors = [] for i, (pi, xidxs, yidxs) in enumerate(factors): if issubclass(type(pi), ConditionalDistribution): is_cond = True elif issubclass(type(pi), Distribution): is_cond = False else: raise TypeError("The %d-th factor is nor " % i + \ "Distribution or ConditionalDistribution") # pi has right number of inputs if len(set(xidxs)) != pi.dim: raise TypeError("The dimension of the %d-th " % i + \ "distribution does not match the input variables.") if is_cond and len(set(yidxs)) != pi.dim_y: raise TypeError("The conditioning dimension of the %d-th " % i + \ "distribution does not match the input variables.") if not is_cond and len(yidxs) != 0: raise TypeError("The conditioning dimension of the %d-th " % i + \ "distribution (Distribution) must be zero.") self.factors.append( (is_cond, pi, xidxs, yidxs) ) input_dim.extend( xidxs ) # Check total dimension and that all the marginals are available if len(input_dim) != len(set(input_dim)) or \ (len(input_dim) != 0 and len(input_dim) != max(input_dim)+1): raise TypeError("Some marginals are not defined or multiply defined") super(FactorizedDistribution, self).__init__(len(input_dim)) self.update_var_to_factor_map() @property
[docs] def var_to_factor_map(self): return self._var_to_factor_map
[docs] def get_ncalls_tree(self, indent=""): out = super(FactorizedDistribution, self).get_ncalls_tree(indent) for _,pi,_,_ in self.factors: out += pi.get_ncalls_tree(indent + ' ') return out
[docs] def get_nevals_tree(self, indent=""): out = super(FactorizedDistribution, self).get_nevals_tree(indent) for _,pi,_,_ in self.factors: out += pi.get_nevals_tree(indent + ' ') return out
[docs] def get_teval_tree(self, indent=""): out = super(FactorizedDistribution, self).get_teval_tree(indent) for _,pi,_,_ in self.factors: out += pi.get_teval_tree(indent + ' ') return out
[docs] def update_ncalls_tree(self, obj): super(FactorizedDistribution, self).update_ncalls_tree(obj) for (_,pi,_,_),(_,obj_pi,_,_) in zip(self.factors,obj.factors): pi.update_ncalls_tree(obj_pi)
[docs] def update_nevals_tree(self, obj): super(FactorizedDistribution, self).update_nevals_tree(obj) for (_,pi,_,_),(_,obj_pi,_,_) in zip(self.factors,obj.factors): pi.update_nevals_tree(obj_pi)
[docs] def update_teval_tree(self, obj): super(FactorizedDistribution, self).update_teval_tree(obj) for (_,pi,_,_),(_,obj_pi,_,_) in zip(self.factors,obj.factors): pi.update_teval_tree(obj_pi)
[docs] def reset_counters(self): super(FactorizedDistribution, self).reset_counters() for (_,pi,_,_) in self.factors: pi.reset_counters()
[docs] def append(self, factor): r""" Add a new factor to the distribution Args: factor (:class:`tuple`): tuple containing a factor (:class:`ConditionalDistribution` and/or :class:`Distribution`), and two tuples with the list of marginal variables and conditioning variables Example ------- Let :math:`\pi(x_0,x_1,x_2) = \pi_1(x_2|x_1,x_0) \pi_2(x_0|x_1) \pi_3(x_1)` and let's add the factor :math:`\pi_4(x_3|x_0,x_1,x_2)`, obtaining: .. math:: \pi(x_0,x_1,x_2,x_3) = \pi_4(x_3|x_0,x_1,x_2)\pi_1(x_2|x_1,x_0) \pi_2(x_0|x_1) \pi_3(x_1) >>> factor = (pi4, [3], [0,1,2]) >>> pi.append(factor) """ pi, xidxs, yidxs = factor if issubclass(type(pi), ConditionalDistribution): is_cond = True elif issubclass(type(pi), Distribution): is_cond = False else: raise TypeError("The factor is nor " + \ "Distribution or ConditionalDistribution") # pi has right number of inputs if len(xidxs) != pi.dim: raise TypeError("The dimension of the " + \ "distribution does not match the input variables.") if is_cond and len(yidxs) != pi.dim_y: raise TypeError("The conditioning dimension of the " + \ "distribution does not match the input variables.") if not is_cond and len(yidxs) != 0: raise TypeError("The conditioning dimension of the " + \ "distribution (Distribution) must be zero.") # Check xidxs contains new coordinates and that all the marginals are available if min(xidxs) < self.dim or len(xidxs) != len(set(xidxs)) or \ self.dim + len(xidxs) != max(xidxs)+1: raise TypeError("Some marginals are not defined or multiply defined") self.factors.append( (is_cond, pi, xidxs, yidxs) ) self.dim += len(xidxs) self.update_var_to_factor_map()
[docs] def update_var_to_factor_map(self): self._var_to_factor_map = np.zeros(self.dim, dtype=int) for i, factor in enumerate(self.factors): for var in factor[2]: self._var_to_factor_map[var] = i
@property
[docs] def n_factors(self): return len(self.factors)
[docs] def rvs(self, m, *args, **kwargs): out = np.zeros((m, self.dim)) touched = np.zeros(self.dim, dtype=bool) factor_list = self.factors[:] while len(factor_list) > 0: factor = factor_list.pop() filo_queue = factor[3][:] while len(filo_queue) > 0: # Search all the sub-tree idx = filo_queue.pop() if not touched[idx]: child_factor = self.factors[self.var_to_factor_map[idx]] child_factor_is_cond = child_factor[0] child_factor_pi = child_factor[1] child_factor_vars = child_factor[2] if child_factor_is_cond: # the factor is a conditional distribution child_factor_cond_var = child_factor[3] if all(touched[child_factor_cond_var]): # all the conditioning variables have been touched out[:,child_factor_vars] = child_factor_pi.rvs( m, out[:,child_factor_cond_var], *args, **kwargs) touched[child_factor_vars] = True factor_list.remove(child_factor) else: filo_queue.append(idx) # Re-append the factor filo_queue.extend( [cv for cv in child_factor_cond_var if not touched[cv]]) else: # the factor is a distribution out[:,child_factor_vars] = child_factor_pi.rvs(m, *args, **kwargs) touched[child_factor_vars] = True factor_list.remove(child_factor) if factor[0]: out[:,factor[2]] = factor[1].rvs(m, out[:,factor[3]], *args, **kwargs) else: out[:,factor[2]] = factor[1].rvs(m, *args, **kwargs) touched[factor[2]] = True return out
@cached([('factor_list','n_factors')]) @counted
[docs] def log_pdf(self, x, params=None, idxs_slice=slice(None,None,None), cache=None): r""" Evaluate :math:`\log \pi({\bf x})` Args: x (:class:`ndarray<numpy.ndarray>` [:math:`m,d`]): evaluation points params (dict): parameters idxs_slice (slice): if precomputed values are present, this parameter indicates at which of the points to evaluate. The number of indices represented by ``idxs_slice`` must match ``x.shape[0]``. cache (dict): cache Returns: (:class:`ndarray<numpy.ndarray>` [:math:`m`]) -- values of :math:`\log\pi` at the ``x`` points. """ if x.shape[1] != self.dim: raise ValueError("The dimension of the input does not match the dimension " + \ "of the distribution") try: factor_list_cache = cache['factor_list_cache'] except TypeError: factor_list_cache = [None] * self.n_factors out = np.zeros(x.shape[0]) for (is_cond, pi, xidxs, yidxs), fcache in zip(self.factors, factor_list_cache): if is_cond: out += pi.log_pdf(x[:,xidxs], x[:,yidxs], params=params, idxs_slice=idxs_slice, cache=fcache) else: out += pi.log_pdf(x[:,xidxs], params=params, idxs_slice=idxs_slice, cache=fcache) return out
@cached([('factor_list','n_factors')]) @counted
[docs] def grad_x_log_pdf(self, x, params=None, idxs_slice=slice(None,None,None), cache=None): r""" Evaluate :math:`\nabla_{\bf x} \log \pi({\bf x})` Args: x (:class:`ndarray<numpy.ndarray>` [:math:`m,d`]): evaluation points params (dict): parameters idxs_slice (slice): if precomputed values are present, this parameter indicates at which of the points to evaluate. The number of indices represented by ``idxs_slice`` must match ``x.shape[0]``. cache (dict): cache Returns: (:class:`ndarray<numpy.ndarray>` [:math:`m,d`]) -- values of :math:`\nabla_x\log\pi` at the ``x`` points. """ if x.shape[1] != self.dim: raise ValueError("The dimension of the input does not match the dimension " + \ "of the distribution") try: factor_list_cache = cache['factor_list_cache'] except TypeError: factor_list_cache = [None] * self.n_factors out = np.zeros(x.shape) for (is_cond, pi, xidxs, yidxs), fcache in zip(self.factors, factor_list_cache): if is_cond: gx = pi.grad_x_log_pdf(x[:,xidxs], x[:,yidxs], params=params, idxs_slice=idxs_slice, cache=fcache) out[:,yidxs] += gx[:,pi.dim:] else: gx = pi.grad_x_log_pdf(x[:,xidxs], params=params, idxs_slice=idxs_slice, cache=fcache) out[:,xidxs] += gx[:,:pi.dim] return out
@cached_tuple(['log_pdf','grad_x_log_pdf'],[('factor_list','n_factors')]) @counted
[docs] def tuple_grad_x_log_pdf( self, x, params=None, idxs_slice=slice(None,None,None), cache=None): r""" Evaluate :math:`\log \pi({\bf x}), \nabla_{\bf x} \log \pi({\bf x})` Args: x (:class:`ndarray<numpy.ndarray>` [:math:`m,d`]): evaluation points params (dict): parameters idxs_slice (slice): if precomputed values are present, this parameter indicates at which of the points to evaluate. The number of indices represented by ``idxs_slice`` must match ``x.shape[0]``. cache (dict): cache Returns: (:class:`tuple`) -- function and gradient evaluation """ if x.shape[1] != self.dim: raise ValueError("The dimension of the input does not match the dimension " + \ "of the distribution") try: factor_list_cache = cache['factor_list_cache'] except TypeError: factor_list_cache = [None] * self.n_factors lpdf = np.zeros(x.shape[0]) gxlpdf = np.zeros(x.shape) for (is_cond, pi, xidxs, yidxs), fcache in zip(self.factors, factor_list_cache): if is_cond: fx, gx = pi.tuple_grad_x_log_pdf( x[:,xidxs], x[:,yidxs], params=params, idxs_slice=idxs_slice, cache=fcache) lpdf += fx gxlpdf[:,yidxs] += gx[:,pi.dim:] else: fx, gx = pi.tuple_grad_x_log_pdf( x[:,xidxs], params=params, idxs_slice=idxs_slice, cache=fcache) lpdf += fx gxlpdf[:,xidxs] += gx[:,:pi.dim] return lpdf, gxlpdf
@cached([('factor_list','n_factors')],False) @counted
[docs] def hess_x_log_pdf(self, x, params=None, idxs_slice=slice(None,None,None), cache=None): r""" Evaluate :math:`\nabla^2_{\bf x} \log \pi({\bf x})` Args: x (:class:`ndarray<numpy.ndarray>` [:math:`m,d`]): evaluation points params (dict): parameters idxs_slice (slice): if precomputed values are present, this parameter indicates at which of the points to evaluate. The number of indices represented by ``idxs_slice`` must match ``x.shape[0]``. cache (dict): cache Returns: (:class:`ndarray<numpy.ndarray>` [:math:`m,d,d`]) -- values of :math:`\nabla^2_x\log\pi` at the ``x`` points. """ if x.shape[1] != self.dim: raise ValueError("The dimension of the input does not match the dimension " + \ "of the distribution") try: factor_list_cache = cache['factor_list_cache'] except TypeError: factor_list_cache = [None] * self.n_factors m = x.shape[0] out = np.zeros( (m, self.dim, self.dim) ) for (is_cond, pi, xidxs, yidxs), fcache in zip(self.factors, factor_list_cache): if is_cond: hx = pi.hess_x_log_pdf(x[:,xidxs], x[:,yidxs], params=params, idxs_slice=idxs_slice, cache=fcache) out[np.ix_(range(m),xidxs,yidxs)] += hx[:,:pi.dim,pi.dim:] out[np.ix_(range(m),yidxs,xidxs)] += hx[:,pi.dim:,:pi.dim] out[np.ix_(range(m),yidxs,yidxs)] += hx[:,pi.dim:,pi.dim:] else: hx = pi.hess_x_log_pdf(x[:,xidxs], params=params, idxs_slice=idxs_slice, cache=fcache) out[np.ix_(range(m),xidxs,xidxs)] += hx[:,:pi.dim,:pi.dim] return out
@cached([('factor_list','n_factors')],False) @counted
[docs] def action_hess_x_log_pdf(self, x, dx, params=None, idxs_slice=slice(None,None,None), cache=None): r""" Evaluate :math:`\nabla^2_{\bf x} \log \pi({\bf x})` Args: x (:class:`ndarray<numpy.ndarray>` [:math:`m,d`]): evaluation points dx (:class:`ndarray<numpy.ndarray>` [:math:`m,d`]): direction on which to evaluate the Hessian params (dict): parameters idxs_slice (slice): if precomputed values are present, this parameter indicates at which of the points to evaluate. The number of indices represented by ``idxs_slice`` must match ``x.shape[0]``. cache (dict): cache Returns: (:class:`ndarray<numpy.ndarray>` [:math:`m,d,d`]) -- values of :math:`\nabla^2_x\log\pi` at the ``x`` points. """ if x.shape[1] != self.dim: raise ValueError("The dimension of the input does not match the dimension " + \ "of the distribution") try: factor_list_cache = cache['factor_list_cache'] except TypeError: factor_list_cache = [None] * self.n_factors m = x.shape[0] out = np.zeros( (m, self.dim) ) for (is_cond, pi, xidxs, yidxs), fcache in zip(self.factors, factor_list_cache): if is_cond: ahx = pi.action_hess_x_log_pdf( x[:,xidxs], x[:,yidxs], dx[:,xidxs], dx[:,yidxs], params=params, idxs_slice=idxs_slice, cache=fcache) out[np.ix_(range(m),xidxs)] += ahx[:,:pi.dim] out[np.ix_(range(m),yidxs)] += ahx[:,pi.dim:] else: ahx = pi.action_hess_x_log_pdf( x[:,xidxs], dx[:,xidxs], params=params, idxs_slice=idxs_slice, cache=fcache) out[np.ix_(range(m),xidxs)] += ahx[:,:pi.dim] return out