Source code for TransportMaps.Algorithms.Adaptivity.RegressionAdaptivity

#
# This file is part of TransportMaps.
#
# TransportMaps is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# TransportMaps is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with TransportMaps.  If not, see <http://www.gnu.org/licenses/>.
#
# Transport Maps Library
# Copyright (C) 2015-2018 Massachusetts Institute of Technology
# Uncertainty Quantification group
# Department of Aeronautics and Astronautics
#
# Authors: Transport Map Team
# Website: transportmaps.mit.edu
# Support: transportmaps.mit.edu/qa/
#

from ...Builders import L2RegressionBuilder
from ...L2 import L2_misfit

__all__ = [
    'ToleranceSequentialL2RegressionBuilder'
]


[docs]class ToleranceSequentialL2RegressionBuilder(L2RegressionBuilder): def __init__(self, eps, regression_params_list, monitor_params): self.eps = eps self.regression_params_list = regression_params_list self.monitor_params = monitor_params super(ToleranceSequentialL2RegressionBuilder, self).__init__(self.regression_params_list[0])
[docs] def solve(self, transport_map_list, target_map, **extra_reg_params): tm = None x0 = None for transport_map, self.regression_params in zip( transport_map_list, self.regression_params_list): if tm is not None: for c1, c2 in zip(tm.approx_list, transport_map.approx_list): # Constant part for i1, midx1 in enumerate(c1.c.multi_idxs): for i2, midx2 in enumerate(c2.c.multi_idxs): if midx1 == midx2: break c2.c.coeffs[i2] = c1.c.coeffs[i1] # Integrated part for i1, midx1 in enumerate(c1.h.multi_idxs): for i2, midx2 in enumerate(c2.h.multi_idxs): if midx1 == midx2: break c2.h.coeffs[i2] = c1.h.coeffs[i1] self.regression_params['x0'] = transport_map.coeffs tm_new, log_list = super( ToleranceSequentialL2RegressionBuilder, self).solve(transport_map, target_map, **extra_reg_params) if not log_list[-1]['success']: if tm is None: # If no map is available return the target map return target_map, log_list else: break else: tm = transport_map # Check error err = L2_misfit(target_map, tm, **self.monitor_params) if err < self.eps: break return transport_map, log_list