# Source code for TransportMaps.Samplers.Routines

#
# This file is part of TransportMaps.
#
# TransportMaps is free software: you can redistribute it and/or modify
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# TransportMaps is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with TransportMaps.  If not, see <http://www.gnu.org/licenses/>.
#
# Transport Maps Library
# Copyright (C) 2015-2018 Massachusetts Institute of Technology
# Uncertainty Quantification group
# Department of Aeronautics and Astronautics
#
# Author: Transport Map Team
# Website: transportmaps.mit.edu
# Support: transportmaps.mit.edu/qa/
#

import numpy as np
import numpy.linalg as npla
import scipy.stats as stats
import scipy.signal as signal

__all__ = ['ess']

[docs]def ess(samps, quantile=0.99, do_xcorr=False,
plotting=False, plot_lag=50, fig=None):
r""" Compute the Effective Sample Size (ESS) of a sample

The minimum ESS over all the dimension is returned.
Cross-correlation can be optionally used as well in the determination of
the ESS.
Plotting of the correlation decay can be shown.

The ESS is computed as :math:\lfloor m/\kappa \rfloor, where

.. math::

\kappa = 1 + \sum_{c_i>b_i} c_i \;,

:math:c_i is the auto-correlation at lag :math:i and
:math:b_i is the quantile-confidence interval for the
:math:i-th value of auto-correlation
(i.e. only significant auto-correlation values are summed up).

Args:
samps (:class:ndarray<numpy.ndarray> [:math:m,d]):
:math:d-dimensional sample on which to compute the ESS
quantile (float): condifence interval quantile
do_xcorr (bool): whether to compute and use the auto-correlation function
plotting (bool): whether to plot auto/cross-correlation decays
plot_lag (int): how many lags to plot
fig (figure): handle to a figure

Returns:
(:class:int) -- minimum ESS across the :math:d dimensions
"""
nsamps = samps.shape[0]
dim = samps.shape[1]
xcorr = []
abs_xcorr = []
for d1 in range(dim):
if do_xcorr:
for d2 in range(d1+1):
s1 = samps[:,d1] - np.mean(samps[:,d1])
s2 = samps[:,d2] - np.mean(samps[:,d2])
s1 /= npla.norm(s1)
s2 /= npla.norm(s2)
# c = np.correlate(s1, s2, mode='full')[nsamps-1:]
c = signal.correlate(s1, s2, mode='full')[nsamps-1:]
xcorr.append(c)
abs_xcorr.append( np.abs(xcorr[-1]) )
else:
s = samps[:,d1] - np.mean(samps[:,d1])
s /= npla.norm(s)
# c = np.correlate(s, s, mode='full')[nsamps-1:]
c = signal.correlate(s, s, mode='full')[nsamps-1:]
xcorr.append(c)
abs_xcorr.append( np.abs(xcorr[-1]) )
# Confidence interval
var = 1. / np.arange(nsamps, 0, -1)
alpha = 1. - (1. - quantile)/2.
confint = stats.norm.ppf(alpha) * np.sqrt(var)
# ESS
ess = []
min_ess = nsamps
for axc in abs_xcorr:
sig_corr = axc[axc >= confint]
kappa = 1. + 2. * np.sum(sig_corr[1:])
ess.append( int(np.floor( float(nsamps) / kappa ) ) )
if ess[-1] < min_ess:
min_ess = ess[-1]
if plotting:
import matplotlib.pyplot as plt
if fig is None:
fig = plt.figure()
ax_list = []
if do_xcorr:
for d1 in range(dim):
for d2 in range(dim):
if d2 <= d1:
else:
for d1 in range(dim):
for ax, c in zip(ax_list, xcorr):
ax.vlines(range(plot_lag+1), np.zeros(plot_lag+1), c[:plot_lag+1])
ax.plot(confint[:plot_lag+1], '--r')
ax.plot(-confint[:plot_lag+1], '--r')
ax.tick_params(
axis='x',          # changes apply to the x-axis
which='both',      # both major and minor ticks are affected
bottom='off',      # ticks along the bottom edge are off
top='off',         # ticks along the top edge are off
labelbottom='off') # labels along the bottom edge are off
ax.tick_params(
axis='y',          # changes apply to the x-axis
which='both',      # both major and minor ticks are affected
left='off',        # ticks along the bottom edge are off
right='off',       # ticks along the top edge are off
labelleft='off')   # labels along the bottom edge are off
plt.tight_layout()
return min_ess