#
# This file is part of TransportMaps.
#
# TransportMaps is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# TransportMaps is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with TransportMaps. If not, see <http://www.gnu.org/licenses/>.
#
# Transport Maps Library
# Copyright (C) 2015-2018 Massachusetts Institute of Technology
# Uncertainty Quantification group
# Department of Aeronautics and Astronautics
#
# Authors: Transport Map Team
# Website: transportmaps.mit.edu
# Support: transportmaps.mit.edu/qa/
#
import unittest
import numpy as np
import numpy.random as npr
import TransportMaps as TM
import TransportMaps.DerivativesChecks as DC
from TransportMaps import KL
from TransportMaps import mpi_map, mpi_map_alloc_dmem, mpi_bcast_dmem, \
MPI_SUPPORT, Maps
[docs]class KL_divergence_DerivativeChecks(object):
[docs] def setUp(self):
npr.seed(1)
self.fd_eps = 1e-6
self.nprocs = 1
[docs] def test_grad_a_kl_divergence(self):
d1 = self.base_distribution
d2 = self.distribution
qtype = self.qtype
qparams = self.qparams
mpi_pool = self.mpi_pool
def kl_divergence(a, params={}):
# Update distribution coefficients
d2.coeffs = a
# Evaluate
out = KL.kl_divergence(d1, d2, qtype=qtype, qparams=qparams,
mpi_pool_tuple=(None,mpi_pool))
return out
def grad_a_kl_divergence(a, params={}):
# Update distribution coefficients
d2.coeffs = a
# Evaluate
out = KL.grad_a_kl_divergence(d1, d2, qtype=qtype, qparams=qparams,
mpi_pool_tuple=(None,mpi_pool))
return out
flag = DC.fd_gradient_check(kl_divergence, grad_a_kl_divergence,
self.coeffs, self.fd_eps,
verbose=False)
self.assertTrue( flag )
[docs] def test_tuple_grad_a_kl_divergence(self):
d1 = self.base_distribution
d2 = self.distribution
qtype = self.qtype
qparams = self.qparams
mpi_pool = self.mpi_pool
def kl_divergence(a, params={}):
# Update distribution coefficients
d2.coeffs = a
# Evaluate
out,_ = KL.tuple_grad_a_kl_divergence(d1, d2, qtype=qtype, qparams=qparams,
mpi_pool_tuple=(None,mpi_pool))
return out
def grad_a_kl_divergence(a, params={}):
# Update distribution coefficients
d2.coeffs = a
# Evaluate
_,out = KL.tuple_grad_a_kl_divergence(d1, d2, qtype=qtype, qparams=qparams,
mpi_pool_tuple=(None,mpi_pool))
return out
flag = DC.fd_gradient_check(kl_divergence, grad_a_kl_divergence,
self.coeffs, self.fd_eps,
verbose=False)
self.assertTrue( flag )
[docs] def test_hess_a_kl_divergence(self):
d1 = self.base_distribution
d2 = self.distribution
qtype = self.qtype
qparams = self.qparams
mpi_pool = self.mpi_pool
def grad_a_kl_divergence(a, params={}):
# Update distribution coefficients
d2.coeffs = a
# Evaluate
out = KL.grad_a_kl_divergence(d1, d2, qtype=qtype, qparams=qparams,
mpi_pool_tuple=(None,mpi_pool))
return out
def hess_a_kl_divergence(a, params={}):
# Update distribution coefficients
d2.coeffs = a
# Evaluate
out = KL.hess_a_kl_divergence(d1, d2, qtype=qtype, qparams=qparams,
mpi_pool_tuple=(None,mpi_pool))
return out
flag = DC.fd_gradient_check(grad_a_kl_divergence, hess_a_kl_divergence,
self.coeffs, self.fd_eps,
verbose=False)
self.assertTrue( flag )
[docs] def test_action_storage_hess_a_kl_divergence(self):
d1 = self.base_distribution
d2 = self.distribution
qtype = self.qtype
qparams = self.qparams
mpi_pool = self.mpi_pool
def grad_a_kl_divergence(a, params={}):
# Update distribution coefficients
d2.coeffs = a
# Evaluate
out = KL.grad_a_kl_divergence(d1, d2, qtype=qtype, qparams=qparams,
mpi_pool_tuple=(None,mpi_pool))
return out
def action_storage_hess_a_kl_divergence(a, v, params={}):
# Update distribution coefficients
d2.coeffs = a
# Evaluate
(H, ) = KL.storage_hess_a_kl_divergence(
d1, d2, qtype=qtype, qparams=qparams,
mpi_pool_tuple=(None,mpi_pool))
out = KL.action_stored_hess_a_kl_divergence(H, v)
return out
v = np.random.randn( d2.n_coeffs )
flag = DC.action_hess_check(
grad_a_kl_divergence, action_storage_hess_a_kl_divergence,
self.coeffs, v, fd_dx=self.fd_eps, verbose=False)
self.assertTrue( flag )
[docs] def test_action_hess_a_kl_divergence(self):
d1 = self.base_distribution
d2 = self.distribution
qtype = self.qtype
qparams = self.qparams
mpi_pool = self.mpi_pool
da = npr.randn(self.coeffs.size)
d2.coeffs = self.coeffs
A = KL.hess_a_kl_divergence(d1, d2, qtype=qtype, qparams=qparams,
mpi_pool_tuple=(None,mpi_pool))
ha_dot_da = np.dot(A, da)
aha = KL.action_hess_a_kl_divergence(
da, d1, d2, qtype=qtype, qparams=qparams,
mpi_pool_tuple=(None,mpi_pool))
self.assertTrue( np.allclose(ha_dot_da, aha) )
@unittest.skip("Not needed")
[docs] def test_grad_x_grad_t_kl_divergence(self):
d1 = self.base_distribution
d2 = self.distribution
qtype = self.qtype
qparams = self.qparams
mpi_pool = self.mpi_pool
x, w = d1.quadrature(qtype, qparams)
params = {'d1': d1, 'd2': d2,
'mpi_pool_tuple': (None, mpi_pool)}
flag = DC.fd_gradient_check(
KL.grad_t_kl_divergence,
KL.grad_x_grad_t_kl_divergence,
x, self.fd_eps, params=params, verbose=False)
self.assertTrue( flag )
[docs] def test_batch_grad_a_kl_divergence(self):
batch_size = 3
d1 = self.base_distribution
d2 = self.distribution
qtype = self.qtype
qparams = self.qparams
mpi_pool = self.mpi_pool
(x,w) = d1.quadrature(qtype, qparams)
params2 = None
def kl_divergence(a, params={}):
# Update distribution coefficients
d2.coeffs = a
# Evaluate
out = KL.kl_divergence(d1, d2, params2=params2, x=x, w=w,
batch_size=batch_size,
mpi_pool_tuple=(None,mpi_pool))
return out
def grad_a_kl_divergence(a, params={}):
# Update distribution coefficients
d2.coeffs = a
# Evaluate
out = KL.grad_a_kl_divergence(d1, d2, params2=params2, x=x, w=w,
batch_size=batch_size,
mpi_pool_tuple=(None,mpi_pool))
return out
flag = DC.fd_gradient_check(kl_divergence, grad_a_kl_divergence,
self.coeffs, self.fd_eps,
verbose=False)
self.assertTrue( flag )
[docs] def test_batch_tuple_grad_a_kl_divergence(self):
batch_size = 3
d1 = self.base_distribution
d2 = self.distribution
qtype = self.qtype
qparams = self.qparams
mpi_pool = self.mpi_pool
(x,w) = d1.quadrature(qtype, qparams)
params2 = None
def kl_divergence(a, params={}):
# Update distribution coefficients
d2.coeffs = a
# Evaluate
out,_ = KL.tuple_grad_a_kl_divergence(d1, d2, params2=params2, x=x, w=w,
batch_size=batch_size,
mpi_pool_tuple=(None,mpi_pool))
return out
def grad_a_kl_divergence(a, params={}):
# Update distribution coefficients
d2.coeffs = a
# Evaluate
_,out = KL.tuple_grad_a_kl_divergence(d1, d2, params2=params2, x=x, w=w,
batch_size=batch_size,
mpi_pool_tuple=(None,mpi_pool))
return out
flag = DC.fd_gradient_check(kl_divergence, grad_a_kl_divergence,
self.coeffs, self.fd_eps,
verbose=False)
self.assertTrue( flag )
[docs] def test_batch_hess_a_kl_divergence(self):
batch_size = 3
d1 = self.base_distribution
d2 = self.distribution
qtype = self.qtype
qparams = self.qparams
mpi_pool = self.mpi_pool
(x,w) = d1.quadrature(qtype, qparams)
params2 = None
def grad_a_kl_divergence(a, params={}):
# Update distribution coefficients
d2.coeffs = a
# Evaluate
out = KL.grad_a_kl_divergence(d1, d2, params2=params2, x=x, w=w,
batch_size=batch_size,
mpi_pool_tuple=(None,mpi_pool))
return out
def hess_a_kl_divergence(a, params={}):
# Update distribution coefficients
d2.coeffs = a
# Evaluate
out = KL.hess_a_kl_divergence(d1, d2, params2=params2, x=x, w=w,
batch_size=batch_size,
mpi_pool_tuple=(None,mpi_pool))
return out
flag = DC.fd_gradient_check(grad_a_kl_divergence, hess_a_kl_divergence,
self.coeffs, self.fd_eps,
verbose=False)
self.assertTrue( flag )
[docs] def test_batch_action_storage_hess_a_kl_divergence(self):
batch_size = 3
d1 = self.base_distribution
d2 = self.distribution
qtype = self.qtype
qparams = self.qparams
mpi_pool = self.mpi_pool
(x,w) = d1.quadrature(qtype, qparams)
params2 = None
def grad_a_kl_divergence(a, params={}):
# Update distribution coefficients
d2.coeffs = a
# Evaluate
out = KL.grad_a_kl_divergence(d1, d2, params2=params2, x=x, w=w,
batch_size=batch_size,
mpi_pool_tuple=(None,mpi_pool))
return out
def action_storage_hess_a_kl_divergence(a, v, params={}):
# Update distribution coefficients
d2.coeffs = a
# Evaluate
(H, )= KL.storage_hess_a_kl_divergence(
d1, d2, params2=params2, qtype=qtype, qparams=qparams,
batch_size=batch_size, mpi_pool_tuple=(None,mpi_pool))
out = KL.action_stored_hess_a_kl_divergence(H, v)
return out
v = np.random.randn( d2.n_coeffs )
flag = DC.action_hess_check(
grad_a_kl_divergence, action_storage_hess_a_kl_divergence,
self.coeffs, v, fd_dx=self.fd_eps, verbose=False)
self.assertTrue( flag )
[docs] def test_batch_action_hess_a_kl_divergence(self):
import TransportMaps as TM
batch_size = 3
d1 = self.base_distribution
d2 = self.distribution
qtype = self.qtype
qparams = self.qparams
mpi_pool = self.mpi_pool
da = npr.randn(self.coeffs.size)
(x,w) = d1.quadrature(qtype, qparams)
params2 = None
d2.coeffs = self.coeffs
A = KL.hess_a_kl_divergence(d1, d2, params2=params2, x=x, w=w,
batch_size=batch_size,
mpi_pool_tuple=(None,mpi_pool))
ha_dot_da = np.dot(A, da)
aha = KL.action_hess_a_kl_divergence(
da, d1, d2, params2=params2, x=x, w=w,
batch_size=batch_size,
mpi_pool_tuple=(None,mpi_pool))
self.assertTrue( np.allclose(ha_dot_da, aha) )
[docs] def test_precomp_grad_a_kl_divergence(self):
d1 = self.base_distribution
d2 = self.distribution
qtype = self.qtype
qparams = self.qparams
mpi_pool = self.mpi_pool
(x,w) = d1.quadrature(qtype, qparams)
# Init memory
params2 = {
'params_pi': None,
'params_t': {'components': [{} for i in range(self.dim)]} }
mpi_bcast_dmem(params2=params2, mpi_pool=mpi_pool)
dmem_key_in_list = ['params2']
dmem_arg_in_list = ['params']
dmem_val_in_list = [params2]
scatter_tuple = (['x'], [x])
mpi_map("precomp_minimize_kl_divergence",
scatter_tuple=scatter_tuple,
dmem_key_in_list=dmem_key_in_list,
dmem_arg_in_list=dmem_arg_in_list,
dmem_val_in_list=dmem_val_in_list,
obj=d2.transport_map, mpi_pool=mpi_pool,
concatenate=False)
def kl_divergence(a, params={}):
# Update distribution coefficients
d2.coeffs = a
# Evaluate
out = KL.kl_divergence(d1, d2, params2=params2, x=x, w=w,
mpi_pool_tuple=(None,mpi_pool))
return out
def grad_a_kl_divergence(a, params={}):
# Update distribution coefficients
d2.coeffs = a
# Evaluate
out = KL.grad_a_kl_divergence(d1, d2, params2=params2, x=x, w=w,
mpi_pool_tuple=(None,mpi_pool))
return out
flag = DC.fd_gradient_check(kl_divergence, grad_a_kl_divergence,
self.coeffs, self.fd_eps,
verbose=False)
self.assertTrue( flag )
[docs] def test_precomp_tuple_grad_a_kl_divergence(self):
d1 = self.base_distribution
d2 = self.distribution
qtype = self.qtype
qparams = self.qparams
mpi_pool = self.mpi_pool
(x,w) = d1.quadrature(qtype, qparams)
# Init memory
params2 = {
'params_pi': None,
'params_t': {'components': [{} for i in range(self.dim)]} }
mpi_bcast_dmem(params2=params2, mpi_pool=mpi_pool)
dmem_key_in_list = ['params2']
dmem_arg_in_list = ['params']
dmem_val_in_list = [params2]
scatter_tuple = (['x'], [x])
mpi_map("precomp_minimize_kl_divergence",
scatter_tuple=scatter_tuple,
dmem_key_in_list=dmem_key_in_list,
dmem_arg_in_list=dmem_arg_in_list,
dmem_val_in_list=dmem_val_in_list,
obj=d2.transport_map, mpi_pool=mpi_pool,
concatenate=False)
def kl_divergence(a, params={}):
# Update distribution coefficients
d2.coeffs = a
# Evaluate
out,_ = KL.tuple_grad_a_kl_divergence(d1, d2, params2=params2, x=x, w=w,
mpi_pool_tuple=(None,mpi_pool))
return out
def grad_a_kl_divergence(a, params={}):
# Update distribution coefficients
d2.coeffs = a
# Evaluate
_,out = KL.tuple_grad_a_kl_divergence(d1, d2, params2=params2, x=x, w=w,
mpi_pool_tuple=(None,mpi_pool))
return out
flag = DC.fd_gradient_check(kl_divergence, grad_a_kl_divergence,
self.coeffs, self.fd_eps,
verbose=False)
self.assertTrue( flag )
[docs] def test_precomp_hess_a_kl_divergence(self):
d1 = self.base_distribution
d2 = self.distribution
qtype = self.qtype
qparams = self.qparams
mpi_pool = self.mpi_pool
(x,w) = d1.quadrature(qtype, qparams)
# Init memory
params2 = {
'params_pi': None,
'params_t': {'components': [{} for i in range(self.dim)]} }
mpi_bcast_dmem(params2=params2, mpi_pool=mpi_pool)
dmem_key_in_list = ['params2']
dmem_arg_in_list = ['params']
dmem_val_in_list = [params2]
scatter_tuple = (['x'], [x])
mpi_map("precomp_minimize_kl_divergence",
scatter_tuple=scatter_tuple,
dmem_key_in_list=dmem_key_in_list,
dmem_arg_in_list=dmem_arg_in_list,
dmem_val_in_list=dmem_val_in_list,
obj=d2.transport_map, mpi_pool=mpi_pool,
concatenate=False)
def grad_a_kl_divergence(a, params={}):
# Update distribution coefficients
d2.coeffs = a
# Evaluate
out = KL.grad_a_kl_divergence(d1, d2, params2=params2, x=x, w=w,
mpi_pool_tuple=(None,mpi_pool))
return out
def hess_a_kl_divergence(a, params={}):
# Update distribution coefficients
d2.coeffs = a
# Evaluate
out = KL.hess_a_kl_divergence(d1, d2, params2=params2, x=x, w=w,
mpi_pool_tuple=(None,mpi_pool))
return out
flag = DC.fd_gradient_check(grad_a_kl_divergence, hess_a_kl_divergence,
self.coeffs, self.fd_eps,
verbose=False)
self.assertTrue( flag )
[docs] def test_precomp_action_storage_hess_a_kl_divergence(self):
d1 = self.base_distribution
d2 = self.distribution
qtype = self.qtype
qparams = self.qparams
mpi_pool = self.mpi_pool
(x,w) = d1.quadrature(qtype, qparams)
# Init memory
params2 = {
'params_pi': None,
'params_t': {'components': [{} for i in range(self.dim)]} }
mpi_bcast_dmem(params2=params2, mpi_pool=mpi_pool)
dmem_key_in_list = ['params2']
dmem_arg_in_list = ['params']
dmem_val_in_list = [params2]
scatter_tuple = (['x'], [x])
mpi_map("precomp_minimize_kl_divergence",
scatter_tuple=scatter_tuple,
dmem_key_in_list=dmem_key_in_list,
dmem_arg_in_list=dmem_arg_in_list,
dmem_val_in_list=dmem_val_in_list,
obj=d2.transport_map, mpi_pool=mpi_pool,
concatenate=False)
def grad_a_kl_divergence(a, params={}):
# Update distribution coefficients
d2.coeffs = a
# Evaluate
out = KL.grad_a_kl_divergence(d1, d2, params2=params2, x=x, w=w,
mpi_pool_tuple=(None,mpi_pool))
return out
def action_storage_hess_a_kl_divergence(a, v, params={}):
# Update distribution coefficients
d2.coeffs = a
# Evaluate
scatter_tuple = (['x', 'w'],[x, w])
bcast_tuple = (['d1', 'd2'], [d1, d2])
dmem_key_in_list = ['params2']
dmem_arg_in_list = ['params2']
dmem_val_in_list = [params2]
dmem_key_out_list = ['hess_a_kl_divergence']
(H, ) = mpi_map_alloc_dmem(
KL.storage_hess_a_kl_divergence, scatter_tuple=scatter_tuple,
bcast_tuple=bcast_tuple, dmem_key_in_list=dmem_key_in_list,
dmem_arg_in_list=dmem_arg_in_list, dmem_val_in_list=dmem_val_in_list,
dmem_key_out_list=dmem_key_out_list,
mpi_pool=mpi_pool, concatenate=False)
bcast_tuple = (['v'], [v])
dmem_key_in_list = ['hess_a_kl_divergence']
dmem_arg_in_list = ['H']
dmem_val_in_list = [H]
reduce_obj = TM.SumChunkReduce(axis=0)
out = mpi_map(KL.action_stored_hess_a_kl_divergence,
bcast_tuple=bcast_tuple,
dmem_key_in_list=dmem_key_in_list,
dmem_arg_in_list=dmem_arg_in_list,
dmem_val_in_list=dmem_val_in_list,
reduce_obj=reduce_obj,
mpi_pool=mpi_pool)
return out
v = np.random.randn( d2.n_coeffs )
flag = DC.action_hess_check(
grad_a_kl_divergence, action_storage_hess_a_kl_divergence,
self.coeffs, v, fd_dx=self.fd_eps, verbose=False)
self.assertTrue( flag )
[docs] def test_precomp_action_hess_a_kl_divergence(self):
d1 = self.base_distribution
d2 = self.distribution
qtype = self.qtype
qparams = self.qparams
mpi_pool = self.mpi_pool
da = npr.randn(self.coeffs.size)
(x,w) = d1.quadrature(qtype, qparams)
# Init memory
params2 = {
'params_pi': None,
'params_t': {'components': [{} for i in range(self.dim)]} }
mpi_bcast_dmem(params2=params2, mpi_pool=mpi_pool)
dmem_key_in_list = ['params2']
dmem_arg_in_list = ['params']
dmem_val_in_list = [params2]
scatter_tuple = (['x'], [x])
mpi_map("precomp_minimize_kl_divergence",
scatter_tuple=scatter_tuple,
dmem_key_in_list=dmem_key_in_list,
dmem_arg_in_list=dmem_arg_in_list,
dmem_val_in_list=dmem_val_in_list,
obj=d2.transport_map, mpi_pool=mpi_pool,
concatenate=False)
d2.coeffs = self.coeffs
A = KL.hess_a_kl_divergence(d1, d2, params2=params2, x=x, w=w,
mpi_pool_tuple=(None,mpi_pool))
ha_dot_da = np.dot(A, da)
aha = KL.action_hess_a_kl_divergence(
da, d1, d2, params2=params2, x=x, w=w,
mpi_pool_tuple=(None,mpi_pool))
self.assertTrue( np.allclose(ha_dot_da, aha) )
[docs] def test_precomp_batch_grad_a_kl_divergence(self):
batch_size = 3
d1 = self.base_distribution
d2 = self.distribution
qtype = self.qtype
qparams = self.qparams
mpi_pool = self.mpi_pool
(x,w) = d1.quadrature(qtype, qparams)
# Init memory
params2 = {
'params_pi': None,
'params_t': {'components': [{} for i in range(self.dim)]} }
mpi_bcast_dmem(params2=params2, mpi_pool=mpi_pool)
dmem_key_in_list = ['params2']
dmem_arg_in_list = ['params']
dmem_val_in_list = [params2]
scatter_tuple = (['x'], [x])
mpi_map("precomp_minimize_kl_divergence",
scatter_tuple=scatter_tuple,
dmem_key_in_list=dmem_key_in_list,
dmem_arg_in_list=dmem_arg_in_list,
dmem_val_in_list=dmem_val_in_list,
obj=d2.transport_map, mpi_pool=mpi_pool,
concatenate=False)
def kl_divergence(a, params={}):
# Update distribution coefficients
d2.coeffs = a
# Evaluate
out = KL.kl_divergence(d1, d2, params2=params2, x=x, w=w,
batch_size=batch_size, mpi_pool_tuple=(None,mpi_pool))
return out
def grad_a_kl_divergence(a, params={}):
# Update distribution coefficients
d2.coeffs = a
# Evaluate
out = KL.grad_a_kl_divergence(d1, d2, params2=params2, x=x, w=w,
batch_size=batch_size,
mpi_pool_tuple=(None,mpi_pool))
return out
flag = DC.fd_gradient_check(kl_divergence, grad_a_kl_divergence,
self.coeffs, self.fd_eps,
verbose=False)
self.assertTrue( flag )
[docs] def test_precomp_batch_tuple_grad_a_kl_divergence(self):
batch_size = 3
d1 = self.base_distribution
d2 = self.distribution
qtype = self.qtype
qparams = self.qparams
mpi_pool = self.mpi_pool
(x,w) = d1.quadrature(qtype, qparams)
# Init memory
params2 = {
'params_pi': None,
'params_t': {'components': [{} for i in range(self.dim)]} }
mpi_bcast_dmem(params2=params2, mpi_pool=mpi_pool)
dmem_key_in_list = ['params2']
dmem_arg_in_list = ['params']
dmem_val_in_list = [params2]
scatter_tuple = (['x'], [x])
mpi_map("precomp_minimize_kl_divergence",
scatter_tuple=scatter_tuple,
dmem_key_in_list=dmem_key_in_list,
dmem_arg_in_list=dmem_arg_in_list,
dmem_val_in_list=dmem_val_in_list,
obj=d2.transport_map, mpi_pool=mpi_pool,
concatenate=False)
def kl_divergence(a, params={}):
# Update distribution coefficients
d2.coeffs = a
# Evaluate
out,_ = KL.tuple_grad_a_kl_divergence(d1, d2, params2=params2, x=x, w=w,
batch_size=batch_size, mpi_pool_tuple=(None,mpi_pool))
return out
def grad_a_kl_divergence(a, params={}):
# Update distribution coefficients
d2.coeffs = a
# Evaluate
_,out = KL.tuple_grad_a_kl_divergence(d1, d2, params2=params2, x=x, w=w,
batch_size=batch_size,
mpi_pool_tuple=(None,mpi_pool))
return out
flag = DC.fd_gradient_check(kl_divergence, grad_a_kl_divergence,
self.coeffs, self.fd_eps,
verbose=False)
self.assertTrue( flag )
[docs] def test_precomp_batch_hess_a_kl_divergence(self):
batch_size = 3
d1 = self.base_distribution
d2 = self.distribution
qtype = self.qtype
qparams = self.qparams
mpi_pool = self.mpi_pool
(x,w) = d1.quadrature(qtype, qparams)
# Init memory
params2 = {
'params_pi': None,
'params_t': {'components': [{} for i in range(self.dim)]} }
mpi_bcast_dmem(params2=params2, mpi_pool=mpi_pool)
dmem_key_in_list = ['params2']
dmem_arg_in_list = ['params']
dmem_val_in_list = [params2]
scatter_tuple = (['x'], [x])
mpi_map("precomp_minimize_kl_divergence",
scatter_tuple=scatter_tuple,
dmem_key_in_list=dmem_key_in_list,
dmem_arg_in_list=dmem_arg_in_list,
dmem_val_in_list=dmem_val_in_list,
obj=d2.transport_map, mpi_pool=mpi_pool,
concatenate=False)
def grad_a_kl_divergence(a, params={}):
# Update distribution coefficients
d2.coeffs = a
# Evaluate
out = KL.grad_a_kl_divergence(d1, d2, params2=params2, x=x, w=w,
batch_size=batch_size,
mpi_pool_tuple=(None,mpi_pool))
return out
def hess_a_kl_divergence(a, params={}):
# Update distribution coefficients
d2.coeffs = a
# Evaluate
out = KL.hess_a_kl_divergence(d1, d2, params2=params2, x=x, w=w,
batch_size=batch_size,
mpi_pool_tuple=(None,mpi_pool))
return out
flag = DC.fd_gradient_check(grad_a_kl_divergence, hess_a_kl_divergence,
self.coeffs, self.fd_eps,
verbose=False)
self.assertTrue( flag )
[docs] def test_precomp_batch_action_storage_hess_a_kl_divergence(self):
batch_size = 3
d1 = self.base_distribution
d2 = self.distribution
qtype = self.qtype
qparams = self.qparams
mpi_pool = self.mpi_pool
(x,w) = d1.quadrature(qtype, qparams)
# Init memory
params2 = {
'params_pi': None,
'params_t': {'components': [{} for i in range(self.dim)]} }
mpi_bcast_dmem(params2=params2, mpi_pool=mpi_pool)
dmem_key_in_list = ['params2']
dmem_arg_in_list = ['params']
dmem_val_in_list = [params2]
scatter_tuple = (['x'], [x])
mpi_map("precomp_minimize_kl_divergence",
scatter_tuple=scatter_tuple,
dmem_key_in_list=dmem_key_in_list,
dmem_arg_in_list=dmem_arg_in_list,
dmem_val_in_list=dmem_val_in_list,
obj=d2.transport_map, mpi_pool=mpi_pool,
concatenate=False)
def grad_a_kl_divergence(a, params={}):
# Update distribution coefficients
d2.coeffs = a
# Evaluate
out = KL.grad_a_kl_divergence(d1, d2, params2=params2, x=x, w=w,
batch_size=batch_size,
mpi_pool_tuple=(None,mpi_pool))
return out
def action_storage_hess_a_kl_divergence(a, v, params={}):
# Update distribution coefficients
d2.coeffs = a
# Evaluate
scatter_tuple = (['x', 'w'],[x, w])
bcast_tuple = (['d1', 'd2'], [d1, d2])
dmem_key_in_list = ['params2']
dmem_arg_in_list = ['params2']
dmem_val_in_list = [params2]
dmem_key_out_list = ['hess_a_kl_divergence']
(H, ) = mpi_map_alloc_dmem(
KL.storage_hess_a_kl_divergence, scatter_tuple=scatter_tuple,
bcast_tuple=bcast_tuple, dmem_key_in_list=dmem_key_in_list,
dmem_arg_in_list=dmem_arg_in_list, dmem_val_in_list=dmem_val_in_list,
dmem_key_out_list=dmem_key_out_list,
mpi_pool=mpi_pool, concatenate=False)
bcast_tuple = (['v'], [v])
dmem_key_in_list = ['hess_a_kl_divergence']
dmem_arg_in_list = ['H']
dmem_val_in_list = [H]
reduce_obj = TM.SumChunkReduce(axis=0)
out = mpi_map(KL.action_stored_hess_a_kl_divergence,
bcast_tuple=bcast_tuple,
dmem_key_in_list=dmem_key_in_list,
dmem_arg_in_list=dmem_arg_in_list,
dmem_val_in_list=dmem_val_in_list,
reduce_obj=reduce_obj,
mpi_pool=mpi_pool)
return out
v = np.random.randn( d2.n_coeffs )
flag = DC.action_hess_check(
grad_a_kl_divergence, action_storage_hess_a_kl_divergence,
self.coeffs, v, fd_dx=self.fd_eps, verbose=False)
self.assertTrue( flag )
[docs] def test_precomp_batch_action_hess_a_kl_divergence(self):
batch_size = 3
d1 = self.base_distribution
d2 = self.distribution
qtype = self.qtype
qparams = self.qparams
mpi_pool = self.mpi_pool
da = npr.randn(self.coeffs.size)
(x,w) = d1.quadrature(qtype, qparams)
# Init memory
params2 = {
'params_pi': None,
'params_t': {'components': [{} for i in range(self.dim)]} }
mpi_bcast_dmem(params2=params2, mpi_pool=mpi_pool)
dmem_key_in_list = ['params2']
dmem_arg_in_list = ['params']
dmem_val_in_list = [params2]
scatter_tuple = (['x'], [x])
mpi_map("precomp_minimize_kl_divergence",
scatter_tuple=scatter_tuple,
dmem_key_in_list=dmem_key_in_list,
dmem_arg_in_list=dmem_arg_in_list,
dmem_val_in_list=dmem_val_in_list,
obj=d2.transport_map, mpi_pool=mpi_pool,
concatenate=False)
d2.coeffs = self.coeffs
A = KL.hess_a_kl_divergence(d1, d2, params2=params2, x=x, w=w,
batch_size=batch_size,
mpi_pool_tuple=(None,mpi_pool))
ha_dot_da = np.dot(A, da)
aha = KL.action_hess_a_kl_divergence(
da, d1, d2, params2=params2, x=x, w=w,
batch_size=batch_size,
mpi_pool_tuple=(None,mpi_pool))
self.assertTrue( np.allclose(ha_dot_da, aha) )
[docs] def test_precomp_cached_grad_a_kl_divergence(self):
d1 = self.base_distribution
d2 = self.distribution
qtype = self.qtype
qparams = self.qparams
mpi_pool = self.mpi_pool
(x,w) = d1.quadrature(qtype, qparams)
# Init memory
params2 = {
'params_pi': None,
'params_t': {'components': [{} for i in range(self.dim)]} }
mpi_bcast_dmem(params2=params2, mpi_pool=mpi_pool)
dmem_key_in_list = ['params2']
dmem_arg_in_list = ['params']
dmem_val_in_list = [params2]
# Precomp
scatter_tuple = (['x'], [x])
mpi_map("precomp_minimize_kl_divergence",
scatter_tuple=scatter_tuple,
dmem_key_in_list=dmem_key_in_list,
dmem_arg_in_list=dmem_arg_in_list,
dmem_val_in_list=dmem_val_in_list,
obj=d2.transport_map, mpi_pool=mpi_pool,
concatenate=False)
# Init cache
scatter_tuple = (['x'], [x])
(cache, ) = mpi_map_alloc_dmem("allocate_cache_minimize_kl_divergence",
scatter_tuple=scatter_tuple,
dmem_key_out_list=['cache'],
obj=d2.transport_map, mpi_pool=mpi_pool,
concatenate=False)
def kl_divergence(a, params={}):
# Update distribution coefficients
d2.coeffs = a
# Evaluate
# Reset cache
dmem_key_in_list = ['cache']
dmem_arg_in_list = ['cache']
dmem_val_in_list = [cache]
mpi_map("reset_cache_minimize_kl_divergence",
dmem_key_in_list=dmem_key_in_list,
dmem_arg_in_list=dmem_arg_in_list,
dmem_val_in_list=dmem_val_in_list,
obj=d2.transport_map, mpi_pool=mpi_pool,
concatenate=False)
# Evaluate
out = KL.kl_divergence(d1, d2, params2=params2, cache=cache,
x=x, w=w,
mpi_pool_tuple=(None,mpi_pool))
return out
def grad_a_kl_divergence(a, params={}):
# Update distribution coefficients
d2.coeffs = a
# Evaluate
# Reset cache
dmem_key_in_list = ['cache']
dmem_arg_in_list = ['cache']
dmem_val_in_list = [cache]
mpi_map("reset_cache_minimize_kl_divergence",
dmem_key_in_list=dmem_key_in_list,
dmem_arg_in_list=dmem_arg_in_list,
dmem_val_in_list=dmem_val_in_list,
obj=d2.transport_map, mpi_pool=mpi_pool,
concatenate=False)
# Fill cache with kl_divergence
KL.kl_divergence(d1, d2, params2=params2, cache=cache,
x=x, w=w,
mpi_pool_tuple=(None,mpi_pool))
# Evaluate grad_a_kl_divergence using cached values
out = KL.grad_a_kl_divergence(d1, d2, params2=params2,
cache=cache, x=x, w=w,
mpi_pool_tuple=(None,mpi_pool))
return out
flag = DC.fd_gradient_check(kl_divergence, grad_a_kl_divergence,
self.coeffs, self.fd_eps,
verbose=False)
self.assertTrue( flag )
[docs] def test_precomp_cached_tuple_grad_a_kl_divergence(self):
d1 = self.base_distribution
d2 = self.distribution
qtype = self.qtype
qparams = self.qparams
mpi_pool = self.mpi_pool
(x,w) = d1.quadrature(qtype, qparams)
# Init memory
params2 = {
'params_pi': None,
'params_t': {'components': [{} for i in range(self.dim)]} }
mpi_bcast_dmem(params2=params2, mpi_pool=mpi_pool)
dmem_key_in_list = ['params2']
dmem_arg_in_list = ['params']
dmem_val_in_list = [params2]
# Precomp
scatter_tuple = (['x'], [x])
mpi_map("precomp_minimize_kl_divergence",
scatter_tuple=scatter_tuple,
dmem_key_in_list=dmem_key_in_list,
dmem_arg_in_list=dmem_arg_in_list,
dmem_val_in_list=dmem_val_in_list,
obj=d2.transport_map, mpi_pool=mpi_pool,
concatenate=False)
# Init cache
scatter_tuple = (['x'], [x])
(cache, ) = mpi_map_alloc_dmem("allocate_cache_minimize_kl_divergence",
scatter_tuple=scatter_tuple,
dmem_key_out_list=['cache'],
obj=d2.transport_map, mpi_pool=mpi_pool,
concatenate=False)
def kl_divergence(a, params={}):
# Update distribution coefficients
d2.coeffs = a
# Evaluate
# Reset cache
dmem_key_in_list = ['cache']
dmem_arg_in_list = ['cache']
dmem_val_in_list = [cache]
mpi_map("reset_cache_minimize_kl_divergence",
dmem_key_in_list=dmem_key_in_list,
dmem_arg_in_list=dmem_arg_in_list,
dmem_val_in_list=dmem_val_in_list,
obj=d2.transport_map, mpi_pool=mpi_pool,
concatenate=False)
# Evaluate
out,_ = KL.tuple_grad_a_kl_divergence(d1, d2, params2=params2,
cache=cache, x=x, w=w,
mpi_pool_tuple=(None,mpi_pool))
return out
def grad_a_kl_divergence(a, params={}):
# Update distribution coefficients
d2.coeffs = a
# Evaluate
# Reset cache
dmem_key_in_list = ['cache']
dmem_arg_in_list = ['cache']
dmem_val_in_list = [cache]
mpi_map("reset_cache_minimize_kl_divergence",
dmem_key_in_list=dmem_key_in_list,
dmem_arg_in_list=dmem_arg_in_list,
dmem_val_in_list=dmem_val_in_list,
obj=d2.transport_map, mpi_pool=mpi_pool,
concatenate=False)
# Fill cache with kl_divergence
KL.kl_divergence(d1, d2, params2=params2,
cache=cache, x=x, w=w,
mpi_pool_tuple=(None,mpi_pool))
# Evaluate grad_a_kl_divergence using cached values
_,out = KL.tuple_grad_a_kl_divergence(d1, d2, params2=params2,
cache=cache, x=x, w=w,
mpi_pool_tuple=(None,mpi_pool))
return out
flag = DC.fd_gradient_check(kl_divergence, grad_a_kl_divergence,
self.coeffs, self.fd_eps,
verbose=False)
self.assertTrue( flag )
[docs] def test_precomp_cached_hess_a_kl_divergence(self):
d1 = self.base_distribution
d2 = self.distribution
qtype = self.qtype
qparams = self.qparams
mpi_pool = self.mpi_pool
(x,w) = d1.quadrature(qtype, qparams)
# Init memory
params2 = {
'params_pi': None,
'params_t': {'components': [{} for i in range(self.dim)]} }
mpi_bcast_dmem(params2=params2, mpi_pool=mpi_pool)
dmem_key_in_list = ['params2']
dmem_arg_in_list = ['params']
dmem_val_in_list = [params2]
# Precomp
scatter_tuple = (['x'], [x])
mpi_map("precomp_minimize_kl_divergence",
scatter_tuple=scatter_tuple,
dmem_key_in_list=dmem_key_in_list,
dmem_arg_in_list=dmem_arg_in_list,
dmem_val_in_list=dmem_val_in_list,
obj=d2.transport_map, mpi_pool=mpi_pool,
concatenate=False)
# Init cache
scatter_tuple = (['x'], [x])
(cache, ) = mpi_map_alloc_dmem("allocate_cache_minimize_kl_divergence",
scatter_tuple=scatter_tuple,
dmem_key_out_list=['cache'],
obj=d2.transport_map, mpi_pool=mpi_pool,
concatenate=False)
def grad_a_kl_divergence(a, params={}):
# Update distribution coefficients
d2.coeffs = a
# Evaluate
# Reset cache
dmem_key_in_list = ['cache']
dmem_arg_in_list = ['cache']
dmem_val_in_list = [cache]
mpi_map("reset_cache_minimize_kl_divergence",
dmem_key_in_list=dmem_key_in_list,
dmem_arg_in_list=dmem_arg_in_list,
dmem_val_in_list=dmem_val_in_list,
obj=d2.transport_map, mpi_pool=mpi_pool,
concatenate=False)
# Evaluate
out = KL.grad_a_kl_divergence(d1, d2, params2=params2,
cache=cache, x=x, w=w,
mpi_pool_tuple=(None,mpi_pool))
return out
def hess_a_kl_divergence(a, params={}):
# Update distribution coefficients
d2.coeffs = a
# Evaluate
# Reset cache
dmem_key_in_list = ['cache']
dmem_arg_in_list = ['cache']
dmem_val_in_list = [cache]
mpi_map("reset_cache_minimize_kl_divergence",
dmem_key_in_list=dmem_key_in_list,
dmem_arg_in_list=dmem_arg_in_list,
dmem_val_in_list=dmem_val_in_list,
obj=d2.transport_map, mpi_pool=mpi_pool,
concatenate=False)
# Fill cache with grad_a_kl_divergence
KL.grad_a_kl_divergence(d1, d2, params2=params2,
cache=cache, x=x, w=w,
mpi_pool_tuple=(None,mpi_pool))
# Evaluate
out = KL.hess_a_kl_divergence(d1, d2, params2=params2,
cache=cache, x=x, w=w,
mpi_pool_tuple=(None,mpi_pool))
return out
flag = DC.fd_gradient_check(grad_a_kl_divergence, hess_a_kl_divergence,
self.coeffs, self.fd_eps,
verbose=False)
self.assertTrue( flag )
[docs] def test_precomp_cached_action_storage_hess_a_kl_divergence(self):
d1 = self.base_distribution
d2 = self.distribution
qtype = self.qtype
qparams = self.qparams
mpi_pool = self.mpi_pool
(x,w) = d1.quadrature(qtype, qparams)
# Init memory
params2 = {
'params_pi': None,
'params_t': {'components': [{} for i in range(self.dim)]} }
mpi_bcast_dmem(params2=params2, mpi_pool=mpi_pool)
dmem_key_in_list = ['params2']
dmem_arg_in_list = ['params']
dmem_val_in_list = [params2]
# Precomp
scatter_tuple = (['x'], [x])
mpi_map("precomp_minimize_kl_divergence",
scatter_tuple=scatter_tuple,
dmem_key_in_list=dmem_key_in_list,
dmem_arg_in_list=dmem_arg_in_list,
dmem_val_in_list=dmem_val_in_list,
obj=d2.transport_map, mpi_pool=mpi_pool,
concatenate=False)
# Init cache
scatter_tuple = (['x'], [x])
(cache, ) = mpi_map_alloc_dmem("allocate_cache_minimize_kl_divergence",
scatter_tuple=scatter_tuple,
dmem_key_out_list=['cache'],
obj=d2.transport_map, mpi_pool=mpi_pool,
concatenate=False)
def grad_a_kl_divergence(a, params={}):
# Update distribution coefficients
d2.coeffs = a
# Evaluate
# Reset cache
dmem_key_in_list = ['cache']
dmem_arg_in_list = ['cache']
dmem_val_in_list = [cache]
mpi_map("reset_cache_minimize_kl_divergence",
dmem_key_in_list=dmem_key_in_list,
dmem_arg_in_list=dmem_arg_in_list,
dmem_val_in_list=dmem_val_in_list,
obj=d2.transport_map, mpi_pool=mpi_pool,
concatenate=False)
# Evaluate
out = KL.grad_a_kl_divergence(d1, d2, params2=params2,
cache=cache, x=x, w=w,
mpi_pool_tuple=(None,mpi_pool))
return out
def action_storage_hess_a_kl_divergence(a, v, params={}):
# Update distribution coefficients
d2.coeffs = a
# Evaluate
# Reset cache
dmem_key_in_list = ['cache']
dmem_arg_in_list = ['cache']
dmem_val_in_list = [cache]
mpi_map("reset_cache_minimize_kl_divergence",
dmem_key_in_list=dmem_key_in_list,
dmem_arg_in_list=dmem_arg_in_list,
dmem_val_in_list=dmem_val_in_list,
obj=d2.transport_map, mpi_pool=mpi_pool,
concatenate=False)
# Fill cache with grad_a_kl_divergence
KL.grad_a_kl_divergence(d1, d2, params2=params2,
cache=cache, x=x, w=w,
mpi_pool_tuple=(None,mpi_pool))
# Evaluate
scatter_tuple = (['x', 'w'],[x, w])
bcast_tuple = (['d1', 'd2'], [d1, d2])
dmem_key_in_list = ['params2']
dmem_arg_in_list = ['params2']
dmem_val_in_list = [params2]
dmem_key_out_list = ['hess_a_kl_divergence']
(H, ) = mpi_map_alloc_dmem(
KL.storage_hess_a_kl_divergence, scatter_tuple=scatter_tuple,
bcast_tuple=bcast_tuple, dmem_key_in_list=dmem_key_in_list,
dmem_arg_in_list=dmem_arg_in_list, dmem_val_in_list=dmem_val_in_list,
dmem_key_out_list=dmem_key_out_list,
mpi_pool=mpi_pool, concatenate=False)
bcast_tuple = (['v'], [v])
dmem_key_in_list = ['hess_a_kl_divergence']
dmem_arg_in_list = ['H']
dmem_val_in_list = [H]
reduce_obj = TM.SumChunkReduce(axis=0)
out = mpi_map(KL.action_stored_hess_a_kl_divergence,
bcast_tuple=bcast_tuple,
dmem_key_in_list=dmem_key_in_list,
dmem_arg_in_list=dmem_arg_in_list,
dmem_val_in_list=dmem_val_in_list,
reduce_obj=reduce_obj,
mpi_pool=mpi_pool)
return out
v = np.random.randn( d2.n_coeffs )
flag = DC.action_hess_check(
grad_a_kl_divergence, action_storage_hess_a_kl_divergence,
self.coeffs, v, fd_dx=self.fd_eps, verbose=False)
self.assertTrue( flag )
[docs] def test_precomp_cached_action_hess_a_kl_divergence(self):
d1 = self.base_distribution
d2 = self.distribution
qtype = self.qtype
qparams = self.qparams
mpi_pool = self.mpi_pool
da = npr.randn(self.coeffs.size)
(x,w) = d1.quadrature(qtype, qparams)
# Init memory
params2 = {
'params_pi': None,
'params_t': {'components': [{} for i in range(self.dim)]} }
mpi_bcast_dmem(params2=params2, mpi_pool=mpi_pool)
dmem_key_in_list = ['params2']
dmem_arg_in_list = ['params']
dmem_val_in_list = [params2]
# Precomp
scatter_tuple = (['x'], [x])
mpi_map("precomp_minimize_kl_divergence",
scatter_tuple=scatter_tuple,
dmem_key_in_list=dmem_key_in_list,
dmem_arg_in_list=dmem_arg_in_list,
dmem_val_in_list=dmem_val_in_list,
obj=d2.transport_map, mpi_pool=mpi_pool,
concatenate=False)
# Init cache
scatter_tuple = (['x'], [x])
(cache, ) = mpi_map_alloc_dmem("allocate_cache_minimize_kl_divergence",
scatter_tuple=scatter_tuple,
dmem_key_out_list=['cache'],
obj=d2.transport_map, mpi_pool=mpi_pool,
concatenate=False)
# Update distribution coefficients
d2.coeffs = self.coeffs
# Reset cache
dmem_key_in_list = ['cache']
dmem_arg_in_list = ['cache']
dmem_val_in_list = [cache]
mpi_map("reset_cache_minimize_kl_divergence",
dmem_key_in_list=dmem_key_in_list,
dmem_arg_in_list=dmem_arg_in_list,
dmem_val_in_list=dmem_val_in_list,
obj=d2.transport_map, mpi_pool=mpi_pool,
concatenate=False)
# Fill cache with grad_a_kl_divergence
KL.grad_a_kl_divergence(d1, d2, params2=params2,
cache=cache, x=x, w=w,
mpi_pool_tuple=(None,mpi_pool))
# Evaluate
A = KL.hess_a_kl_divergence(d1, d2, params2=params2,
cache=cache, x=x, w=w,
mpi_pool_tuple=(None,mpi_pool))
ha_dot_da = np.dot(A, da)
# Reset cache
dmem_key_in_list = ['cache']
dmem_arg_in_list = ['cache']
dmem_val_in_list = [cache]
mpi_map("reset_cache_minimize_kl_divergence",
dmem_key_in_list=dmem_key_in_list,
dmem_arg_in_list=dmem_arg_in_list,
dmem_val_in_list=dmem_val_in_list,
obj=d2.transport_map, mpi_pool=mpi_pool,
concatenate=False)
# Fill cache with grad_a_kl_divergence
KL.grad_a_kl_divergence(d1, d2, params2=params2,
cache=cache, x=x, w=w,
mpi_pool_tuple=(None,mpi_pool))
# Evaluate
aha = KL.action_hess_a_kl_divergence(da, d1, d2, params2=params2,
cache=cache, x=x, w=w,
mpi_pool_tuple=(None,mpi_pool))
self.assertTrue( np.allclose(ha_dot_da, aha) )
[docs] def test_precomp_cached_batch_grad_a_kl_divergence(self):
batch_size = 3
d1 = self.base_distribution
d2 = self.distribution
qtype = self.qtype
qparams = self.qparams
mpi_pool = self.mpi_pool
(x,w) = d1.quadrature(qtype, qparams)
# Init memory
params2 = {
'params_pi': None,
'params_t': {'components': [{} for i in range(self.dim)]} }
mpi_bcast_dmem(params2=params2, mpi_pool=mpi_pool)
dmem_key_in_list = ['params2']
dmem_arg_in_list = ['params']
dmem_val_in_list = [params2]
# Precomp
scatter_tuple = (['x'], [x])
mpi_map("precomp_minimize_kl_divergence",
scatter_tuple=scatter_tuple,
dmem_key_in_list=dmem_key_in_list,
dmem_arg_in_list=dmem_arg_in_list,
dmem_val_in_list=dmem_val_in_list,
obj=d2.transport_map, mpi_pool=mpi_pool,
concatenate=False)
# Init cache
scatter_tuple = (['x'], [x])
(cache, ) = mpi_map_alloc_dmem("allocate_cache_minimize_kl_divergence",
scatter_tuple=scatter_tuple,
dmem_key_out_list=['cache'],
obj=d2.transport_map, mpi_pool=mpi_pool,
concatenate=False)
def kl_divergence(a, params={}):
# Update distribution coefficients
d2.coeffs = a
# Evaluate
# Reset cache
dmem_key_in_list = ['cache']
dmem_arg_in_list = ['cache']
dmem_val_in_list = [cache]
mpi_map("reset_cache_minimize_kl_divergence",
dmem_key_in_list=dmem_key_in_list,
dmem_arg_in_list=dmem_arg_in_list,
dmem_val_in_list=dmem_val_in_list,
obj=d2.transport_map, mpi_pool=mpi_pool,
concatenate=False)
# Evaluate
out = KL.kl_divergence(d1, d2, params2=params2,
cache=cache, x=x, w=w,
batch_size=batch_size,
mpi_pool_tuple=(None,mpi_pool))
return out
def grad_a_kl_divergence(a, params={}):
# Update distribution coefficients
d2.coeffs = a
# Evaluate
# Reset cache
dmem_key_in_list = ['cache']
dmem_arg_in_list = ['cache']
dmem_val_in_list = [cache]
mpi_map("reset_cache_minimize_kl_divergence",
dmem_key_in_list=dmem_key_in_list,
dmem_arg_in_list=dmem_arg_in_list,
dmem_val_in_list=dmem_val_in_list,
obj=d2.transport_map, mpi_pool=mpi_pool,
concatenate=False)
# Fill cache with kl_divergence
KL.kl_divergence(d1, d2, params2=params2,
cache=cache, x=x, w=w,
mpi_pool_tuple=(None,mpi_pool))
# Evaluate grad_a_kl_divergence using cached values
out = KL.grad_a_kl_divergence(d1, d2, params2=params2,
cache=cache, x=x, w=w,
batch_size=batch_size,
mpi_pool_tuple=(None,mpi_pool))
return out
flag = DC.fd_gradient_check(kl_divergence, grad_a_kl_divergence,
self.coeffs, self.fd_eps,
verbose=False)
self.assertTrue( flag )
[docs] def test_precomp_cached_batch_tuple_grad_a_kl_divergence(self):
batch_size = 3
d1 = self.base_distribution
d2 = self.distribution
qtype = self.qtype
qparams = self.qparams
mpi_pool = self.mpi_pool
(x,w) = d1.quadrature(qtype, qparams)
# Init memory
params2 = {
'params_pi': None,
'params_t': {'components': [{} for i in range(self.dim)]} }
mpi_bcast_dmem(params2=params2, mpi_pool=mpi_pool)
dmem_key_in_list = ['params2']
dmem_arg_in_list = ['params']
dmem_val_in_list = [params2]
# Precomp
scatter_tuple = (['x'], [x])
mpi_map("precomp_minimize_kl_divergence",
scatter_tuple=scatter_tuple,
dmem_key_in_list=dmem_key_in_list,
dmem_arg_in_list=dmem_arg_in_list,
dmem_val_in_list=dmem_val_in_list,
obj=d2.transport_map, mpi_pool=mpi_pool,
concatenate=False)
# Init cache
scatter_tuple = (['x'], [x])
(cache, ) = mpi_map_alloc_dmem("allocate_cache_minimize_kl_divergence",
scatter_tuple=scatter_tuple,
dmem_key_out_list=['cache'],
obj=d2.transport_map, mpi_pool=mpi_pool,
concatenate=False)
def kl_divergence(a, params={}):
# Update distribution coefficients
d2.coeffs = a
# Evaluate
# Reset cache
dmem_key_in_list = ['cache']
dmem_arg_in_list = ['cache']
dmem_val_in_list = [cache]
mpi_map("reset_cache_minimize_kl_divergence",
dmem_key_in_list=dmem_key_in_list,
dmem_arg_in_list=dmem_arg_in_list,
dmem_val_in_list=dmem_val_in_list,
obj=d2.transport_map, mpi_pool=mpi_pool,
concatenate=False)
# Evaluate
out,_ = KL.tuple_grad_a_kl_divergence(d1, d2, params2=params2,
cache=cache, x=x, w=w,
batch_size=batch_size,
mpi_pool_tuple=(None,mpi_pool))
return out
def grad_a_kl_divergence(a, params={}):
# Update distribution coefficients
d2.coeffs = a
# Evaluate
# Reset cache
dmem_key_in_list = ['cache']
dmem_arg_in_list = ['cache']
dmem_val_in_list = [cache]
mpi_map("reset_cache_minimize_kl_divergence",
dmem_key_in_list=dmem_key_in_list,
dmem_arg_in_list=dmem_arg_in_list,
dmem_val_in_list=dmem_val_in_list,
obj=d2.transport_map, mpi_pool=mpi_pool,
concatenate=False)
# Fill cache with kl_divergence
KL.kl_divergence(d1, d2, params2=params2, cache=cache, x=x, w=w,
mpi_pool_tuple=(None,mpi_pool))
# Evaluate grad_a_kl_divergence using cached values
_,out = KL.tuple_grad_a_kl_divergence(d1, d2, params2=params2,
cache=cache, x=x, w=w,
batch_size=batch_size,
mpi_pool_tuple=(None,mpi_pool))
return out
flag = DC.fd_gradient_check(kl_divergence, grad_a_kl_divergence,
self.coeffs, self.fd_eps,
verbose=False)
self.assertTrue( flag )
[docs] def test_precomp_cached_batch_hess_a_kl_divergence(self):
batch_size = 3
d1 = self.base_distribution
d2 = self.distribution
qtype = self.qtype
qparams = self.qparams
mpi_pool = self.mpi_pool
(x,w) = d1.quadrature(qtype, qparams)
# Init memory
params2 = {
'params_pi': None,
'params_t': {'components': [{} for i in range(self.dim)]} }
mpi_bcast_dmem(params2=params2, mpi_pool=mpi_pool)
dmem_key_in_list = ['params2']
dmem_arg_in_list = ['params']
dmem_val_in_list = [params2]
# Precomp
scatter_tuple = (['x'], [x])
mpi_map("precomp_minimize_kl_divergence",
scatter_tuple=scatter_tuple,
dmem_key_in_list=dmem_key_in_list,
dmem_arg_in_list=dmem_arg_in_list,
dmem_val_in_list=dmem_val_in_list,
obj=d2.transport_map, mpi_pool=mpi_pool,
concatenate=False)
# Init cache
scatter_tuple = (['x'], [x])
(cache, ) = mpi_map_alloc_dmem("allocate_cache_minimize_kl_divergence",
scatter_tuple=scatter_tuple,
dmem_key_out_list=['cache'],
obj=d2.transport_map, mpi_pool=mpi_pool,
concatenate=False)
def grad_a_kl_divergence(a, params={}):
# Update distribution coefficients
d2.coeffs = a
# Evaluate
# Reset cache
dmem_key_in_list = ['cache']
dmem_arg_in_list = ['cache']
dmem_val_in_list = [cache]
mpi_map("reset_cache_minimize_kl_divergence",
dmem_key_in_list=dmem_key_in_list,
dmem_arg_in_list=dmem_arg_in_list,
dmem_val_in_list=dmem_val_in_list,
obj=d2.transport_map, mpi_pool=mpi_pool,
concatenate=False)
# Evaluate
out = KL.grad_a_kl_divergence(d1, d2, params2=params2,
cache=cache, x=x, w=w,
batch_size=batch_size,
mpi_pool_tuple=(None,mpi_pool))
return out
def hess_a_kl_divergence(a, params={}):
# Update distribution coefficients
d2.coeffs = a
# Evaluate
# Reset cache
dmem_key_in_list = ['cache']
dmem_arg_in_list = ['cache']
dmem_val_in_list = [cache]
mpi_map("reset_cache_minimize_kl_divergence",
dmem_key_in_list=dmem_key_in_list,
dmem_arg_in_list=dmem_arg_in_list,
dmem_val_in_list=dmem_val_in_list,
obj=d2.transport_map, mpi_pool=mpi_pool,
concatenate=False)
# Fill cache with grad_a_kl_divergence
KL.grad_a_kl_divergence(d1, d2, params2=params2,
cache=cache, x=x, w=w,
mpi_pool_tuple=(None,mpi_pool))
# Evaluate
out = KL.hess_a_kl_divergence(d1, d2, params2=params2,
cache=cache, x=x, w=w,
batch_size=batch_size,
mpi_pool_tuple=(None,mpi_pool))
return out
flag = DC.fd_gradient_check(grad_a_kl_divergence, hess_a_kl_divergence,
self.coeffs, self.fd_eps,
verbose=False)
self.assertTrue( flag )
[docs] def test_precomp_cached_batch_action_storage_hess_a_kl_divergence(self):
batch_size = 3
d1 = self.base_distribution
d2 = self.distribution
qtype = self.qtype
qparams = self.qparams
mpi_pool = self.mpi_pool
(x,w) = d1.quadrature(qtype, qparams)
# Init memory
params2 = {
'params_pi': None,
'params_t': {'components': [{} for i in range(self.dim)]} }
mpi_bcast_dmem(params2=params2, mpi_pool=mpi_pool)
dmem_key_in_list = ['params2']
dmem_arg_in_list = ['params']
dmem_val_in_list = [params2]
# Precomp
scatter_tuple = (['x'], [x])
mpi_map("precomp_minimize_kl_divergence",
scatter_tuple=scatter_tuple,
dmem_key_in_list=dmem_key_in_list,
dmem_arg_in_list=dmem_arg_in_list,
dmem_val_in_list=dmem_val_in_list,
obj=d2.transport_map, mpi_pool=mpi_pool,
concatenate=False)
# Init cache
scatter_tuple = (['x'], [x])
(cache, ) = mpi_map_alloc_dmem("allocate_cache_minimize_kl_divergence",
scatter_tuple=scatter_tuple,
dmem_key_out_list=['cache'],
obj=d2.transport_map, mpi_pool=mpi_pool,
concatenate=False)
def grad_a_kl_divergence(a, params={}):
# Update distribution coefficients
d2.coeffs = a
# Evaluate
# Reset cache
dmem_key_in_list = ['cache']
dmem_arg_in_list = ['cache']
dmem_val_in_list = [cache]
mpi_map("reset_cache_minimize_kl_divergence",
dmem_key_in_list=dmem_key_in_list,
dmem_arg_in_list=dmem_arg_in_list,
dmem_val_in_list=dmem_val_in_list,
obj=d2.transport_map, mpi_pool=mpi_pool,
concatenate=False)
# Evaluate
out = KL.grad_a_kl_divergence(d1, d2, params2=params2,
cache=cache, x=x, w=w,
batch_size=batch_size,
mpi_pool_tuple=(None,mpi_pool))
return out
def action_storage_hess_a_kl_divergence(a, v, params={}):
# Update distribution coefficients
d2.coeffs = a
# Evaluate
# Reset cache
dmem_key_in_list = ['cache']
dmem_arg_in_list = ['cache']
dmem_val_in_list = [cache]
mpi_map("reset_cache_minimize_kl_divergence",
dmem_key_in_list=dmem_key_in_list,
dmem_arg_in_list=dmem_arg_in_list,
dmem_val_in_list=dmem_val_in_list,
obj=d2.transport_map, mpi_pool=mpi_pool,
concatenate=False)
# Fill cache with grad_a_kl_divergence
KL.grad_a_kl_divergence(d1, d2, params2=params2,
cache=cache, x=x, w=w,
mpi_pool_tuple=(None,mpi_pool))
# Evaluate
scatter_tuple = (['x', 'w'],[x, w])
bcast_tuple = (['d1', 'd2', 'batch_size'], [d1, d2, batch_size])
dmem_key_in_list = ['params2', 'cache']
dmem_arg_in_list = ['params2', 'cache']
dmem_val_in_list = [params2, cache]
dmem_key_out_list = ['hess_a_kl_divergence']
(H, ) = mpi_map_alloc_dmem(
KL.storage_hess_a_kl_divergence, scatter_tuple=scatter_tuple,
bcast_tuple=bcast_tuple, dmem_key_in_list=dmem_key_in_list,
dmem_arg_in_list=dmem_arg_in_list, dmem_val_in_list=dmem_val_in_list,
dmem_key_out_list=dmem_key_out_list,
mpi_pool=mpi_pool, concatenate=False)
bcast_tuple = (['v'], [v])
dmem_key_in_list = ['hess_a_kl_divergence']
dmem_arg_in_list = ['H']
dmem_val_in_list = [H]
reduce_obj = TM.SumChunkReduce(axis=0)
out = mpi_map(KL.action_stored_hess_a_kl_divergence,
bcast_tuple=bcast_tuple,
dmem_key_in_list=dmem_key_in_list,
dmem_arg_in_list=dmem_arg_in_list,
dmem_val_in_list=dmem_val_in_list,
reduce_obj=reduce_obj,
mpi_pool=mpi_pool)
return out
v = np.random.randn( d2.n_coeffs )
flag = DC.action_hess_check(
grad_a_kl_divergence, action_storage_hess_a_kl_divergence,
self.coeffs, v, fd_dx=self.fd_eps, verbose=False)
self.assertTrue( flag )
[docs] def test_precomp_cached_batch_action_hess_a_kl_divergence(self):
import TransportMaps as TM
batch_size = 3
d1 = self.base_distribution
d2 = self.distribution
qtype = self.qtype
qparams = self.qparams
mpi_pool = self.mpi_pool
da = npr.randn(self.coeffs.size)
(x,w) = d1.quadrature(qtype, qparams)
# Init memory
params2 = {
'params_pi': None,
'params_t': {'components': [{} for i in range(self.dim)]} }
mpi_bcast_dmem(params2=params2, mpi_pool=mpi_pool)
dmem_key_in_list = ['params2']
dmem_arg_in_list = ['params']
dmem_val_in_list = [params2]
# Precomp
scatter_tuple = (['x'], [x])
mpi_map("precomp_minimize_kl_divergence",
scatter_tuple=scatter_tuple,
dmem_key_in_list=dmem_key_in_list,
dmem_arg_in_list=dmem_arg_in_list,
dmem_val_in_list=dmem_val_in_list,
obj=d2.transport_map, mpi_pool=mpi_pool,
concatenate=False)
# Init cache
scatter_tuple = (['x'], [x])
(cache, ) = mpi_map_alloc_dmem("allocate_cache_minimize_kl_divergence",
scatter_tuple=scatter_tuple,
dmem_key_out_list=['cache'],
obj=d2.transport_map, mpi_pool=mpi_pool,
concatenate=False)
# Update distribution coefficients
d2.coeffs = self.coeffs
# Reset cache
dmem_key_in_list = ['cache']
dmem_arg_in_list = ['cache']
dmem_val_in_list = [cache]
mpi_map("reset_cache_minimize_kl_divergence",
dmem_key_in_list=dmem_key_in_list,
dmem_arg_in_list=dmem_arg_in_list,
dmem_val_in_list=dmem_val_in_list,
obj=d2.transport_map, mpi_pool=mpi_pool,
concatenate=False)
# Fill cache with grad_a_kl_divergence
KL.grad_a_kl_divergence(d1, d2, params2=params2,
cache=cache, x=x, w=w,
mpi_pool_tuple=(None,mpi_pool))
# Evaluate
A = KL.hess_a_kl_divergence(d1, d2, params2=params2,
cache=cache, x=x, w=w,
batch_size=batch_size,
mpi_pool_tuple=(None,mpi_pool))
ha_dot_da = np.dot(A, da)
# Reset cache
dmem_key_in_list = ['cache']
dmem_arg_in_list = ['cache']
dmem_val_in_list = [cache]
mpi_map("reset_cache_minimize_kl_divergence",
dmem_key_in_list=dmem_key_in_list,
dmem_arg_in_list=dmem_arg_in_list,
dmem_val_in_list=dmem_val_in_list,
obj=d2.transport_map, mpi_pool=mpi_pool,
concatenate=False)
# Fill cache with grad_a_kl_divergence
KL.grad_a_kl_divergence(d1, d2, params2=params2,
cache=cache, x=x, w=w,
mpi_pool_tuple=(None,mpi_pool))
# Evaluate
aha = KL.action_hess_a_kl_divergence(da, d1, d2, params2=params2,
cache=cache, x=x, w=w,
batch_size=batch_size,
mpi_pool_tuple=(None,mpi_pool))
self.assertTrue( np.allclose(ha_dot_da, aha) )
#
# Serial and parallel tests
#
[docs]class Serial_KL_divergence_DerivativeChecks(KL_divergence_DerivativeChecks):
[docs] def setUp(self):
super(Serial_KL_divergence_DerivativeChecks,self).setUp()
self.mpi_pool = None
[docs]class ParallelPool_KL_divergence_DerivativeChecks(KL_divergence_DerivativeChecks):
[docs] def setUp(self):
import TransportMaps as TM
super(ParallelPool_KL_divergence_DerivativeChecks,self).setUp()
import_set = set([ (None, 'numpy', 'np') ])
self.mpi_pool = TM.get_mpi_pool()
self.mpi_pool.start(2)
self.mpi_pool.mod_import(import_set)
[docs] def tearDown(self):
import time
self.mpi_pool.stop()
time.sleep(0.2)
#
# PullBack and PushForward test cases
#
[docs]class Serial_PullBackTMD_KL_divergence_DerivativeChecks(
Serial_KL_divergence_DerivativeChecks):
[docs] def setUp(self):
import TransportMaps.Distributions as DIST
self.distribution = DIST.PullBackParametricTransportMapDistribution( self.tm_approx,
self.distribution_pi )
super(Serial_PullBackTMD_KL_divergence_DerivativeChecks,self).setUp()
[docs]class ParallelPool_PullBackTMD_KL_divergence_DerivativeChecks(
ParallelPool_KL_divergence_DerivativeChecks):
[docs] def setUp(self):
import TransportMaps.Distributions as DIST
self.distribution = DIST.PullBackParametricTransportMapDistribution( self.tm_approx,
self.distribution_pi )
super(ParallelPool_PullBackTMD_KL_divergence_DerivativeChecks,self).setUp()
[docs]class Serial_PushForwardTMD_KL_divergence_DerivativeChecks(
Serial_KL_divergence_DerivativeChecks):
[docs] def setUp(self):
import TransportMaps.Distributions as DIST
self.distribution = DIST.PushForwardParametricTransportMapDistribution( self.tm_approx,
self.distribution_pi )
super(Serial_PushForwardTMD_KL_divergence_DerivativeChecks,self).setUp()
@unittest.skip("Not implemented")
[docs] def test_hess_a_kl_divergence(self):
pass
@unittest.skip("Not implemented")
[docs] def test_batch_hess_a_kl_divergence(self):
pass
@unittest.skip("Not implemented")
[docs] def test_action_hess_a_kl_divergence(self):
pass
@unittest.skip("Not implemented")
[docs] def test_batch_action_hess_a_kl_divergence(self):
pass
@unittest.skip("Not implemented")
[docs] def test_precomp_grad_a_kl_divergence(self):
pass
@unittest.skip("Not implemented")
[docs] def test_precomp_hess_a_kl_divergence(self):
pass
@unittest.skip("Not implemented")
[docs] def test_precomp_batch_grad_a_kl_divergence(self):
pass
@unittest.skip("Not implemented")
[docs] def test_precomp_batch_hess_a_kl_divergence(self):
pass
@unittest.skip("Not implemented")
[docs] def test_precomp_cached_grad_a_kl_divergence(self):
pass
@unittest.skip("Not implemented")
[docs] def test_precomp_cached_hess_a_kl_divergence(self):
pass
@unittest.skip("Not implemented")
[docs] def test_precomp_cached_batch_grad_a_kl_divergence(self):
pass
@unittest.skip("Not implemented")
[docs] def test_precomp_cached_batch_hess_a_kl_divergence(self):
pass
@unittest.skip("Not implemented")
[docs] def test_action_storage_hess_a_kl_divergence(self):
pass
@unittest.skip("Not implemented")
[docs] def test_batch_action_storage_hess_a_kl_divergence(self):
pass
@unittest.skip("Not implemented")
[docs] def test_precomp_action_storage_hess_a_kl_divergence(self):
pass
@unittest.skip("Not implemented")
[docs] def test_precomp_batch_action_storage_hess_a_kl_divergence(self):
pass
@unittest.skip("Not implemented")
[docs] def test_precomp_cached_action_storage_hess_a_kl_divergence(self):
pass
@unittest.skip("Not implemented")
[docs] def test_precomp_cached_batch_action_storage_hess_a_kl_divergence(self):
pass
[docs] def test_cached_grad_a_kl_divergence(self):
d1 = self.base_distribution
d2 = self.distribution
qtype = self.qtype
qparams = self.qparams
mpi_pool = self.mpi_pool
(x,w) = d1.quadrature(qtype, qparams)
# Distribute objects
mpi_bcast_dmem(d2=d2, mpi_pool=mpi_pool)
# Link tm to d2.transport_map
def link_tm_d2(d2):
return (d2.transport_map,)
(tm,) = mpi_map_alloc_dmem(
link_tm_d2, dmem_key_in_list=['d2'], dmem_arg_in_list=['d2'],
dmem_val_in_list=[d2], dmem_key_out_list=['tm'],
mpi_pool=mpi_pool)
# Init memory
params2 = {
'params_pi': None,
'params_t': {'components': [{} for i in range(self.dim)]} }
mpi_bcast_dmem(params2=params2, mpi_pool=mpi_pool)
dmem_key_in_list = ['params2']
dmem_arg_in_list = ['params']
dmem_val_in_list = [params2]
# Init cache
scatter_tuple = (['x'], [x])
(cache, ) = mpi_map_alloc_dmem(
"allocate_cache_minimize_kl_divergence",
scatter_tuple=scatter_tuple,
dmem_key_out_list=['cache'],
obj=d2.transport_map, mpi_pool=mpi_pool,
concatenate=False)
def kl_divergence(a, params={}):
# Update distribution coefficients
d2.coeffs = a
# Evaluate
# Reset cache
dmem_key_in_list = ['cache']
dmem_arg_in_list = ['cache']
dmem_val_in_list = [cache]
mpi_map("reset_cache_minimize_kl_divergence",
dmem_key_in_list=dmem_key_in_list,
dmem_arg_in_list=dmem_arg_in_list,
dmem_val_in_list=dmem_val_in_list,
obj=d2.transport_map, mpi_pool=mpi_pool,
concatenate=False)
# Evaluate
out = KL.kl_divergence(d1, d2, params2=params2, x=x, w=w,
mpi_pool_tuple=(None,mpi_pool))
return out
def grad_a_kl_divergence(a, params={}):
# Update distribution coefficients
d2.coeffs = a
# Evaluate
# Reset cache
dmem_key_in_list = ['cache']
dmem_arg_in_list = ['cache']
dmem_val_in_list = [cache]
mpi_map("reset_cache_minimize_kl_divergence",
dmem_key_in_list=dmem_key_in_list,
dmem_arg_in_list=dmem_arg_in_list,
dmem_val_in_list=dmem_val_in_list,
obj=d2.transport_map, mpi_pool=mpi_pool,
concatenate=False)
# Fill cache with kl_divergence
KL.kl_divergence(d1, d2, params2=params2, x=x, w=w,
mpi_pool_tuple=(None,mpi_pool))
# Evaluate grad_a_kl_divergence using cached values
out = KL.grad_a_kl_divergence(d1, d2, params2=params2, x=x, w=w,
mpi_pool_tuple=(None,mpi_pool))
return out
flag = DC.fd_gradient_check(kl_divergence, grad_a_kl_divergence,
self.coeffs, self.fd_eps,
None, verbose=False)
self.assertTrue( flag )
[docs]class ParallelPool_PushForwardTMD_KL_divergence_DerivativeChecks(
ParallelPool_KL_divergence_DerivativeChecks):
[docs] def setUp(self):
import TransportMaps.Distributions as DIST
self.distribution = DIST.PushForwardParametricTransportMapDistribution( self.tm_approx,
self.distribution_pi )
super(ParallelPool_PushForwardTMD_KL_divergence_DerivativeChecks,self).setUp()
@unittest.skip("Not implemented")
[docs] def test_hess_a_kl_divergence(self):
pass
@unittest.skip("Not implemented")
[docs] def test_batch_hess_a_kl_divergence(self):
pass
@unittest.skip("Not implemented")
[docs] def test_action_hess_a_kl_divergence(self):
pass
@unittest.skip("Not implemented")
[docs] def test_batch_action_hess_a_kl_divergence(self):
pass
@unittest.skip("Not implemented")
[docs] def test_precomp_grad_a_kl_divergence(self):
pass
@unittest.skip("Not implemented")
[docs] def test_precomp_hess_a_kl_divergence(self):
pass
@unittest.skip("Not implemented")
[docs] def test_precomp_batch_grad_a_kl_divergence(self):
pass
@unittest.skip("Not implemented")
[docs] def test_precomp_batch_hess_a_kl_divergence(self):
pass
@unittest.skip("Not implemented")
[docs] def test_precomp_cached_grad_a_kl_divergence(self):
pass
@unittest.skip("Not implemented")
[docs] def test_precomp_cached_hess_a_kl_divergence(self):
pass
@unittest.skip("Not implemented")
[docs] def test_precomp_cached_batch_grad_a_kl_divergence(self):
pass
@unittest.skip("Not implemented")
[docs] def test_precomp_cached_batch_hess_a_kl_divergence(self):
pass
@unittest.skip("Not implemented")
[docs] def test_action_storage_hess_a_kl_divergence(self):
pass
@unittest.skip("Not implemented")
[docs] def test_batch_action_storage_hess_a_kl_divergence(self):
pass
@unittest.skip("Not implemented")
[docs] def test_precomp_action_storage_hess_a_kl_divergence(self):
pass
@unittest.skip("Not implemented")
[docs] def test_precomp_batch_action_storage_hess_a_kl_divergence(self):
pass
@unittest.skip("Not implemented")
[docs] def test_precomp_cached_action_storage_hess_a_kl_divergence(self):
pass
@unittest.skip("Not implemented")
[docs] def test_precomp_cached_batch_action_storage_hess_a_kl_divergence(self):
pass
[docs] def test_cached_grad_a_kl_divergence(self):
d1 = self.base_distribution
d2 = self.distribution
qtype = self.qtype
qparams = self.qparams
mpi_pool = self.mpi_pool
(x,w) = d1.quadrature(qtype, qparams)
# Distribute objects
mpi_bcast_dmem(d2=d2, mpi_pool=mpi_pool)
# Link tm to d2.transport_map
def link_tm_d2(d2):
return (d2.transport_map,)
(tm,) = mpi_map_alloc_dmem(
link_tm_d2, dmem_key_in_list=['d2'], dmem_arg_in_list=['d2'],
dmem_val_in_list=[d2], dmem_key_out_list=['tm'],
mpi_pool=mpi_pool)
# Init memory
params2 = {
'params_pi': None,
'params_t': {'components': [{} for i in range(self.dim)]} }
mpi_bcast_dmem(params2=params2, mpi_pool=mpi_pool)
dmem_key_in_list = ['params2']
dmem_arg_in_list = ['params']
dmem_val_in_list = [params2]
# Init cache
scatter_tuple = (['x'], [x])
(cache, ) = mpi_map_alloc_dmem(
"allocate_cache_minimize_kl_divergence",
scatter_tuple=scatter_tuple,
dmem_key_out_list=['cache'],
obj=d2.transport_map, mpi_pool=mpi_pool,
concatenate=False)
def kl_divergence(a, params={}):
# Update distribution coefficients
d2.coeffs = a
# Evaluate
# Reset cache
dmem_key_in_list = ['cache']
dmem_arg_in_list = ['cache']
dmem_val_in_list = [cache]
mpi_map("reset_cache_minimize_kl_divergence",
dmem_key_in_list=dmem_key_in_list,
dmem_arg_in_list=dmem_arg_in_list,
dmem_val_in_list=dmem_val_in_list,
obj=d2.transport_map, mpi_pool=mpi_pool,
concatenate=False)
# Evaluate
out = KL.kl_divergence(d1, d2, params2=params2, x=x, w=w,
mpi_pool_tuple=(None,mpi_pool))
return out
def grad_a_kl_divergence(a, params={}):
# Update distribution coefficients
d2.coeffs = a
# Evaluate
# Reset cache
dmem_key_in_list = ['cache']
dmem_arg_in_list = ['cache']
dmem_val_in_list = [cache]
mpi_map("reset_cache_minimize_kl_divergence",
dmem_key_in_list=dmem_key_in_list,
dmem_arg_in_list=dmem_arg_in_list,
dmem_val_in_list=dmem_val_in_list,
obj=d2.transport_map, mpi_pool=mpi_pool,
concatenate=False)
# Fill cache with kl_divergence
KL.kl_divergence(d1, d2, params2=params2, x=x, w=w,
mpi_pool_tuple=(None,mpi_pool))
# Evaluate grad_a_kl_divergence using cached values
out = KL.grad_a_kl_divergence(d1, d2, params2=params2, x=x, w=w,
mpi_pool_tuple=(None,mpi_pool))
return out
flag = DC.fd_gradient_check(kl_divergence, grad_a_kl_divergence,
self.coeffs, self.fd_eps,
None, verbose=False)
self.assertTrue( flag )
#
# Transport Map
#
[docs]class IntegratedExponentialTM(object):
[docs] def setUp_tm(self):
import SpectralToolbox.Spectral1D as S1D
import TransportMaps.Maps.Functionals as FUNC
import TransportMaps.Maps as MAPS
# Build the transport map (isotropic for each entry)
self.order = 4
approx_list = []
active_vars = []
for i in range(self.dim):
c_basis_list = [S1D.HermiteProbabilistsPolynomial()] * (i+1)
c_orders_list = ([self.order] * i) + [0]
c_approx = FUNC.LinearSpanTensorizedParametricFunctional(c_basis_list, spantype='full',
order_list=c_orders_list)
e_basis_list = [S1D.ConstantExtendedHermiteProbabilistsFunction()] * (i+1)
e_orders_list = [self.order - 1] * (i+1)
e_approx = FUNC.LinearSpanTensorizedParametricFunctional(e_basis_list, spantype='full',
order_list=e_orders_list)
approx = FUNC.IntegratedExponentialParametricMonotoneFunctional(c_approx, e_approx)
approx_list.append( approx )
active_vars.append( range(i+1) )
self.tm_approx = Maps.IntegratedExponentialParametricTriangularComponentwiseTransportMap(
active_vars=active_vars,
approx_list=approx_list
)
self.params = {}
self.params['params_t'] = None
self.coeffs = npr.randn(self.tm_approx.n_coeffs) / 10.
self.tm_approx.coeffs = self.coeffs
#
# Serial and parallel transport map tests
#
[docs]class Serial_IntegratedExponentialPBTMD_KL_divergence_DerivativeChecks(
IntegratedExponentialTM,
Serial_PullBackTMD_KL_divergence_DerivativeChecks):
[docs] def setUp(self):
super(Serial_IntegratedExponentialPBTMD_KL_divergence_DerivativeChecks,
self).setUp_tm()
super(Serial_IntegratedExponentialPBTMD_KL_divergence_DerivativeChecks,
self).setUp()
[docs]class Serial_IntegratedExponentialPFTMD_KL_divergence_DerivativeChecks(
IntegratedExponentialTM,
Serial_PushForwardTMD_KL_divergence_DerivativeChecks):
[docs] def setUp(self):
super(Serial_IntegratedExponentialPFTMD_KL_divergence_DerivativeChecks,
self).setUp_tm()
super(Serial_IntegratedExponentialPFTMD_KL_divergence_DerivativeChecks,
self).setUp()
[docs]class ParallelPool_IntegratedExponentialPBTMD_KL_divergence_DerivativeChecks(
IntegratedExponentialTM,
ParallelPool_PullBackTMD_KL_divergence_DerivativeChecks):
[docs] def setUp(self):
super(ParallelPool_IntegratedExponentialPBTMD_KL_divergence_DerivativeChecks,
self).setUp_tm()
super(ParallelPool_IntegratedExponentialPBTMD_KL_divergence_DerivativeChecks,
self).setUp()
[docs]class ParallelPool_IntegratedExponentialPFTMD_KL_divergence_DerivativeChecks(
IntegratedExponentialTM,
ParallelPool_PushForwardTMD_KL_divergence_DerivativeChecks):
[docs] def setUp(self):
super(ParallelPool_IntegratedExponentialPFTMD_KL_divergence_DerivativeChecks,
self).setUp_tm()
super(ParallelPool_IntegratedExponentialPFTMD_KL_divergence_DerivativeChecks,
self).setUp()
[docs]class IntegratedSquaredTM(object):
[docs] def setUp_tm(self):
import SpectralToolbox.Spectral1D as S1D
import TransportMaps.Maps.Functionals as FUNC
import TransportMaps.Maps as MAPS
# Build the transport map (isotropic for each entry)
self.order = 4
approx_list = []
active_vars = []
for i in range(self.dim):
c_basis_list = [S1D.HermiteProbabilistsPolynomial()] * (i+1)
c_orders_list = ([self.order] * i) + [0]
c_approx = FUNC.LinearSpanTensorizedParametricFunctional(c_basis_list, spantype='full',
order_list=c_orders_list)
e_basis_list = [S1D.ConstantExtendedHermiteProbabilistsFunction()] * (i+1)
e_orders_list = [self.order - 1] * (i+1)
e_approx = FUNC.LinearSpanTensorizedParametricFunctional(e_basis_list, spantype='full',
order_list=e_orders_list)
approx = FUNC.IntegratedSquaredParametricMonotoneFunctional(c_approx, e_approx)
approx_list.append( approx )
active_vars.append( range(i+1) )
self.tm_approx = Maps.IntegratedSquaredParametricTriangularComponentwiseTransportMap(
active_vars=active_vars,
approx_list=approx_list
)
self.params = {}
self.params['params_t'] = None
coeffs = self.tm_approx.get_identity_coeffs()
coeffs += npr.randn(len(coeffs)) / 100.
self.tm_approx.coeffs = coeffs
self.coeffs = coeffs
#
# Serial and parallel transport map tests
#
[docs]class Serial_IntegratedSquaredPBTMD_KL_divergence_DerivativeChecks(
IntegratedSquaredTM,
Serial_PullBackTMD_KL_divergence_DerivativeChecks):
[docs] def setUp(self):
super(Serial_IntegratedSquaredPBTMD_KL_divergence_DerivativeChecks,
self).setUp_tm()
super(Serial_IntegratedSquaredPBTMD_KL_divergence_DerivativeChecks,
self).setUp()
[docs]class Serial_IntegratedSquaredPFTMD_KL_divergence_DerivativeChecks(
IntegratedSquaredTM,
Serial_PushForwardTMD_KL_divergence_DerivativeChecks):
[docs] def setUp(self):
super(Serial_IntegratedSquaredPFTMD_KL_divergence_DerivativeChecks,
self).setUp_tm()
super(Serial_IntegratedSquaredPFTMD_KL_divergence_DerivativeChecks,
self).setUp()
[docs]class ParallelPool_IntegratedSquaredPBTMD_KL_divergence_DerivativeChecks(
IntegratedSquaredTM,
ParallelPool_PullBackTMD_KL_divergence_DerivativeChecks):
[docs] def setUp(self):
super(ParallelPool_IntegratedSquaredPBTMD_KL_divergence_DerivativeChecks,
self).setUp_tm()
super(ParallelPool_IntegratedSquaredPBTMD_KL_divergence_DerivativeChecks,
self).setUp()
[docs]class ParallelPool_IntegratedSquaredPFTMD_KL_divergence_DerivativeChecks(
IntegratedSquaredTM,
ParallelPool_PushForwardTMD_KL_divergence_DerivativeChecks):
[docs] def setUp(self):
super(ParallelPool_IntegratedSquaredPFTMD_KL_divergence_DerivativeChecks,
self).setUp_tm()
super(ParallelPool_IntegratedSquaredPFTMD_KL_divergence_DerivativeChecks,
self).setUp()
#
# Specific tests
#
[docs]class TMD_TestCase(object):
[docs] def _setUp_tcase(self):
import TransportMaps.Distributions as DIST
self.dim = self.setup['dim']
self.target_distribution = self.Tparams['target_distribution']
self.support_map = self.Tparams['support_map']
self.distribution_pi = DIST.PullBackTransportMapDistribution(
self.support_map,
self.target_distribution
)
self.base_distribution = self.Tparams['base_distribution']
self.qtype = 3
self.qparams = [2]*self.dim
[docs]class Linear1D_TMD_TestCase(TMD_TestCase):
[docs] def _setUp_tcase(self):
import TransportMaps.tests.TestFunctions as TF
title, self.setup, self.Tparams = TF.get(0)
super(Linear1D_TMD_TestCase,self)._setUp_tcase()
[docs]class ArcTan1D_TMD_TestCase(TMD_TestCase):
[docs] def _setUp_tcase(self):
import TransportMaps.tests.TestFunctions as TF
title, self.setup, self.Tparams = TF.get(2)
super(ArcTan1D_TMD_TestCase,self)._setUp_tcase()
[docs]class Exp1D_TMD_TestCase(TMD_TestCase):
[docs] def _setUp_tcase(self):
import TransportMaps.tests.TestFunctions as TF
title, self.setup, self.Tparams = TF.get(3)
super(Exp1D_TMD_TestCase,self)._setUp_tcase()
[docs]class Logistic1D_TMD_TestCase(TMD_TestCase):
[docs] def _setUp_tcase(self):
import TransportMaps.tests.TestFunctions as TF
title, self.setup, self.Tparams = TF.get(4)
super(Logistic1D_TMD_TestCase,self)._setUp_tcase()
[docs]class Gamma1D_TMD_TestCase(TMD_TestCase):
[docs] def _setUp_tcase(self):
import TransportMaps.tests.TestFunctions as TF
title, self.setup, self.Tparams = TF.get(5)
super(Gamma1D_TMD_TestCase,self)._setUp_tcase()
[docs]class Beta1D_TMD_TestCase(TMD_TestCase):
[docs] def _setUp_tcase(self):
import TransportMaps.tests.TestFunctions as TF
title, self.setup, self.Tparams = TF.get(6)
super(Beta1D_TMD_TestCase,self)._setUp_tcase()
[docs]class Gumbel1D_TMD_TestCase(TMD_TestCase):
[docs] def _setUp_tcase(self):
import TransportMaps.tests.TestFunctions as TF
title, self.setup, self.Tparams = TF.get(7)
super(Gumbel1D_TMD_TestCase,self)._setUp_tcase()
[docs]class Linear2D_TMD_TestCase(TMD_TestCase):
[docs] def _setUp_tcase(self):
import TransportMaps.tests.TestFunctions as TF
title, self.setup, self.Tparams = TF.get(9)
super(Linear2D_TMD_TestCase,self)._setUp_tcase()
[docs]class Banana2D_TMD_TestCase(TMD_TestCase):
[docs] def _setUp_tcase(self):
import TransportMaps.tests.TestFunctions as TF
title, self.setup, self.Tparams = TF.get(10)
super(Banana2D_TMD_TestCase,self)._setUp_tcase()
#
# Serial PullBack tests
#
[docs]class Linear1D_Serial_IEPBTMD_KLdiv_DerivativeChecks(
Linear1D_TMD_TestCase,
Serial_IntegratedExponentialPBTMD_KL_divergence_DerivativeChecks,
unittest.TestCase):
[docs] def setUp(self):
super(Linear1D_Serial_IEPBTMD_KLdiv_DerivativeChecks,self)._setUp_tcase()
super(Linear1D_Serial_IEPBTMD_KLdiv_DerivativeChecks,self).setUp()
[docs]class ArcTan1D_Serial_IEPBTMD_KLdiv_DerivativeChecks(
ArcTan1D_TMD_TestCase,
Serial_IntegratedExponentialPBTMD_KL_divergence_DerivativeChecks,
unittest.TestCase):
[docs] def setUp(self):
super(ArcTan1D_Serial_IEPBTMD_KLdiv_DerivativeChecks,self)._setUp_tcase()
super(ArcTan1D_Serial_IEPBTMD_KLdiv_DerivativeChecks,self).setUp()
[docs]class Exp1D_Serial_IEPBTMD_KLdiv_DerivativeChecks(
Exp1D_TMD_TestCase,
Serial_IntegratedExponentialPBTMD_KL_divergence_DerivativeChecks,
unittest.TestCase):
[docs] def setUp(self):
super(Exp1D_Serial_IEPBTMD_KLdiv_DerivativeChecks,self)._setUp_tcase()
super(Exp1D_Serial_IEPBTMD_KLdiv_DerivativeChecks,self).setUp()
[docs]class Logistic1D_Serial_IEPBTMD_KLdiv_DerivativeChecks(
Logistic1D_TMD_TestCase,
Serial_IntegratedExponentialPBTMD_KL_divergence_DerivativeChecks,
unittest.TestCase):
[docs] def setUp(self):
super(Logistic1D_Serial_IEPBTMD_KLdiv_DerivativeChecks,self)._setUp_tcase()
super(Logistic1D_Serial_IEPBTMD_KLdiv_DerivativeChecks,self).setUp()
[docs]class Gamma1D_Serial_IEPBTMD_KLdiv_DerivativeChecks(
Gamma1D_TMD_TestCase,
Serial_IntegratedExponentialPBTMD_KL_divergence_DerivativeChecks,
unittest.TestCase):
[docs] def setUp(self):
super(Gamma1D_Serial_IEPBTMD_KLdiv_DerivativeChecks,self)._setUp_tcase()
super(Gamma1D_Serial_IEPBTMD_KLdiv_DerivativeChecks,self).setUp()
[docs]class Beta1D_Serial_IEPBTMD_KLdiv_DerivativeChecks(
Beta1D_TMD_TestCase,
Serial_IntegratedExponentialPBTMD_KL_divergence_DerivativeChecks,
unittest.TestCase):
[docs] def setUp(self):
super(Beta1D_Serial_IEPBTMD_KLdiv_DerivativeChecks,self)._setUp_tcase()
super(Beta1D_Serial_IEPBTMD_KLdiv_DerivativeChecks,self).setUp()
[docs]class Gumbel1D_Serial_IEPBTMD_KLdiv_DerivativeChecks(
Gumbel1D_TMD_TestCase,
Serial_IntegratedExponentialPBTMD_KL_divergence_DerivativeChecks,
unittest.TestCase):
[docs] def setUp(self):
super(Gumbel1D_Serial_IEPBTMD_KLdiv_DerivativeChecks,self)._setUp_tcase()
super(Gumbel1D_Serial_IEPBTMD_KLdiv_DerivativeChecks,self).setUp()
[docs]class Linear2D_Serial_IEPBTMD_KLdiv_DerivativeChecks(
Linear2D_TMD_TestCase,
Serial_IntegratedExponentialPBTMD_KL_divergence_DerivativeChecks,
unittest.TestCase):
[docs] def setUp(self):
super(Linear2D_Serial_IEPBTMD_KLdiv_DerivativeChecks,self)._setUp_tcase()
super(Linear2D_Serial_IEPBTMD_KLdiv_DerivativeChecks,self).setUp()
[docs]class Banana2D_Serial_IEPBTMD_KLdiv_DerivativeChecks(
Banana2D_TMD_TestCase,
Serial_IntegratedExponentialPBTMD_KL_divergence_DerivativeChecks,
unittest.TestCase):
[docs] def setUp(self):
super(Banana2D_Serial_IEPBTMD_KLdiv_DerivativeChecks,self)._setUp_tcase()
super(Banana2D_Serial_IEPBTMD_KLdiv_DerivativeChecks,self).setUp()
# #
# # Serial PushForward tests
# #
# class Linear1D_Serial_IEPFTMD_KLdiv_DerivativeChecks(
# Linear1D_TMD_TestCase,
# Serial_IntegratedExponentialPFTMD_KL_divergence_DerivativeChecks,
# unittest.TestCase):
# def setUp(self):
# super(Linear1D_Serial_IEPFTMD_KLdiv_DerivativeChecks,self)._setUp_tcase()
# super(Linear1D_Serial_IEPFTMD_KLdiv_DerivativeChecks,self).setUp()
# class ArcTan1D_Serial_IEPFTMD_KLdiv_DerivativeChecks(
# ArcTan1D_TMD_TestCase,
# Serial_IntegratedExponentialPFTMD_KL_divergence_DerivativeChecks,
# unittest.TestCase):
# def setUp(self):
# super(ArcTan1D_Serial_IEPFTMD_KLdiv_DerivativeChecks,self)._setUp_tcase()
# super(ArcTan1D_Serial_IEPFTMD_KLdiv_DerivativeChecks,self).setUp()
# class Exp1D_Serial_IEPFTMD_KLdiv_DerivativeChecks(
# Exp1D_TMD_TestCase,
# Serial_IntegratedExponentialPFTMD_KL_divergence_DerivativeChecks,
# unittest.TestCase):
# def setUp(self):
# super(Exp1D_Serial_IEPFTMD_KLdiv_DerivativeChecks,self)._setUp_tcase()
# super(Exp1D_Serial_IEPFTMD_KLdiv_DerivativeChecks,self).setUp()
# class Logistic1D_Serial_IEPFTMD_KLdiv_DerivativeChecks(
# Logistic1D_TMD_TestCase,
# Serial_IntegratedExponentialPFTMD_KL_divergence_DerivativeChecks,
# unittest.TestCase):
# def setUp(self):
# super(Logistic1D_Serial_IEPFTMD_KLdiv_DerivativeChecks,self)._setUp_tcase()
# super(Logistic1D_Serial_IEPFTMD_KLdiv_DerivativeChecks,self).setUp()
# class Gamma1D_Serial_IEPFTMD_KLdiv_DerivativeChecks(
# Gamma1D_TMD_TestCase,
# Serial_IntegratedExponentialPFTMD_KL_divergence_DerivativeChecks,
# unittest.TestCase):
# def setUp(self):
# super(Gamma1D_Serial_IEPFTMD_KLdiv_DerivativeChecks,self)._setUp_tcase()
# super(Gamma1D_Serial_IEPFTMD_KLdiv_DerivativeChecks,self).setUp()
# class Beta1D_Serial_IEPFTMD_KLdiv_DerivativeChecks(
# Beta1D_TMD_TestCase,
# Serial_IntegratedExponentialPFTMD_KL_divergence_DerivativeChecks,
# unittest.TestCase):
# def setUp(self):
# super(Beta1D_Serial_IEPFTMD_KLdiv_DerivativeChecks,self)._setUp_tcase()
# super(Beta1D_Serial_IEPFTMD_KLdiv_DerivativeChecks,self).setUp()
# class Gumbel1D_Serial_IEPFTMD_KLdiv_DerivativeChecks(
# Gumbel1D_TMD_TestCase,
# Serial_IntegratedExponentialPFTMD_KL_divergence_DerivativeChecks,
# unittest.TestCase):
# def setUp(self):
# super(Gumbel1D_Serial_IEPFTMD_KLdiv_DerivativeChecks,self)._setUp_tcase()
# super(Gumbel1D_Serial_IEPFTMD_KLdiv_DerivativeChecks,self).setUp()
# class Linear2D_Serial_IEPFTMD_KLdiv_DerivativeChecks(
# Linear2D_TMD_TestCase,
# Serial_IntegratedExponentialPFTMD_KL_divergence_DerivativeChecks,
# unittest.TestCase):
# def setUp(self):
# super(Linear2D_Serial_IEPFTMD_KLdiv_DerivativeChecks,self)._setUp_tcase()
# super(Linear2D_Serial_IEPFTMD_KLdiv_DerivativeChecks,self).setUp()
# class Banana2D_Serial_IEPFTMD_KLdiv_DerivativeChecks(
# Banana2D_TMD_TestCase,
# Serial_IntegratedExponentialPFTMD_KL_divergence_DerivativeChecks,
# unittest.TestCase):
# def setUp(self):
# super(Banana2D_Serial_IEPFTMD_KLdiv_DerivativeChecks,self)._setUp_tcase()
# super(Banana2D_Serial_IEPFTMD_KLdiv_DerivativeChecks,self).setUp()
#
# ParallelPool PullBack tests
#
[docs]class Linear1D_ParallelPool_IEPBTMD_KLdiv_DerivativeChecks(
Linear1D_TMD_TestCase,
ParallelPool_IntegratedExponentialPBTMD_KL_divergence_DerivativeChecks,
unittest.TestCase):
[docs] def setUp(self):
super(Linear1D_ParallelPool_IEPBTMD_KLdiv_DerivativeChecks,self)._setUp_tcase()
super(Linear1D_ParallelPool_IEPBTMD_KLdiv_DerivativeChecks,self).setUp()
[docs]class ArcTan1D_ParallelPool_IEPBTMD_KLdiv_DerivativeChecks(
ArcTan1D_TMD_TestCase,
ParallelPool_IntegratedExponentialPBTMD_KL_divergence_DerivativeChecks,
unittest.TestCase):
[docs] def setUp(self):
super(ArcTan1D_ParallelPool_IEPBTMD_KLdiv_DerivativeChecks,self)._setUp_tcase()
super(ArcTan1D_ParallelPool_IEPBTMD_KLdiv_DerivativeChecks,self).setUp()
[docs]class Exp1D_ParallelPool_IEPBTMD_KLdiv_DerivativeChecks(
Exp1D_TMD_TestCase,
ParallelPool_IntegratedExponentialPBTMD_KL_divergence_DerivativeChecks,
unittest.TestCase):
[docs] def setUp(self):
super(Exp1D_ParallelPool_IEPBTMD_KLdiv_DerivativeChecks,self)._setUp_tcase()
super(Exp1D_ParallelPool_IEPBTMD_KLdiv_DerivativeChecks,self).setUp()
[docs]class Logistic1D_ParallelPool_IEPBTMD_KLdiv_DerivativeChecks(
Logistic1D_TMD_TestCase,
ParallelPool_IntegratedExponentialPBTMD_KL_divergence_DerivativeChecks,
unittest.TestCase):
[docs] def setUp(self):
super(Logistic1D_ParallelPool_IEPBTMD_KLdiv_DerivativeChecks,self)._setUp_tcase()
super(Logistic1D_ParallelPool_IEPBTMD_KLdiv_DerivativeChecks,self).setUp()
[docs]class Gamma1D_ParallelPool_IEPBTMD_KLdiv_DerivativeChecks(
Gamma1D_TMD_TestCase,
ParallelPool_IntegratedExponentialPBTMD_KL_divergence_DerivativeChecks,
unittest.TestCase):
[docs] def setUp(self):
super(Gamma1D_ParallelPool_IEPBTMD_KLdiv_DerivativeChecks,self)._setUp_tcase()
super(Gamma1D_ParallelPool_IEPBTMD_KLdiv_DerivativeChecks,self).setUp()
[docs]class Beta1D_ParallelPool_IEPBTMD_KLdiv_DerivativeChecks(
Beta1D_TMD_TestCase,
ParallelPool_IntegratedExponentialPBTMD_KL_divergence_DerivativeChecks,
unittest.TestCase):
[docs] def setUp(self):
super(Beta1D_ParallelPool_IEPBTMD_KLdiv_DerivativeChecks,self)._setUp_tcase()
super(Beta1D_ParallelPool_IEPBTMD_KLdiv_DerivativeChecks,self).setUp()
[docs]class Gumbel1D_ParallelPool_IEPBTMD_KLdiv_DerivativeChecks(
Gumbel1D_TMD_TestCase,
ParallelPool_IntegratedExponentialPBTMD_KL_divergence_DerivativeChecks,
unittest.TestCase):
[docs] def setUp(self):
super(Gumbel1D_ParallelPool_IEPBTMD_KLdiv_DerivativeChecks,self)._setUp_tcase()
super(Gumbel1D_ParallelPool_IEPBTMD_KLdiv_DerivativeChecks,self).setUp()
[docs]class Linear2D_ParallelPool_IEPBTMD_KLdiv_DerivativeChecks(
Linear2D_TMD_TestCase,
ParallelPool_IntegratedExponentialPBTMD_KL_divergence_DerivativeChecks,
unittest.TestCase):
[docs] def setUp(self):
super(Linear2D_ParallelPool_IEPBTMD_KLdiv_DerivativeChecks,self)._setUp_tcase()
super(Linear2D_ParallelPool_IEPBTMD_KLdiv_DerivativeChecks,self).setUp()
[docs]class Banana2D_ParallelPool_IEPBTMD_KLdiv_DerivativeChecks(
Banana2D_TMD_TestCase,
ParallelPool_IntegratedExponentialPBTMD_KL_divergence_DerivativeChecks,
unittest.TestCase):
[docs] def setUp(self):
super(Banana2D_ParallelPool_IEPBTMD_KLdiv_DerivativeChecks,self)._setUp_tcase()
super(Banana2D_ParallelPool_IEPBTMD_KLdiv_DerivativeChecks,self).setUp()
# #
# # ParallelPool PushForward tests
# #
# class Linear1D_ParallelPool_IEPFTMD_KLdiv_DerivativeChecks(
# Linear1D_TMD_TestCase,
# ParallelPool_IntegratedExponentialPFTMD_KL_divergence_DerivativeChecks,
# unittest.TestCase):
# def setUp(self):
# super(Linear1D_ParallelPool_IEPFTMD_KLdiv_DerivativeChecks,self)._setUp_tcase()
# super(Linear1D_ParallelPool_IEPFTMD_KLdiv_DerivativeChecks,self).setUp()
# class ArcTan1D_ParallelPool_IEPFTMD_KLdiv_DerivativeChecks(
# ArcTan1D_TMD_TestCase,
# ParallelPool_IntegratedExponentialPFTMD_KL_divergence_DerivativeChecks,
# unittest.TestCase):
# def setUp(self):
# super(ArcTan1D_ParallelPool_IEPFTMD_KLdiv_DerivativeChecks,self)._setUp_tcase()
# super(ArcTan1D_ParallelPool_IEPFTMD_KLdiv_DerivativeChecks,self).setUp()
# class Exp1D_ParallelPool_IEPFTMD_KLdiv_DerivativeChecks(
# Exp1D_TMD_TestCase,
# ParallelPool_IntegratedExponentialPFTMD_KL_divergence_DerivativeChecks,
# unittest.TestCase):
# def setUp(self):
# super(Exp1D_ParallelPool_IEPFTMD_KLdiv_DerivativeChecks,self)._setUp_tcase()
# super(Exp1D_ParallelPool_IEPFTMD_KLdiv_DerivativeChecks,self).setUp()
# class Logistic1D_ParallelPool_IEPFTMD_KLdiv_DerivativeChecks(
# Logistic1D_TMD_TestCase,
# ParallelPool_IntegratedExponentialPFTMD_KL_divergence_DerivativeChecks,
# unittest.TestCase):
# def setUp(self):
# super(Logistic1D_ParallelPool_IEPFTMD_KLdiv_DerivativeChecks,self)._setUp_tcase()
# super(Logistic1D_ParallelPool_IEPFTMD_KLdiv_DerivativeChecks,self).setUp()
# class Gamma1D_ParallelPool_IEPFTMD_KLdiv_DerivativeChecks(
# Gamma1D_TMD_TestCase,
# ParallelPool_IntegratedExponentialPFTMD_KL_divergence_DerivativeChecks,
# unittest.TestCase):
# def setUp(self):
# super(Gamma1D_ParallelPool_IEPFTMD_KLdiv_DerivativeChecks,self)._setUp_tcase()
# super(Gamma1D_ParallelPool_IEPFTMD_KLdiv_DerivativeChecks,self).setUp()
# class Beta1D_ParallelPool_IEPFTMD_KLdiv_DerivativeChecks(
# Beta1D_TMD_TestCase,
# ParallelPool_IntegratedExponentialPFTMD_KL_divergence_DerivativeChecks,
# unittest.TestCase):
# def setUp(self):
# super(Beta1D_ParallelPool_IEPFTMD_KLdiv_DerivativeChecks,self)._setUp_tcase()
# super(Beta1D_ParallelPool_IEPFTMD_KLdiv_DerivativeChecks,self).setUp()
# class Gumbel1D_ParallelPool_IEPFTMD_KLdiv_DerivativeChecks(
# Gumbel1D_TMD_TestCase,
# ParallelPool_IntegratedExponentialPFTMD_KL_divergence_DerivativeChecks,
# unittest.TestCase):
# def setUp(self):
# super(Gumbel1D_ParallelPool_IEPFTMD_KLdiv_DerivativeChecks,self)._setUp_tcase()
# super(Gumbel1D_ParallelPool_IEPFTMD_KLdiv_DerivativeChecks,self).setUp()
# class Linear2D_ParallelPool_IEPFTMD_KLdiv_DerivativeChecks(
# Linear2D_TMD_TestCase,
# ParallelPool_IntegratedExponentialPFTMD_KL_divergence_DerivativeChecks,
# unittest.TestCase):
# def setUp(self):
# super(Linear2D_ParallelPool_IEPFTMD_KLdiv_DerivativeChecks,self)._setUp_tcase()
# super(Linear2D_ParallelPool_IEPFTMD_KLdiv_DerivativeChecks,self).setUp()
# class Banana2D_ParallelPool_IEPFTMD_KLdiv_DerivativeChecks(
# Banana2D_TMD_TestCase,
# ParallelPool_IntegratedExponentialPFTMD_KL_divergence_DerivativeChecks,
# unittest.TestCase):
# def setUp(self):
# super(Banana2D_ParallelPool_IEPFTMD_KLdiv_DerivativeChecks,self)._setUp_tcase()
# super(Banana2D_ParallelPool_IEPFTMD_KLdiv_DerivativeChecks,self).setUp()
#
# Integrated Squared Serial PullBack tests
#
[docs]class Linear1D_Serial_ISPBTMD_KLdiv_DerivativeChecks(
Linear1D_TMD_TestCase,
Serial_IntegratedSquaredPBTMD_KL_divergence_DerivativeChecks,
unittest.TestCase):
[docs] def setUp(self):
super(Linear1D_Serial_ISPBTMD_KLdiv_DerivativeChecks,self)._setUp_tcase()
super(Linear1D_Serial_ISPBTMD_KLdiv_DerivativeChecks,self).setUp()
[docs]class ArcTan1D_Serial_ISPBTMD_KLdiv_DerivativeChecks(
ArcTan1D_TMD_TestCase,
Serial_IntegratedSquaredPBTMD_KL_divergence_DerivativeChecks,
unittest.TestCase):
[docs] def setUp(self):
super(ArcTan1D_Serial_ISPBTMD_KLdiv_DerivativeChecks,self)._setUp_tcase()
super(ArcTan1D_Serial_ISPBTMD_KLdiv_DerivativeChecks,self).setUp()
[docs]class Exp1D_Serial_ISPBTMD_KLdiv_DerivativeChecks(
Exp1D_TMD_TestCase,
Serial_IntegratedSquaredPBTMD_KL_divergence_DerivativeChecks,
unittest.TestCase):
[docs] def setUp(self):
super(Exp1D_Serial_ISPBTMD_KLdiv_DerivativeChecks,self)._setUp_tcase()
super(Exp1D_Serial_ISPBTMD_KLdiv_DerivativeChecks,self).setUp()
[docs]class Logistic1D_Serial_ISPBTMD_KLdiv_DerivativeChecks(
Logistic1D_TMD_TestCase,
Serial_IntegratedSquaredPBTMD_KL_divergence_DerivativeChecks,
unittest.TestCase):
[docs] def setUp(self):
super(Logistic1D_Serial_ISPBTMD_KLdiv_DerivativeChecks,self)._setUp_tcase()
super(Logistic1D_Serial_ISPBTMD_KLdiv_DerivativeChecks,self).setUp()
[docs]class Gamma1D_Serial_ISPBTMD_KLdiv_DerivativeChecks(
Gamma1D_TMD_TestCase,
Serial_IntegratedSquaredPBTMD_KL_divergence_DerivativeChecks,
unittest.TestCase):
[docs] def setUp(self):
super(Gamma1D_Serial_ISPBTMD_KLdiv_DerivativeChecks,self)._setUp_tcase()
super(Gamma1D_Serial_ISPBTMD_KLdiv_DerivativeChecks,self).setUp()
[docs]class Beta1D_Serial_ISPBTMD_KLdiv_DerivativeChecks(
Beta1D_TMD_TestCase,
Serial_IntegratedSquaredPBTMD_KL_divergence_DerivativeChecks,
unittest.TestCase):
[docs] def setUp(self):
super(Beta1D_Serial_ISPBTMD_KLdiv_DerivativeChecks,self)._setUp_tcase()
super(Beta1D_Serial_ISPBTMD_KLdiv_DerivativeChecks,self).setUp()
[docs]class Gumbel1D_Serial_ISPBTMD_KLdiv_DerivativeChecks(
Gumbel1D_TMD_TestCase,
Serial_IntegratedSquaredPBTMD_KL_divergence_DerivativeChecks,
unittest.TestCase):
[docs] def setUp(self):
super(Gumbel1D_Serial_ISPBTMD_KLdiv_DerivativeChecks,self)._setUp_tcase()
super(Gumbel1D_Serial_ISPBTMD_KLdiv_DerivativeChecks,self).setUp()
[docs]class Linear2D_Serial_ISPBTMD_KLdiv_DerivativeChecks(
Linear2D_TMD_TestCase,
Serial_IntegratedSquaredPBTMD_KL_divergence_DerivativeChecks,
unittest.TestCase):
[docs] def setUp(self):
super(Linear2D_Serial_ISPBTMD_KLdiv_DerivativeChecks,self)._setUp_tcase()
super(Linear2D_Serial_ISPBTMD_KLdiv_DerivativeChecks,self).setUp()
[docs]class Banana2D_Serial_ISPBTMD_KLdiv_DerivativeChecks(
Banana2D_TMD_TestCase,
Serial_IntegratedSquaredPBTMD_KL_divergence_DerivativeChecks,
unittest.TestCase):
[docs] def setUp(self):
super(Banana2D_Serial_ISPBTMD_KLdiv_DerivativeChecks,self)._setUp_tcase()
super(Banana2D_Serial_ISPBTMD_KLdiv_DerivativeChecks,self).setUp()
# #
# # Integrated Squared Serial PushForward tests
# #
# class Linear1D_Serial_ISPFTMD_KLdiv_DerivativeChecks(
# Linear1D_TMD_TestCase,
# Serial_IntegratedSquaredPFTMD_KL_divergence_DerivativeChecks,
# unittest.TestCase):
# def setUp(self):
# super(Linear1D_Serial_ISPFTMD_KLdiv_DerivativeChecks,self)._setUp_tcase()
# super(Linear1D_Serial_ISPFTMD_KLdiv_DerivativeChecks,self).setUp()
# class ArcTan1D_Serial_ISPFTMD_KLdiv_DerivativeChecks(
# ArcTan1D_TMD_TestCase,
# Serial_IntegratedSquaredPFTMD_KL_divergence_DerivativeChecks,
# unittest.TestCase):
# def setUp(self):
# super(ArcTan1D_Serial_ISPFTMD_KLdiv_DerivativeChecks,self)._setUp_tcase()
# super(ArcTan1D_Serial_ISPFTMD_KLdiv_DerivativeChecks,self).setUp()
# class Exp1D_Serial_ISPFTMD_KLdiv_DerivativeChecks(
# Exp1D_TMD_TestCase,
# Serial_IntegratedSquaredPFTMD_KL_divergence_DerivativeChecks,
# unittest.TestCase):
# def setUp(self):
# super(Exp1D_Serial_ISPFTMD_KLdiv_DerivativeChecks,self)._setUp_tcase()
# super(Exp1D_Serial_ISPFTMD_KLdiv_DerivativeChecks,self).setUp()
# class Logistic1D_Serial_ISPFTMD_KLdiv_DerivativeChecks(
# Logistic1D_TMD_TestCase,
# Serial_IntegratedSquaredPFTMD_KL_divergence_DerivativeChecks,
# unittest.TestCase):
# def setUp(self):
# super(Logistic1D_Serial_ISPFTMD_KLdiv_DerivativeChecks,self)._setUp_tcase()
# super(Logistic1D_Serial_ISPFTMD_KLdiv_DerivativeChecks,self).setUp()
# class Gamma1D_Serial_ISPFTMD_KLdiv_DerivativeChecks(
# Gamma1D_TMD_TestCase,
# Serial_IntegratedSquaredPFTMD_KL_divergence_DerivativeChecks,
# unittest.TestCase):
# def setUp(self):
# super(Gamma1D_Serial_ISPFTMD_KLdiv_DerivativeChecks,self)._setUp_tcase()
# super(Gamma1D_Serial_ISPFTMD_KLdiv_DerivativeChecks,self).setUp()
# class Beta1D_Serial_ISPFTMD_KLdiv_DerivativeChecks(
# Beta1D_TMD_TestCase,
# Serial_IntegratedSquaredPFTMD_KL_divergence_DerivativeChecks,
# unittest.TestCase):
# def setUp(self):
# super(Beta1D_Serial_ISPFTMD_KLdiv_DerivativeChecks,self)._setUp_tcase()
# super(Beta1D_Serial_ISPFTMD_KLdiv_DerivativeChecks,self).setUp()
# class Gumbel1D_Serial_ISPFTMD_KLdiv_DerivativeChecks(
# Gumbel1D_TMD_TestCase,
# Serial_IntegratedSquaredPFTMD_KL_divergence_DerivativeChecks,
# unittest.TestCase):
# def setUp(self):
# super(Gumbel1D_Serial_ISPFTMD_KLdiv_DerivativeChecks,self)._setUp_tcase()
# super(Gumbel1D_Serial_ISPFTMD_KLdiv_DerivativeChecks,self).setUp()
# class Linear2D_Serial_ISPFTMD_KLdiv_DerivativeChecks(
# Linear2D_TMD_TestCase,
# Serial_IntegratedSquaredPFTMD_KL_divergence_DerivativeChecks,
# unittest.TestCase):
# def setUp(self):
# super(Linear2D_Serial_ISPFTMD_KLdiv_DerivativeChecks,self)._setUp_tcase()
# super(Linear2D_Serial_ISPFTMD_KLdiv_DerivativeChecks,self).setUp()
# class Banana2D_Serial_ISPFTMD_KLdiv_DerivativeChecks(
# Banana2D_TMD_TestCase,
# Serial_IntegratedSquaredPFTMD_KL_divergence_DerivativeChecks,
# unittest.TestCase):
# def setUp(self):
# super(Banana2D_Serial_ISPFTMD_KLdiv_DerivativeChecks,self)._setUp_tcase()
# super(Banana2D_Serial_ISPFTMD_KLdiv_DerivativeChecks,self).setUp()
#
# Integrated Squared ParallelPool PullBack tests
#
[docs]class Linear1D_ParallelPool_ISPBTMD_KLdiv_DerivativeChecks(
Linear1D_TMD_TestCase,
ParallelPool_IntegratedSquaredPBTMD_KL_divergence_DerivativeChecks,
unittest.TestCase):
[docs] def setUp(self):
super(Linear1D_ParallelPool_ISPBTMD_KLdiv_DerivativeChecks,self)._setUp_tcase()
super(Linear1D_ParallelPool_ISPBTMD_KLdiv_DerivativeChecks,self).setUp()
[docs]class ArcTan1D_ParallelPool_ISPBTMD_KLdiv_DerivativeChecks(
ArcTan1D_TMD_TestCase,
ParallelPool_IntegratedSquaredPBTMD_KL_divergence_DerivativeChecks,
unittest.TestCase):
[docs] def setUp(self):
super(ArcTan1D_ParallelPool_ISPBTMD_KLdiv_DerivativeChecks,self)._setUp_tcase()
super(ArcTan1D_ParallelPool_ISPBTMD_KLdiv_DerivativeChecks,self).setUp()
[docs]class Exp1D_ParallelPool_ISPBTMD_KLdiv_DerivativeChecks(
Exp1D_TMD_TestCase,
ParallelPool_IntegratedSquaredPBTMD_KL_divergence_DerivativeChecks,
unittest.TestCase):
[docs] def setUp(self):
super(Exp1D_ParallelPool_ISPBTMD_KLdiv_DerivativeChecks,self)._setUp_tcase()
super(Exp1D_ParallelPool_ISPBTMD_KLdiv_DerivativeChecks,self).setUp()
[docs]class Logistic1D_ParallelPool_ISPBTMD_KLdiv_DerivativeChecks(
Logistic1D_TMD_TestCase,
ParallelPool_IntegratedSquaredPBTMD_KL_divergence_DerivativeChecks,
unittest.TestCase):
[docs] def setUp(self):
super(Logistic1D_ParallelPool_ISPBTMD_KLdiv_DerivativeChecks,self)._setUp_tcase()
super(Logistic1D_ParallelPool_ISPBTMD_KLdiv_DerivativeChecks,self).setUp()
[docs]class Gamma1D_ParallelPool_ISPBTMD_KLdiv_DerivativeChecks(
Gamma1D_TMD_TestCase,
ParallelPool_IntegratedSquaredPBTMD_KL_divergence_DerivativeChecks,
unittest.TestCase):
[docs] def setUp(self):
super(Gamma1D_ParallelPool_ISPBTMD_KLdiv_DerivativeChecks,self)._setUp_tcase()
super(Gamma1D_ParallelPool_ISPBTMD_KLdiv_DerivativeChecks,self).setUp()
[docs]class Beta1D_ParallelPool_ISPBTMD_KLdiv_DerivativeChecks(
Beta1D_TMD_TestCase,
ParallelPool_IntegratedSquaredPBTMD_KL_divergence_DerivativeChecks,
unittest.TestCase):
[docs] def setUp(self):
super(Beta1D_ParallelPool_ISPBTMD_KLdiv_DerivativeChecks,self)._setUp_tcase()
super(Beta1D_ParallelPool_ISPBTMD_KLdiv_DerivativeChecks,self).setUp()
[docs]class Gumbel1D_ParallelPool_ISPBTMD_KLdiv_DerivativeChecks(
Gumbel1D_TMD_TestCase,
ParallelPool_IntegratedSquaredPBTMD_KL_divergence_DerivativeChecks,
unittest.TestCase):
[docs] def setUp(self):
super(Gumbel1D_ParallelPool_ISPBTMD_KLdiv_DerivativeChecks,self)._setUp_tcase()
super(Gumbel1D_ParallelPool_ISPBTMD_KLdiv_DerivativeChecks,self).setUp()
[docs]class Linear2D_ParallelPool_ISPBTMD_KLdiv_DerivativeChecks(
Linear2D_TMD_TestCase,
ParallelPool_IntegratedSquaredPBTMD_KL_divergence_DerivativeChecks,
unittest.TestCase):
[docs] def setUp(self):
super(Linear2D_ParallelPool_ISPBTMD_KLdiv_DerivativeChecks,self)._setUp_tcase()
super(Linear2D_ParallelPool_ISPBTMD_KLdiv_DerivativeChecks,self).setUp()
[docs]class Banana2D_ParallelPool_ISPBTMD_KLdiv_DerivativeChecks(
Banana2D_TMD_TestCase,
ParallelPool_IntegratedSquaredPBTMD_KL_divergence_DerivativeChecks,
unittest.TestCase):
[docs] def setUp(self):
super(Banana2D_ParallelPool_ISPBTMD_KLdiv_DerivativeChecks,self)._setUp_tcase()
super(Banana2D_ParallelPool_ISPBTMD_KLdiv_DerivativeChecks,self).setUp()
# #
# # Integrated Squared ParallelPool PushForward tests
# #
# class Linear1D_ParallelPool_ISPFTMD_KLdiv_DerivativeChecks(
# Linear1D_TMD_TestCase,
# ParallelPool_IntegratedSquaredPFTMD_KL_divergence_DerivativeChecks,
# unittest.TestCase):
# def setUp(self):
# super(Linear1D_ParallelPool_ISPFTMD_KLdiv_DerivativeChecks,self)._setUp_tcase()
# super(Linear1D_ParallelPool_ISPFTMD_KLdiv_DerivativeChecks,self).setUp()
# class ArcTan1D_ParallelPool_ISPFTMD_KLdiv_DerivativeChecks(
# ArcTan1D_TMD_TestCase,
# ParallelPool_IntegratedSquaredPFTMD_KL_divergence_DerivativeChecks,
# unittest.TestCase):
# def setUp(self):
# super(ArcTan1D_ParallelPool_ISPFTMD_KLdiv_DerivativeChecks,self)._setUp_tcase()
# super(ArcTan1D_ParallelPool_ISPFTMD_KLdiv_DerivativeChecks,self).setUp()
# class Exp1D_ParallelPool_ISPFTMD_KLdiv_DerivativeChecks(
# Exp1D_TMD_TestCase,
# ParallelPool_IntegratedSquaredPFTMD_KL_divergence_DerivativeChecks,
# unittest.TestCase):
# def setUp(self):
# super(Exp1D_ParallelPool_ISPFTMD_KLdiv_DerivativeChecks,self)._setUp_tcase()
# super(Exp1D_ParallelPool_ISPFTMD_KLdiv_DerivativeChecks,self).setUp()
# class Logistic1D_ParallelPool_ISPFTMD_KLdiv_DerivativeChecks(
# Logistic1D_TMD_TestCase,
# ParallelPool_IntegratedSquaredPFTMD_KL_divergence_DerivativeChecks,
# unittest.TestCase):
# def setUp(self):
# super(Logistic1D_ParallelPool_ISPFTMD_KLdiv_DerivativeChecks,self)._setUp_tcase()
# super(Logistic1D_ParallelPool_ISPFTMD_KLdiv_DerivativeChecks,self).setUp()
# class Gamma1D_ParallelPool_ISPFTMD_KLdiv_DerivativeChecks(
# Gamma1D_TMD_TestCase,
# ParallelPool_IntegratedSquaredPFTMD_KL_divergence_DerivativeChecks,
# unittest.TestCase):
# def setUp(self):
# super(Gamma1D_ParallelPool_ISPFTMD_KLdiv_DerivativeChecks,self)._setUp_tcase()
# super(Gamma1D_ParallelPool_ISPFTMD_KLdiv_DerivativeChecks,self).setUp()
# class Beta1D_ParallelPool_ISPFTMD_KLdiv_DerivativeChecks(
# Beta1D_TMD_TestCase,
# ParallelPool_IntegratedSquaredPFTMD_KL_divergence_DerivativeChecks,
# unittest.TestCase):
# def setUp(self):
# super(Beta1D_ParallelPool_ISPFTMD_KLdiv_DerivativeChecks,self)._setUp_tcase()
# super(Beta1D_ParallelPool_ISPFTMD_KLdiv_DerivativeChecks,self).setUp()
# class Gumbel1D_ParallelPool_ISPFTMD_KLdiv_DerivativeChecks(
# Gumbel1D_TMD_TestCase,
# ParallelPool_IntegratedSquaredPFTMD_KL_divergence_DerivativeChecks,
# unittest.TestCase):
# def setUp(self):
# super(Gumbel1D_ParallelPool_ISPFTMD_KLdiv_DerivativeChecks,self)._setUp_tcase()
# super(Gumbel1D_ParallelPool_ISPFTMD_KLdiv_DerivativeChecks,self).setUp()
# class Linear2D_ParallelPool_ISPFTMD_KLdiv_DerivativeChecks(
# Linear2D_TMD_TestCase,
# ParallelPool_IntegratedSquaredPFTMD_KL_divergence_DerivativeChecks,
# unittest.TestCase):
# def setUp(self):
# super(Linear2D_ParallelPool_ISPFTMD_KLdiv_DerivativeChecks,self)._setUp_tcase()
# super(Linear2D_ParallelPool_ISPFTMD_KLdiv_DerivativeChecks,self).setUp()
# class Banana2D_ParallelPool_ISPFTMD_KLdiv_DerivativeChecks(
# Banana2D_TMD_TestCase,
# ParallelPool_IntegratedSquaredPFTMD_KL_divergence_DerivativeChecks,
# unittest.TestCase):
# def setUp(self):
# super(Banana2D_ParallelPool_ISPFTMD_KLdiv_DerivativeChecks,self)._setUp_tcase()
# super(Banana2D_ParallelPool_ISPFTMD_KLdiv_DerivativeChecks,self).setUp()
[docs]def build_suite(ttype='all'):
# INTEGRATED EXPONENTIAL
# Serial PullBack tests
suite_linear1d_se_iepbtmd = unittest.TestLoader().loadTestsFromTestCase(
Linear1D_Serial_IEPBTMD_KLdiv_DerivativeChecks )
suite_arctan1d_se_iepbtmd = unittest.TestLoader().loadTestsFromTestCase(
ArcTan1D_Serial_IEPBTMD_KLdiv_DerivativeChecks )
suite_exp1d_se_iepbtmd = unittest.TestLoader().loadTestsFromTestCase(
Exp1D_Serial_IEPBTMD_KLdiv_DerivativeChecks )
suite_logistic1d_se_iepbtmd = unittest.TestLoader().loadTestsFromTestCase(
Logistic1D_Serial_IEPBTMD_KLdiv_DerivativeChecks )
suite_gamma1d_se_iepbtmd = unittest.TestLoader().loadTestsFromTestCase(
Gamma1D_Serial_IEPBTMD_KLdiv_DerivativeChecks )
suite_beta1d_se_iepbtmd = unittest.TestLoader().loadTestsFromTestCase(
Beta1D_Serial_IEPBTMD_KLdiv_DerivativeChecks )
suite_gumbel1d_se_iepbtmd = unittest.TestLoader().loadTestsFromTestCase(
Gumbel1D_Serial_IEPBTMD_KLdiv_DerivativeChecks )
suite_linear2d_se_iepbtmd = unittest.TestLoader().loadTestsFromTestCase(
Linear2D_Serial_IEPBTMD_KLdiv_DerivativeChecks )
suite_banana2d_se_iepbtmd = unittest.TestLoader().loadTestsFromTestCase(
Banana2D_Serial_IEPBTMD_KLdiv_DerivativeChecks )
# # Serial PushForward tests
# suite_linear1d_se_iepftmd = unittest.TestLoader().loadTestsFromTestCase(
# Linear1D_Serial_IEPFTMD_KLdiv_DerivativeChecks )
# suite_arctan1d_se_iepftmd = unittest.TestLoader().loadTestsFromTestCase(
# ArcTan1D_Serial_IEPFTMD_KLdiv_DerivativeChecks )
# suite_exp1d_se_iepftmd = unittest.TestLoader().loadTestsFromTestCase(
# Exp1D_Serial_IEPFTMD_KLdiv_DerivativeChecks )
# suite_logistic1d_se_iepftmd = unittest.TestLoader().loadTestsFromTestCase(
# Logistic1D_Serial_IEPFTMD_KLdiv_DerivativeChecks )
# suite_gamma1d_se_iepftmd = unittest.TestLoader().loadTestsFromTestCase(
# Gamma1D_Serial_IEPFTMD_KLdiv_DerivativeChecks )
# suite_beta1d_se_iepftmd = unittest.TestLoader().loadTestsFromTestCase(
# Beta1D_Serial_IEPFTMD_KLdiv_DerivativeChecks )
# suite_gumbel1d_se_iepftmd = unittest.TestLoader().loadTestsFromTestCase(
# Gumbel1D_Serial_IEPFTMD_KLdiv_DerivativeChecks )
# suite_linear2d_se_iepftmd = unittest.TestLoader().loadTestsFromTestCase(
# Linear2D_Serial_IEPFTMD_KLdiv_DerivativeChecks )
# suite_banana2d_se_iepftmd = unittest.TestLoader().loadTestsFromTestCase(
# Banana2D_Serial_IEPFTMD_KLdiv_DerivativeChecks )
# ParallelPool PullBack tests
suite_linear1d_pa_pool_iepbtmd = unittest.TestLoader().loadTestsFromTestCase(
Linear1D_ParallelPool_IEPBTMD_KLdiv_DerivativeChecks )
suite_arctan1d_pa_pool_iepbtmd = unittest.TestLoader().loadTestsFromTestCase(
ArcTan1D_ParallelPool_IEPBTMD_KLdiv_DerivativeChecks )
suite_exp1d_pa_pool_iepbtmd = unittest.TestLoader().loadTestsFromTestCase(
Exp1D_ParallelPool_IEPBTMD_KLdiv_DerivativeChecks )
suite_logistic1d_pa_pool_iepbtmd = unittest.TestLoader().loadTestsFromTestCase(
Logistic1D_ParallelPool_IEPBTMD_KLdiv_DerivativeChecks )
suite_gamma1d_pa_pool_iepbtmd = unittest.TestLoader().loadTestsFromTestCase(
Gamma1D_ParallelPool_IEPBTMD_KLdiv_DerivativeChecks )
suite_beta1d_pa_pool_iepbtmd = unittest.TestLoader().loadTestsFromTestCase(
Beta1D_ParallelPool_IEPBTMD_KLdiv_DerivativeChecks )
suite_gumbel1d_pa_pool_iepbtmd = unittest.TestLoader().loadTestsFromTestCase(
Gumbel1D_ParallelPool_IEPBTMD_KLdiv_DerivativeChecks )
suite_linear2d_pa_pool_iepbtmd = unittest.TestLoader().loadTestsFromTestCase(
Linear2D_ParallelPool_IEPBTMD_KLdiv_DerivativeChecks )
suite_banana2d_pa_pool_iepbtmd = unittest.TestLoader().loadTestsFromTestCase(
Banana2D_ParallelPool_IEPBTMD_KLdiv_DerivativeChecks )
# # ParallelPool PushForward tests
# suite_linear1d_pa_pool_iepftmd = unittest.TestLoader().loadTestsFromTestCase(
# Linear1D_ParallelPool_IEPFTMD_KLdiv_DerivativeChecks )
# suite_arctan1d_pa_pool_iepftmd = unittest.TestLoader().loadTestsFromTestCase(
# ArcTan1D_ParallelPool_IEPFTMD_KLdiv_DerivativeChecks )
# suite_exp1d_pa_pool_iepftmd = unittest.TestLoader().loadTestsFromTestCase(
# Exp1D_ParallelPool_IEPFTMD_KLdiv_DerivativeChecks )
# suite_logistic1d_pa_pool_iepftmd = unittest.TestLoader().loadTestsFromTestCase(
# Logistic1D_ParallelPool_IEPFTMD_KLdiv_DerivativeChecks )
# suite_gamma1d_pa_pool_iepftmd = unittest.TestLoader().loadTestsFromTestCase(
# Gamma1D_ParallelPool_IEPFTMD_KLdiv_DerivativeChecks )
# suite_beta1d_pa_pool_iepftmd = unittest.TestLoader().loadTestsFromTestCase(
# Beta1D_ParallelPool_IEPFTMD_KLdiv_DerivativeChecks )
# suite_gumbel1d_pa_pool_iepftmd = unittest.TestLoader().loadTestsFromTestCase(
# Gumbel1D_ParallelPool_IEPFTMD_KLdiv_DerivativeChecks )
# suite_linear2d_pa_pool_iepftmd = unittest.TestLoader().loadTestsFromTestCase(
# Linear2D_ParallelPool_IEPFTMD_KLdiv_DerivativeChecks )
# suite_banana2d_pa_pool_iepftmd = unittest.TestLoader().loadTestsFromTestCase(
# Banana2D_ParallelPool_IEPFTMD_KLdiv_DerivativeChecks )
# INTEGRATED SQUARED
# Serial PullBack tests
suite_linear1d_se_ispbtmd = unittest.TestLoader().loadTestsFromTestCase(
Linear1D_Serial_ISPBTMD_KLdiv_DerivativeChecks )
suite_arctan1d_se_ispbtmd = unittest.TestLoader().loadTestsFromTestCase(
ArcTan1D_Serial_ISPBTMD_KLdiv_DerivativeChecks )
suite_exp1d_se_ispbtmd = unittest.TestLoader().loadTestsFromTestCase(
Exp1D_Serial_ISPBTMD_KLdiv_DerivativeChecks )
suite_logistic1d_se_ispbtmd = unittest.TestLoader().loadTestsFromTestCase(
Logistic1D_Serial_ISPBTMD_KLdiv_DerivativeChecks )
suite_gamma1d_se_ispbtmd = unittest.TestLoader().loadTestsFromTestCase(
Gamma1D_Serial_ISPBTMD_KLdiv_DerivativeChecks )
suite_beta1d_se_ispbtmd = unittest.TestLoader().loadTestsFromTestCase(
Beta1D_Serial_ISPBTMD_KLdiv_DerivativeChecks )
suite_gumbel1d_se_ispbtmd = unittest.TestLoader().loadTestsFromTestCase(
Gumbel1D_Serial_ISPBTMD_KLdiv_DerivativeChecks )
suite_linear2d_se_ispbtmd = unittest.TestLoader().loadTestsFromTestCase(
Linear2D_Serial_ISPBTMD_KLdiv_DerivativeChecks )
suite_banana2d_se_ispbtmd = unittest.TestLoader().loadTestsFromTestCase(
Banana2D_Serial_ISPBTMD_KLdiv_DerivativeChecks )
# # Serial PushForward tests
# suite_linear1d_se_ispftmd = unittest.TestLoader().loadTestsFromTestCase(
# Linear1D_Serial_ISPFTMD_KLdiv_DerivativeChecks )
# suite_arctan1d_se_ispftmd = unittest.TestLoader().loadTestsFromTestCase(
# ArcTan1D_Serial_ISPFTMD_KLdiv_DerivativeChecks )
# suite_exp1d_se_ispftmd = unittest.TestLoader().loadTestsFromTestCase(
# Exp1D_Serial_ISPFTMD_KLdiv_DerivativeChecks )
# suite_logistic1d_se_ispftmd = unittest.TestLoader().loadTestsFromTestCase(
# Logistic1D_Serial_ISPFTMD_KLdiv_DerivativeChecks )
# suite_gamma1d_se_ispftmd = unittest.TestLoader().loadTestsFromTestCase(
# Gamma1D_Serial_ISPFTMD_KLdiv_DerivativeChecks )
# suite_beta1d_se_ispftmd = unittest.TestLoader().loadTestsFromTestCase(
# Beta1D_Serial_ISPFTMD_KLdiv_DerivativeChecks )
# suite_gumbel1d_se_ispftmd = unittest.TestLoader().loadTestsFromTestCase(
# Gumbel1D_Serial_ISPFTMD_KLdiv_DerivativeChecks )
# suite_linear2d_se_ispftmd = unittest.TestLoader().loadTestsFromTestCase(
# Linear2D_Serial_ISPFTMD_KLdiv_DerivativeChecks )
# suite_banana2d_se_ispftmd = unittest.TestLoader().loadTestsFromTestCase(
# Banana2D_Serial_ISPFTMD_KLdiv_DerivativeChecks )
# ParallelPool PullBack tests
suite_linear1d_pa_pool_ispbtmd = unittest.TestLoader().loadTestsFromTestCase(
Linear1D_ParallelPool_ISPBTMD_KLdiv_DerivativeChecks )
suite_arctan1d_pa_pool_ispbtmd = unittest.TestLoader().loadTestsFromTestCase(
ArcTan1D_ParallelPool_ISPBTMD_KLdiv_DerivativeChecks )
suite_exp1d_pa_pool_ispbtmd = unittest.TestLoader().loadTestsFromTestCase(
Exp1D_ParallelPool_ISPBTMD_KLdiv_DerivativeChecks )
suite_logistic1d_pa_pool_ispbtmd = unittest.TestLoader().loadTestsFromTestCase(
Logistic1D_ParallelPool_ISPBTMD_KLdiv_DerivativeChecks )
suite_gamma1d_pa_pool_ispbtmd = unittest.TestLoader().loadTestsFromTestCase(
Gamma1D_ParallelPool_ISPBTMD_KLdiv_DerivativeChecks )
suite_beta1d_pa_pool_ispbtmd = unittest.TestLoader().loadTestsFromTestCase(
Beta1D_ParallelPool_ISPBTMD_KLdiv_DerivativeChecks )
suite_gumbel1d_pa_pool_ispbtmd = unittest.TestLoader().loadTestsFromTestCase(
Gumbel1D_ParallelPool_ISPBTMD_KLdiv_DerivativeChecks )
suite_linear2d_pa_pool_ispbtmd = unittest.TestLoader().loadTestsFromTestCase(
Linear2D_ParallelPool_ISPBTMD_KLdiv_DerivativeChecks )
suite_banana2d_pa_pool_ispbtmd = unittest.TestLoader().loadTestsFromTestCase(
Banana2D_ParallelPool_ISPBTMD_KLdiv_DerivativeChecks )
# # ParallelPool PushForward tests
# suite_linear1d_pa_pool_ispftmd = unittest.TestLoader().loadTestsFromTestCase(
# Linear1D_ParallelPool_ISPFTMD_KLdiv_DerivativeChecks )
# suite_arctan1d_pa_pool_ispftmd = unittest.TestLoader().loadTestsFromTestCase(
# ArcTan1D_ParallelPool_ISPFTMD_KLdiv_DerivativeChecks )
# suite_exp1d_pa_pool_ispftmd = unittest.TestLoader().loadTestsFromTestCase(
# Exp1D_ParallelPool_ISPFTMD_KLdiv_DerivativeChecks )
# suite_logistic1d_pa_pool_ispftmd = unittest.TestLoader().loadTestsFromTestCase(
# Logistic1D_ParallelPool_ISPFTMD_KLdiv_DerivativeChecks )
# suite_gamma1d_pa_pool_ispftmd = unittest.TestLoader().loadTestsFromTestCase(
# Gamma1D_ParallelPool_ISPFTMD_KLdiv_DerivativeChecks )
# suite_beta1d_pa_pool_ispftmd = unittest.TestLoader().loadTestsFromTestCase(
# Beta1D_ParallelPool_ISPFTMD_KLdiv_DerivativeChecks )
# suite_gumbel1d_pa_pool_ispftmd = unittest.TestLoader().loadTestsFromTestCase(
# Gumbel1D_ParallelPool_ISPFTMD_KLdiv_DerivativeChecks )
# suite_linear2d_pa_pool_ispftmd = unittest.TestLoader().loadTestsFromTestCase(
# Linear2D_ParallelPool_ISPFTMD_KLdiv_DerivativeChecks )
# suite_banana2d_pa_pool_ispftmd = unittest.TestLoader().loadTestsFromTestCase(
# Banana2D_ParallelPool_ISPFTMD_KLdiv_DerivativeChecks )
# GROUP SUITES
# Serial
suites_list = []
if ttype in ['all','serial']:
suites_list = [
# INTEGRATED EXPONENTIAL
# Serial Pull Back
suite_linear1d_se_iepbtmd,
suite_arctan1d_se_iepbtmd,
suite_exp1d_se_iepbtmd, suite_logistic1d_se_iepbtmd,
suite_gamma1d_se_iepbtmd, suite_beta1d_se_iepbtmd,
suite_gumbel1d_se_iepbtmd,
suite_linear2d_se_iepbtmd,
suite_banana2d_se_iepbtmd,
# # Serial Push Forward
# suite_linear1d_se_iepftmd,
# suite_arctan1d_se_iepftmd, suite_exp1d_se_iepftmd,
# suite_logistic1d_se_iepftmd, suite_gamma1d_se_iepftmd,
# suite_beta1d_se_iepftmd, suite_gumbel1d_se_iepftmd,
# suite_linear2d_se_iepftmd, suite_banana2d_se_iepftmd,
# INTEGRATED SQUARED
# Serial Pull Back
suite_linear1d_se_ispbtmd, suite_arctan1d_se_ispbtmd,
suite_exp1d_se_ispbtmd, suite_logistic1d_se_ispbtmd,
suite_gamma1d_se_ispbtmd, suite_beta1d_se_ispbtmd,
suite_gumbel1d_se_ispbtmd,
suite_linear2d_se_ispbtmd,
suite_banana2d_se_ispbtmd,
# # Serial Push Forward
# suite_linear1d_se_ispftmd,
# suite_arctan1d_se_ispftmd, suite_exp1d_se_ispftmd,
# suite_logistic1d_se_ispftmd, suite_gamma1d_se_ispftmd,
# suite_beta1d_se_ispftmd, suite_gumbel1d_se_ispftmd,
# suite_linear2d_se_ispftmd, suite_banana2d_se_ispftmd,
]
# Parallel
if ttype in ['all','parallel'] and MPI_SUPPORT:
suites_list += [
# INTEGRATED EXPONENTIAL
# ParallelPool Pull Back
suite_linear1d_pa_pool_iepbtmd, suite_arctan1d_pa_pool_iepbtmd,
suite_exp1d_pa_pool_iepbtmd, suite_logistic1d_pa_pool_iepbtmd,
suite_gamma1d_pa_pool_iepbtmd, suite_beta1d_pa_pool_iepbtmd,
suite_gumbel1d_pa_pool_iepbtmd, suite_linear2d_pa_pool_iepbtmd,
suite_banana2d_pa_pool_iepbtmd,
# # ParallelPool Push Forward
# suite_linear1d_pa_pool_iepftmd,
# suite_arctan1d_pa_pool_iepftmd, suite_exp1d_pa_pool_iepftmd,
# suite_logistic1d_pa_pool_iepftmd, suite_gamma1d_pa_pool_iepftmd,
# suite_beta1d_pa_pool_iepftmd, suite_gumbel1d_pa_pool_iepftmd,
# suite_linear2d_pa_pool_iepftmd, suite_banana2d_pa_pool_iepftmd,
# INTEGRATED SQUARED
# ParallelPool Pull Back
suite_linear1d_pa_pool_ispbtmd, suite_arctan1d_pa_pool_ispbtmd,
suite_exp1d_pa_pool_ispbtmd, suite_logistic1d_pa_pool_ispbtmd,
suite_gamma1d_pa_pool_ispbtmd, suite_beta1d_pa_pool_ispbtmd,
suite_gumbel1d_pa_pool_ispbtmd, suite_linear2d_pa_pool_ispbtmd,
suite_banana2d_pa_pool_ispbtmd,
# # ParallelPool Push Forward
# suite_linear1d_pa_pool_ispftmd,
# suite_arctan1d_pa_pool_ispftmd, suite_exp1d_pa_pool_ispftmd,
# suite_logistic1d_pa_pool_ispftmd, suite_gamma1d_pa_pool_ispftmd,
# suite_beta1d_pa_pool_ispftmd, suite_gumbel1d_pa_pool_ispftmd,
# suite_linear2d_pa_pool_ispftmd, suite_banana2d_pa_pool_ispftmd
]
all_suites = unittest.TestSuite( suites_list )
return all_suites
[docs]def run_tests(
ttype='serial',
failfast=False
):
all_suites = build_suite(ttype)
# RUN
unittest.TextTestRunner(
verbosity=2,
failfast=failfast
).run(all_suites)
if __name__ == '__main__':
run_tests()