#!/usr/bin/env python
#
# This file is part of TransportMaps.
#
# TransportMaps is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# TransportMaps is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with TransportMaps. If not, see <http://www.gnu.org/licenses/>.
#
# Transport Maps Library
# Copyright (C) 2015-2018 Massachusetts Institute of Technology
# Uncertainty Quantification group
# Department of Aeronautics and Astronautics
#
# Author: Transport Map Team
# Website: transportmaps.mit.edu
# Support: transportmaps.mit.edu/qa/
#
import numpy as np
from TransportMaps.ObjectBase import TMO
from TransportMaps.Misc import cmdinput, required_kwargs
from TransportMaps.External import DOLFIN_SUPPORT
if DOLFIN_SUPPORT:
import dolfin as dol
if dol.__version__ == '2017.2.0':
import mpi4py.MPI as MPI
from petsc4py import PETSc
dol.set_log_level(30)
__all__ = [
'Solver'
]
[docs]class Solver(TMO):
r""" [Abstract] Generic class for a PDE solver
It offers only stub functions and a function to convert degrees of freedom
to fenics functions defined on the prescribed approximation space.
"""
def __init__(self, **kwargs):
super(Solver, self).__init__()
if not DOLFIN_SUPPORT:
raise ImportError("Please install FENICS (dolfin) in order to use this class")
self._dolfin_version = dol.__version__
self.init_mpi()
self.set_up(**kwargs)
[docs] def init_mpi(self):
# Taking care of MPI
if dol.__version__ == '2017.2.0':
self.wcomm = MPI.COMM_WORLD
self.scomm = PETSc.Comm(MPI.COMM_SELF)
else:
self.wcomm = dol.MPI.comm_world
self.scomm = dol.MPI.comm_self
[docs] def __getstate__(self):
if hasattr(self, '_dolfin_version'):
return {'_dolfin_version': self._dolfin_version}
else:
return {}
[docs] def __setstate__(self, state):
super(Solver, self).__setstate__(state)
if hasattr(self, '_dolfin_version'):
if self._dolfin_version != dol.__version__:
self.logger.warn(
"The dolfin version of the solver does not match " + \
"the dolfin version that was used to create it. " + \
"Dolfin version expected: " + self._dolfin_version + ". " + \
"Dolfin version installed: " + dol.__version__ + "."
)
instr = None
while instr not in ['y', 'Y', 'n', 'N']:
instr = cmdinput("Do you want to continue? [y/N] ", 'N')
if instr in ['n', 'N']:
exit(0)
else:
self.logger.warn(
"The solver has no defined dolfin version. " + \
"This may be incompatible with the installed version."
)
self.init_mpi()
self.set_up()
@required_kwargs('VEFS')
[docs] def set_up(self, **kwargs):
self.VEFS = kwargs.pop('VEFS')
[docs] def solve(self, *args, **kwargs):
raise NotImplementedError("To be implemented in sub-classes")
[docs] def solve_adjoint(self, *args, **kwargs):
raise NotImplementedError("To be implemented in sub-classes")
[docs] def solve_action_hess_adjoint(self, *args, **kwargs):
raise NotImplementedError("To be implemented in sub-classes")
[docs] def new_function(self):
return dol.Function(self.VEFS)
[docs] def dof_to_fun(self, x):
ndofs = len(x)
fun = self.new_function()
if dol.__version__ == '2017.2.0':
fun.vector().set_local(x, np.arange(ndofs, dtype=np.intc))
else:
fun.vector().set_local(x)
fun.vector().apply('insert')
return fun