Hi there,
I am very much interested in using the parallel setup for the inverse transport computation. I have the most up-to-date installation of TM, mpi4py and mpi_map. I can also successfully run the Example 0 (direct transport) in the MPI tutorial:
# Example 0: Minimization of the KL-divergence and sampling
nprocs = 2
# Define target distribution
mu = 3.
beta = 4.
target_density = DIST.GumbelDistribution(mu,beta)
# Define base density
base_density = DIST.StandardNormalDistribution(1)
# Define approximating transport map
order = 5
tm_approx = TM.Default_IsotropicIntegratedExponentialTriangularTransportMap(1, order, 'full')
# Define approximating density
tm_density = DIST.PushForwardTransportMapDistribution(tm_approx, base_density)
# Start pool of processes
mpi_pool = TM.get_mpi_pool()
mpi_pool.start(nprocs)
# Solve and sample
try:
qtype = 0 # Gauss quadrature
qparams = 1000 # Quadrature order
reg = None # No regularization
tol = 1e-8 # Optimization tolerance
ders = 1 # Use gradient and Hessian
log_entry_solve = tm_density.minimize_kl_divergence(
target_density, qtype=qtype, qparams=qparams,
regularization=reg, tol=tol, ders=ders,
mpi_pool=mpi_pool)
finally:
mpi_pool.stop()
log_entry_solve
2019-03-20 13:21:46 WARNING:mpi_map: MPI_Pool_v2.alloc_dmem DEPRECATED since v>2.4. Use MPI_Pool_v2.bcast_dmem instead.
2019-03-20 13:21:46 WARNING:mpi_map: MPI_Pool_v2.alloc_dmem DEPRECATED since v>2.4. Use MPI_Pool_v2.bcast_dmem instead.
Out[16]:
{'success': True,
'message': 'Optimization terminated successfully.',
'fval': 1.4265539022688227,
'nit': 40,
'n_fun_ev': 41,
'n_jac_ev': 41,
.......
Now when I try to replicate the same process for the inverse map estimation using the tutorial example on Gumbel distribution, I get an error:
import TransportMaps.Distributions as DIST
class GumbelDistribution(DIST.Distribution):
def __init__(self, mu, beta):
super(GumbelDistribution,self).__init__(1)
self.mu = mu
self.beta = beta
self.dist = stats.gumbel_r(loc=mu, scale=beta)
def pdf(self, x, params=None):
return self.dist.pdf(x).flatten()
def quadrature(self, qtype, qparams, *args, **kwargs):
if qtype == 0: # Monte-Carlo
x = self.dist.rvs(qparams)[:,np.newaxis]
w = np.ones(qparams)/float(qparams)
else: raise ValueError("Quadrature not defined")
return (x, w)
mu = 3.
beta = 4.
pi = GumbelDistribution(mu,beta)
x, w = pi.quadrature(0, 5000)
# linear adjustment
xmax = np.max(x)
xmin = np.min(x)
a = np.array([ 4*(xmin+xmax)/(xmin-xmax) ])
b = np.array([ 8./(xmax-xmin) ])
L = MAPS.FrozenLinearDiagonalTransportMap(a,b)
S = TM.Default_IsotropicIntegratedSquaredTriangularTransportMap(
1, 3, 'total')
rho = DIST.StandardNormalDistribution(1)
push_L_pi = DIST.PushForwardTransportMapDistribution(L, pi)
push_SL_pi = DIST.PushForwardTransportMapDistribution(
S, push_L_pi)
# Start pool of processes
nprocs = 2
mpi_pool = TM.get_mpi_pool()
mpi_pool.start(nprocs)
# Solve and sample
try:
qtype = 0 # Monte-Carlo quadratures from pi
qparams = 500 # Number of MC points
reg = None # No regularization
tol = 1e-3 # Optimization tolerance
ders = 2 # Use gradient and Hessian
log = push_SL_pi.minimize_kl_divergence(
rho, qtype=qtype, qparams=qparams, regularization=reg,
tol=tol, ders=ders,mpi_pool=mpi_pool)
finally:
mpi_pool.stop()
SL = MAPS.CompositeMap(S,L)
pull_SL_rho = DIST.PullBackTransportMapDistribution(SL, rho)
log
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-15-d1561f88f5fc> in <module>
54 log = push_SL_pi.minimize_kl_divergence(
55 rho, qtype=qtype, qparams=qparams, regularization=reg,
---> 56 tol=tol, ders=ders,mpi_pool=mpi_pool)
57
58 finally:
~/anaconda3/lib/python3.7/site-packages/TransportMaps/Distributions/TransportMapDistributions.py in minimize_kl_divergence(self, tar, qtype, qparams, parbase, partar, x0, regularization, tol, maxit, ders, fungrad, hessact, batch_size, mpi_pool, grad_check, hess_check)
525 tol=tol, maxit=maxit, ders=ders, fungrad=fungrad, hessact=hessact,
526 batch_size=batch_size,
--> 527 mpi_pool=mpi_pool, grad_check=grad_check, hess_check=hess_check)
528 return log
529
~/anaconda3/lib/python3.7/site-packages/TransportMaps/Maps/TriangularTransportMapBase.py in minimize_kl_divergence(self, d1, d2, qtype, qparams, x, w, params_d1, params_d2, x0, regularization, tol, maxit, ders, fungrad, hessact, precomp_type, batch_size, mpi_pool, grad_check, hess_check)
1018 print('mpi_pool_list is '+str(mpi_pool_list)) # added by Hassan
1019 for i, (a, avars, batch_size, mpi_pool) in enumerate(zip(
-> 1020 self.approx_list, self.active_vars, batch_size_list, mpi_pool_list)):
1021 f = ProductDistributionParametricPullbackComponentFunction(
1022 a, d2.base_distribution.get_component([i]) )
TypeError: zip argument #4 must support iteration
I print out the argument #4 and get
mpi_pool_list is <mpi_map.misc.MPI_Pool_v2 object at 0x7fa08444b8d0>.
I have tried this for my own data in 2D and 3D, and get the same error.
I have studied the turorial on MPI, but still it is still beyond me to see where this comes from.
Will appreciate if you have suggestions or thoughts.