[1]:
import matplotlib.pyplot as plt
import numpy as np
import numpy.linalg as npla
import scipy.stats as stats
import scipy.linalg as scila
import SpectralToolbox.Spectral1D as S1D
import TransportMaps as TM
import TransportMaps.Maps as MAPS
import TransportMaps.Distributions as DIST
import TransportMaps.Diagnostics as DIAG
import TransportMaps.KL as KL
Multidimensional distributions¶
Let \({\bf X}\sim \mathcal{N}({\bf 0},{\bf \Sigma})\) be an auxiliary random variable with density \(\pi_{\rm x}\), where \({\bf 0}\) is a two element vector and \({\bf \Sigma}\) is a \(2 \times 2\) symmetric positive definite matrix.
Let’s also consider the map \(B^\star:\mathbb{R}^2 \rightarrow \mathbb{R}^2\) defined by
where \(a>0\).
We define the target density as
where
and consequently
This leads to the following expression for the target density:
Let’s define \(\nu_\pi\) …
[2]:
a = 1.
b = 1.
mu = np.zeros(2)
sigma2 = np.array([[1., 0.9],[0.9, 1.]])
pi = DIST.BananaDistribution(a, b, mu, sigma2)
Let’s see how the PDF of such distribution looks like…
[3]:
ndiscr = 100
x = np.linspace(-4,4,ndiscr)
y = np.linspace(-9,3,ndiscr)
xx,yy = np.meshgrid(x,y)
X2d = np.vstack( (xx.flatten(),yy.flatten()) ).T
pdf2d = pi.pdf(X2d).reshape(xx.shape)
[4]:
levels_pdf2d = np.linspace(np.min(pdf2d),np.max(pdf2d),10)
plt.figure()
plt.contour(xx, yy, pdf2d, levels=levels_pdf2d);
Let us define also the exact transport map \(T^\star = B \circ L\) that we are seeking.
[5]:
lin_tm = MAPS.AffineTransportMap(c=mu, L=npla.cholesky(sigma2))
ban_tm = MAPS.FrozenBananaMap(a, b)
Tstar = MAPS.CompositeTransportMap(ban_tm, lin_tm)
[6]:
x1 = np.linspace(-7, 7, ndiscr)
X1d = np.vstack( (x1, np.zeros(ndiscr)) ).T
t1 = Tstar(X1d)[:,0]
t2 = ban_tm(lin_tm(X2d))[:,1]
[7]:
plt.figure(figsize=(10,3.5))
plt.subplot(121); plt.title(r"$T_1(x_1)$");
plt.plot(x1, t1);
plt.subplot(122); plt.title(r"$T_2(x_1,x_2)$");
plt.contour(xx, yy, t2.reshape(xx.shape));
Integrated squared parametrization¶
The triangular transport map \(T\), in the \(2\) dimensional case, takes the form:
Each of these components are parameterized by
where \(c\) and \(h\) are themselves two parametric approximations.
We use polynomial approximations for both of these functions, such that:
Let’s build a 2-nd order approximation of the transport map…
[8]:
order = 2
T = MAPS.assemble_IsotropicIntegratedSquaredTriangularTransportMap(
2, order, 'full')
… select the reference distribution \(\nu_\rho\) …
[9]:
rho = DIST.StandardNormalDistribution(2)
… construct the pushforward object \(T_{\sharp}\nu_\rho\) and the pullback object \(T^{\sharp}\nu_\pi\) …
[10]:
push_rho = DIST.PushForwardParametricTransportMapDistribution(T, rho)
pull_pi = DIST.PullBackParametricTransportMapDistribution(T, pi)
We are then ready to set up and solve the problem
which is equivalent to solve:
[11]:
qtype = 3 # Gauss quadrature
qparams = [10] * 2 # Quadrature order
reg = None # No regularization
tol = 1e-5 # Optimization tolerance
ders = 2 # Use gradient and Hessian
log = KL.minimize_kl_divergence(
rho, pull_pi, qtype=qtype, qparams=qparams, regularization=reg,
tol=tol, ders=ders)
Let’s check the PDF approximation against the exact Banana distribution…
[12]:
approx_pdf = push_rho.pdf(X2d).reshape(xx.shape)
plt.figure()
plt.contour(xx, yy, pdf2d, levels=levels_pdf2d);
plt.contour(xx, yy, approx_pdf, linestyles='dashed', levels=levels_pdf2d);
and compare the approximated transport map with the exact transport map…
[13]:
t1_approx = T(X1d)[:,0]
t2_approx = T(X2d)[:,1]
plt.figure(figsize=(10,3.5))
plt.subplot(121); plt.title(r"$T_1(x_1)$")
plt.plot(x1, t1, 'k');
plt.plot(x1, t1_approx, 'r');
plt.subplot(122); plt.title(r"$T_2(x_1,x_2)$");
plt.contour(xx, yy, t2.reshape(xx.shape));
plt.contour(xx, yy, t2_approx.reshape(xx.shape),
linestyles='dashed');
Sampling from the approximation¶
Monte-Carlo sampling¶
[14]:
M = 1000
samples = push_rho.rvs(M)
plt.figure()
plt.contour(xx, yy, pdf2d);
plt.scatter(samples[:,0], samples[:,1], c='k', s=1.);
Gauss quadratures¶
[15]:
(xq,wq) = push_rho.quadrature(qtype=3, qparams=[2,2])
plt.figure()
plt.contour(xx, yy, pdf2d);
plt.scatter(xq[:,0], xq[:,1], c='k', s=50.*wq+10.);