Source code for TransportMaps.MPI

import pickle
import numpy as np

from .Misc import deprecate
from .External import MPI_SUPPORT
from .ObjectBase import TMO

__all__ = [
    'get_mpi_pool', 'MPIPoolContext', 'mpi_eval',
    'mpi_map', 'mpi_map_alloc_dmem',
    'mpi_alloc_dmem', 'mpi_bcast_dmem', 'mpi_scatter_dmem',
    'SumChunkReduce', 'TupleSumChunkReduce',
    'TensorDotReduce', 'ExpectationReduce',
    'AbsExpectationReduce', 'TupleExpectationReduce',
    'distributed_sampling',
]


[docs]def get_mpi_pool(): r""" Get a pool of ``n`` processors Returns: (:class:`mpi_map.MPI_Pool<mpi_map.MPI_Pool>`) -- pool of processors Usage example:: import numpy as np import numpy.random as npr from TransportMaps import get_mpi_pool, mpi_map class Operator(object): def __init__(self, a): self.a = a def sum(self, x, n=1): out = x for i in range(n): out += self.a return out op = Operator(2.) x = npr.randn(100,5) n = 2 pool = get_mpi_pool() pool.start(3) try: xsum = mpi_map("sum", op, x, (n,), mpi_pool=pool) finally: pool.stop() """ if MPI_SUPPORT: import mpi_map as mpi_map_mod return mpi_map_mod.MPI_Pool_v2() else: raise RuntimeError("MPI is not supported")
[docs]class MPIPoolContext(object): def __init__(self, nprocs): if nprocs > 1: self.mpi_pool = get_mpi_pool() self.mpi_pool.start(nprocs) else: self.mpi_pool = None
[docs] def __enter__(self): return self.mpi_pool
[docs] def __exit__(self, exc_type, exc_val, exc_tb): if self.mpi_pool is not None: self.mpi_pool.stop()
@deprecate('mpi_eval', 'v2.0', 'mpi_map')
[docs]def mpi_eval(f, scatter_tuple=None, bcast_tuple=None, dmem_key_in_list=None, dmem_arg_in_list=None, dmem_val_in_list=None, dmem_key_out_list=None, obj=None, reduce_obj=None, reduce_tuple=None, import_set=None, mpi_pool=None, splitted=False, concatenate=True): r""" Interface for the parallel evaluation of a generic function on points ``x`` Args: f (:class:`object` or :class:`str`): function or string identifying the function in object ``obj`` scatter_tuple (tuple): tuple containing 2 lists of ``[keys]`` and ``[arguments]`` which will be scattered to the processes. bcast_tuple (tuple): tuple containing 2 lists of ``[keys]`` and ``[arguments]`` which will be broadcasted to the processes. dmem_key_in_list (list): list of string containing the keys to be fetched (or created with default ``None`` if missing) from the distributed memory and provided as input to ``f``. dmem_val_in_list (list): list of objects corresponding to the keys defined in ``dmem_key_in_list``, used in case we are not executing in parallel dmem_key_out_list (list): list of keys to be assigned to the outputs beside the first one obj (object): object where the function ``f_name`` is defined reduce_obj (object): object :class:`ReduceObject` defining the reduce method to be applied (if any) reduce_tuple (object): tuple containing 2 lists of ``[keys]`` and ``[arguments]`` which will be scattered to the processes to be used by ``reduce_obj`` import_set (set): list of couples ``(module_name,as_field)`` to be imported as ``import module_name as as_field`` mpi_pool (:class:`mpi_map.MPI_Pool<mpi_map.MPI_Pool>`): pool of processors splitted (bool): whether the scattering input is already splitted or not concatenate (bool): whether to concatenate the output (the output of ``f`` must be a :class:`numpy.ndarray<numpy.ndarray>` object """ if dmem_key_out_list is None: return mpi_map(f=f, scatter_tuple=scatter_tuple, bcast_tuple=bcast_tuple, dmem_key_in_list=dmem_key_in_list, dmem_arg_in_list=dmem_arg_in_list, dmem_val_in_list=dmem_val_in_list, obj=obj, reduce_obj=reduce_obj, reduce_tuple=reduce_tuple, import_set=import_set, mpi_pool=mpi_pool, splitted=splitted, concatenate=concatenate) else: out = mpi_map_alloc_dmem( f=f, scatter_tuple=scatter_tuple, bcast_tuple=bcast_tuple, dmem_key_in_list=dmem_key_in_list, dmem_arg_in_list=dmem_arg_in_list, dmem_val_in_list=dmem_val_in_list, dmem_key_out_list=dmem_key_out_list, obj=obj, reduce_obj=reduce_obj, reduce_tuple=reduce_tuple, import_set=import_set, mpi_pool=mpi_pool, splitted=splitted, concatenate=concatenate) return (None,) + out
[docs]def mpi_map(f, scatter_tuple=None, bcast_tuple=None, dmem_key_in_list=None, dmem_arg_in_list=None, dmem_val_in_list=None, obj=None, obj_val=None, reduce_obj=None, reduce_tuple=None, mpi_pool=None, splitted=False, concatenate=True): r""" Interface for the parallel evaluation of a generic function on points ``x`` Args: f (:class:`object` or :class:`str`): function or string identifying the function in object ``obj`` scatter_tuple (tuple): tuple containing 2 lists of ``[keys]`` and ``[arguments]`` which will be scattered to the processes. bcast_tuple (tuple): tuple containing 2 lists of ``[keys]`` and ``[arguments]`` which will be broadcasted to the processes. dmem_key_in_list (list): list of string containing the keys to be fetched (or created with default ``None`` if missing) from the distributed memory and provided as input to ``f``. dmem_val_in_list (list): list of objects corresponding to the keys defined in ``dmem_key_in_list``, used in case we are not executing in parallel obj (object or str): object where the function ``f_name`` is defined obj_val (object): object to be used in case not executing in parallel and ``obj`` is a string reduce_obj (object): object :class:`ReduceObject` defining the reduce method to be applied (if any) reduce_tuple (object): tuple containing 2 lists of ``[keys]`` and ``[arguments]`` which will be scattered to the processes to be used by ``reduce_obj`` mpi_pool (:class:`mpi_map.MPI_Pool<mpi_map.MPI_Pool>`): pool of processors splitted (bool): whether the scattering input is already splitted or not concatenate (bool): whether to concatenate the output (the output of ``f`` must be a :class:`numpy.ndarray<numpy.ndarray>` object """ # Init un-set arguments if scatter_tuple is None: scatter_tuple = ([], []) if bcast_tuple is None: bcast_tuple = ([], []) if dmem_key_in_list is None: dmem_key_in_list = [] if dmem_arg_in_list is None: dmem_arg_in_list = [] if dmem_val_in_list is None: dmem_val_in_list = [] if reduce_tuple is None: reduce_tuple = ([], []) # Start evaluation if mpi_pool is None: # Prepare arguments args = {} for key, val in zip(scatter_tuple[0], scatter_tuple[1]): if splitted: if len(val) != 1: raise ValueError("Serial execution but splitted input is %d long" % len(val)) args[key] = val[0] else: args[key] = val for key, val in zip(bcast_tuple[0], bcast_tuple[1]): args[key] = val for key, val in zip(dmem_arg_in_list, dmem_val_in_list): args[key] = val reduce_args = {} for key, val in zip(reduce_tuple[0], reduce_tuple[1]): if splitted: if len(val) != 1: raise ValueError("Serial execution but splitted input is %d long" % len(val)) reduce_args[key] = val[0] else: reduce_args[key] = val # Retrieve function if obj is not None: if isinstance(obj, str): f = getattr(obj_val, f) else: f = getattr(obj, f) # Evaluate fval = f(**args) # Reduce if necessary if reduce_obj is not None: fval = reduce_obj.outer_reduce( [reduce_obj.inner_reduce(fval, **reduce_args)], **reduce_args) else: # If object is not None and is not a string, first broadcast the object obj_in = obj if obj is not None and not isinstance(obj, str): obj_distr = pickle.loads(pickle.dumps(obj_in)) if issubclass(type(obj), TMO): obj_distr.reset_counters() mpi_bcast_dmem(obj=obj_distr, mpi_pool=mpi_pool, reserved=True) obj_in = 'obj' # Prepare arguments obj_scatter = mpi_pool.split_data(scatter_tuple[1], scatter_tuple[0], splitted=splitted) obj_bcast = {} for key, val in zip(bcast_tuple[0], bcast_tuple[1]): obj_bcast[key] = val obj_args_reduce = mpi_pool.split_data(reduce_tuple[1], reduce_tuple[0], splitted=splitted) # Evaluate fval = mpi_pool.map(f, obj_scatter=obj_scatter, obj_bcast=obj_bcast, dmem_key_in_list=dmem_key_in_list, dmem_arg_in_list=dmem_arg_in_list, obj=obj_in, reduce_obj=reduce_obj, reduce_args=obj_args_reduce) # Put pieces together if reduce_obj is None and concatenate: if isinstance(fval[0], tuple): out = [] for i in range(len(fval[0])): out.append(np.concatenate([fv[i] for fv in fval])) fval = tuple(out) else: fval = np.concatenate(fval, axis=0) # If the input object was not a string, clear it from memory if obj is not None and not isinstance(obj, str): obj_child_lst = mpi_pool.pop_dmem(obj_in) if issubclass(type(obj), TMO): # Update counters obj.update_ncalls_tree(obj_child_lst[0][0]) for (obj_child,) in obj_child_lst: obj.update_nevals_tree(obj_child) obj.update_teval_tree(obj_child) return fval
[docs]def mpi_map_alloc_dmem(f, scatter_tuple=None, bcast_tuple=None, dmem_key_in_list=None, dmem_arg_in_list=None, dmem_val_in_list=None, dmem_key_out_list=None, obj=None, obj_val=None, reduce_obj=None, reduce_tuple=None, mpi_pool=None, splitted=False, concatenate=True): r""" Interface for the parallel evaluation of a generic function on points ``x`` Args: f (:class:`object` or :class:`str`): function or string identifying the function in object ``obj`` scatter_tuple (tuple): tuple containing 2 lists of ``[keys]`` and ``[arguments]`` which will be scattered to the processes. bcast_tuple (tuple): tuple containing 2 lists of ``[keys]`` and ``[arguments]`` which will be broadcasted to the processes. dmem_key_in_list (list): list of string containing the keys to be fetched (or created with default ``None`` if missing) from the distributed memory and provided as input to ``f``. dmem_val_in_list (list): list of objects corresponding to the keys defined in ``dmem_key_in_list``, used in case we are not executing in parallel dmem_key_out_list (list): list of keys to be assigned to the outputs beside the first one obj (object): object where the function ``f_name`` is defined obj_val (object): object to be used in case not executing in parallel and ``obj`` is a string reduce_obj (object): object :class:`ReduceObject` defining the reduce method to be applied (if any) reduce_tuple (object): tuple containing 2 lists of ``[keys]`` and ``[arguments]`` which will be scattered to the processes to be used by ``reduce_obj`` mpi_pool (:class:`mpi_map.MPI_Pool<mpi_map.MPI_Pool>`): pool of processors splitted (bool): whether the scattering input is already splitted or not concatenate (bool): whether to concatenate the output (the output of ``f`` must be a :class:`numpy.ndarray<numpy.ndarray>` object """ # Init un-set arguments if scatter_tuple is None: scatter_tuple = ([], []) if bcast_tuple is None: bcast_tuple = ([], []) if dmem_key_in_list is None: dmem_key_in_list = [] if dmem_arg_in_list is None: dmem_arg_in_list = [] if dmem_val_in_list is None: dmem_val_in_list = [] if dmem_key_out_list is None: dmem_key_out_list = [] if reduce_tuple is None: reduce_tuple = ([], []) # Start evaluation if mpi_pool is None: # Prepare arguments args = {} for key, val in zip(scatter_tuple[0], scatter_tuple[1]): if splitted: if len(val) != 1: raise ValueError("Serial execution but splitted input is %d long" % len(val)) args[key] = val[0] else: args[key] = val for key, val in zip(bcast_tuple[0], bcast_tuple[1]): args[key] = val for key, val in zip(dmem_arg_in_list, dmem_val_in_list): args[key] = val reduce_args = {} for key, val in zip(reduce_tuple[0], reduce_tuple[1]): if splitted: if len(val) != 1: raise ValueError("Serial execution but splitted input is %d long" % len(val)) reduce_args[key] = val[0] else: reduce_args[key] = val # Retrieve function if obj is not None: if isinstance(obj, str): f = getattr(obj_val, f) else: f = getattr(obj, f) # Evaluate pars = f(**args) if not isinstance(pars, tuple): pars = (pars,) else: # Prepare arguments obj_scatter = mpi_pool.split_data(scatter_tuple[1], scatter_tuple[0], splitted=splitted) obj_bcast = {} for key, val in zip(bcast_tuple[0], bcast_tuple[1]): obj_bcast[key] = val obj_args_reduce = mpi_pool.split_data(reduce_tuple[1], reduce_tuple[0], splitted=splitted) # Evaluate pars = tuple([None] * len(dmem_key_out_list)) mpi_pool.map_alloc_dmem( f, obj_scatter=obj_scatter, obj_bcast=obj_bcast, dmem_key_in_list=dmem_key_in_list, dmem_arg_in_list=dmem_arg_in_list, dmem_key_out_list=dmem_key_out_list, obj=obj, reduce_obj=reduce_obj, reduce_args=obj_args_reduce) return pars
@deprecate("mpi_alloc_dmem", ">2.0b2", "Use mpi_bcast_dmem instead.")
[docs]def mpi_alloc_dmem(mpi_pool=None, **kwargs): mpi_bcast_dmem(mpi_pool=mpi_pool, **kwargs)
[docs]def mpi_bcast_dmem(mpi_pool=None, reserved=False, **kwargs): r""" List of keyworded arguments to be allocated in the distributed memory. This executes only if an mpi_pool is provided. Args: mpi_pool (:class:`mpi_map.MPI_Pool<mpi_map.MPI_Pool>`): pool of processors reserved (bool): whether the ``kwargs`` dictionary can contain the reserved key ``obj`` """ if mpi_pool is not None: if not reserved and 'obj' in kwargs: raise ValueError( "The key obj is reserved for default argument passing. " + \ "If you want to use it, pass reserved=True." ) mpi_pool.bcast_dmem(**kwargs)
[docs]def mpi_scatter_dmem(mpi_pool=None, **kwargs): r""" List of keyworded arguments to be scattered in the distributed memory. This executes only if an mpi_pool is provided. Args: mpi_pool (:class:`mpi_map.MPI_Pool<mpi_map.MPI_Pool>`): pool of processors """ if mpi_pool is not None: mpi_pool.scatter_dmem(**kwargs)
# # MPI REDUCE OPERATIONS #
[docs]class SumChunkReduce(object): r""" Define the summation of the chunks operation. The chunks resulting from the output of the MPI evaluation are summed along their ``axis``. Args: axis (tuple [2]): tuple containing list of axes to be used in the :func:`sum<numpy.sum>` operation """ def __init__(self, axis=None): self.axis = axis
[docs] def inner_reduce(self, x, *args, **kwargs): return x
[docs] def outer_reduce(self, x, *args, **kwargs): return np.sum(x, axis=self.axis)
[docs]class TupleSumChunkReduce(SumChunkReduce): r""" Define the summation of the chunks operation over list of tuples. The chunks resulting from the output of the MPI evaluation are summed along their ``axis``. Args: axis (tuple [2]): tuple containing list of axes to be used in the :func:`sum<numpy.sum>` operation """
[docs] def outer_reduce(self, x, *args, **kwargs): out = [] for i in range(len(x[0])): xin = [xx[i] for xx in x] out.append(super(TupleSumChunkReduce, self).outer_reduce(xin, *args, **kwargs)) return tuple(out)
[docs]class TensorDotReduce(object): r""" Define the reduce tensordot operation carried out through the mpi_map function Args: axis (tuple [2]): tuple containing list of axes to be used in the :func:`tensordot<numpy.tensordot>` operation """ def __init__(self, axis): self.axis = axis
[docs] def inner_reduce(self, x, w): if x.shape[self.axis[0]] > 0: return np.tensordot(x, w, self.axis) else: return 0.
[docs] def outer_reduce(self, x, w): return sum(x)
[docs]class ExpectationReduce(TensorDotReduce): r""" Define the expectation operation carried out through the mpi_map function """ def __init__(self): super(ExpectationReduce, self).__init__((0, 0))
[docs]class AbsExpectationReduce(ExpectationReduce): r""" Define the expectation of the absolute value: :math:`\mathbb{E}[\vert {\bf X} \vert]` """
[docs] def inner_reduce(self, x, w): return super(AbsExpectationReduce, self).inner_reduce(np.abs(x), w)
[docs]class TupleExpectationReduce(ExpectationReduce): r""" Define the expectation operation applied on a tuple If we are given a tuple :math:`(x_1,x_2)`, the inner reduce returns :math:`(\langle x_1,w\rangle , \langle x_2, w\rangle)`. Given a list of tuples :math:`\{(x_i,y_i\}_{i=0}^n`, the outer reduce gives :math:`(\sum x_i, \sum y_i)`. """
[docs] def inner_reduce(self, x, w): out = [] for xx in x: out.append(super(TupleExpectationReduce, self).inner_reduce(xx, w)) return tuple(out)
[docs] def outer_reduce(self, x, w): out = [] tdim = len(x[0]) for i in range(tdim): xin = [xx[i] for xx in x] out.append(super(TupleExpectationReduce, self).outer_reduce(xin, w)) return tuple(out)
# # Distributed operations #
[docs]def distributed_sampling( d, qtype, qparams, mpi_pool=None): nprocs = mpi_pool.nprocs if mpi_pool is not None else 1 qparams_split = qparams // nprocs qparams_reminder = qparams % nprocs qparams_list = [ qparams_split if i >= qparams_reminder else qparams_split + 1 for i in range(nprocs)] mass_list = [float(n) / float(qparams) for n in qparams_list] scatter_tuple = (['qparams', 'mass'], [qparams_list, mass_list]) bcast_tuple = (['qtype'], [qtype]) (x, w) = mpi_map_alloc_dmem( 'quadrature', scatter_tuple=scatter_tuple, splitted=True, bcast_tuple=bcast_tuple, dmem_key_out_list=['x', 'w'], obj=d, mpi_pool=mpi_pool, concatenate=False) return (x, w)