import math
from typing import Any, Dict, Optional, Tuple, Union
import numpy as np
from mpi4py import MPI
from pylops.utils.backend import get_module
from pylops_mpi import DistributedArray, MPILinearOperator, Partition
from pylops_mpi.Distributed import DistributedMixIn
def halo_block_split(
global_shape: tuple,
comm: MPI.Comm,
grid_shape: Optional[tuple] = None,
) -> tuple:
r"""Split a global array over a Cartesian process grid.
Compute the local slice owned by the calling rank when ``global_shape`` is
distributed over ``grid_shape``. This helper follows the same Cartesian
partitioning used internally by :class:`MPIHalo`.
Parameters
----------
global_shape : :obj:`tuple`
Shape of the global array before flattening.
comm : :obj:`mpi4py.MPI.Comm`
MPI communicator containing the ranks in the process grid.
grid_shape : :obj:`tuple`, optional
Number of ranks along each array axis. When ``None``, all ranks are
placed along the last axis.
Returns
-------
local_slice : :obj:`tuple`
Tuple of :class:`slice` objects selecting the local block owned by the
calling rank.
Raises
------
ValueError
If ``grid_shape`` does not contain exactly ``comm.Get_size()`` ranks.
"""
ndim = len(global_shape)
size = comm.Get_size()
# default: put all ranks on the last axis
if grid_shape is None:
grid_shape = (1,) * (ndim - 1) + (size,)
if math.prod(grid_shape) != size:
raise ValueError(f"grid_shape {grid_shape} does not match comm size {size}")
cart = comm.Create_cart(grid_shape, periods=[False] * ndim, reorder=True)
coords = cart.Get_coords(cart.Get_rank())
slices = []
for gdim, procs_on_axis, coord in zip(global_shape, grid_shape, coords):
block_size = math.ceil(gdim / procs_on_axis)
start = coord * block_size
end = min(start + block_size, gdim)
if coord == procs_on_axis - 1:
sl = slice(start, None)
else:
sl = slice(start, end)
slices.append(sl)
return tuple(slices)
[docs]
class MPIHalo(DistributedMixIn, MPILinearOperator):
r"""MPI Halo
Apply haloing to all dimensions of a flattened, 1-dimensional
:class:`pylops_mpi.DistributedArray` after local reshaping to a
N-dimensional array.
The Halo operator is applied over a Cartesian process grid, where each
rank owns a local block of the global N-dimensional array.
Parameters
----------
dims : :obj:`tuple`
Number of samples for each dimension.
halo : :obj:`int` or :obj:`tuple`
Number of halo samples to add around each local block. A scalar value
applies the same halo to both sides of every axis. A tuple of length
``ndim`` applies a symmetric halo per axis. A tuple of length
``2 * ndim`` specifies the halo to apply at the start and at the end
for each axis.
proc_grid_shape : :obj:`tuple`
Number of MPI ranks along each dimension.
comm : :obj:`mpi4py.MPI.Comm`, optional
MPI Base Communicator. Defaults to ``mpi4py.MPI.COMM_WORLD``.
dtype : :obj:`str`, optional
Type of elements in the input array.
Attributes
----------
shape : :obj:`tuple`
Operator shape
Notes
-----
The MPIHalo operator extends each rank's local array with a **halo** (ghost cells)
of width ``halo`` along the haloed axes, providing the neighboring rank's data for
stencil-like operators. Ranks are arranged in an N-dimensional Cartesian grid as
provided by ``proc_grid_shape``, whereby each rank owns a contiguous block of the
global array. The ``halo`` is normalised to a tuple of length ``2 * ndim``, containing
one ``(minus, plus)`` halo-width pair for each axis. The tuple is flattened in
axis order as
.. math::
(h_{0,-}, h_{0,+}, h_{1,-}, h_{1,+}, \ldots)
where :math:`h_{i,-}`` and :math:`h_{i,+}`` represent the halo widths on the
negative and positive side of the i-th axis, respectively. For convenience,
``halo`` may be provided as a scalar when the same symmetric halo is
required along all axes, or as a tuple of length ``ndim`` when symmetric halos
of different width are applied to different axis. Ghost cells on the global
boundary of an axis are zero by default.
In the forward mode, each rank exchanges boundary slices with its left and
right neighbors along each axis via ``MPI_Sendrecv``. Ranks at a global
boundary have ``MPI.PROC_NULL`` as their neighbor on that side, so those
ghost regions remain zero and no exchange is attempted. Once the exchange
is complete, local PyLops operators can be applied independently
on each rank's extended block, typically wrapped into a
:class:`pylops_mpi.basicoperators.MPIBlockDiag` operator.
In the adjoint mode, the reverse operation is performed the original
local domain is extracted by removing the ghost cells.
Finally, note that the Halo operator is not linear operator per se; instead,
it is meant to sandwitch any linear operator to implement equivalent
behaviours to the serial version of such an operator.
"""
def __init__(
self,
dims: tuple,
halo: Union[int, tuple],
proc_grid_shape: Optional[tuple] = None,
comm: MPI.Comm = MPI.COMM_WORLD,
dtype: Any = np.float64,
) -> None:
self.global_dims = tuple(dims)
self.ndim = len(dims)
self.comm = comm
self.dtype = dtype
if proc_grid_shape is None:
proc_grid_shape = (1,) * (self.ndim - 1) + (self.comm.Get_size(),)
self.proc_grid_shape = tuple(proc_grid_shape)
if math.prod(self.proc_grid_shape) != self.comm.Get_size():
raise ValueError(
f"grid_shape {self.proc_grid_shape} does not match comm size {self.comm.Get_size()}"
)
self.cart_comm, self.neigh = self._build_topo()
self.halo = self._parse_halo(halo)
self.local_dims = self._calc_local_dims()
self.local_extent = self._calc_local_extent()
self._validate_exchange_widths()
self._local_dim_sizes = []
# For uneven global dimensions, MPIHalo's Cartesian block sizes differ
# from DistributedArray's default flat split. Store those sizes so
# _rmatvec can build an adjoint output with the same local ownership
# as the original Halo input.
comm_group = self.comm.Get_group()
cart_group = self.cart_comm.Get_group()
for rank in range(self.comm.Get_size()):
cart_rank = MPI.Group.Translate_ranks(comm_group, [rank], cart_group)[0]
coords = self.cart_comm.Get_coords(cart_rank)
local_size = 1
for gdim, coord, grid_procs in zip(self.global_dims, coords, self.proc_grid_shape):
block_size = math.ceil(gdim / grid_procs)
start = coord * block_size
end = min(start + block_size, gdim)
local_size *= end - start
self._local_dim_sizes.append(local_size)
self._local_extent_sizes = self._allgather(
self.comm,
None,
int(np.prod(self.local_extent)),
)
self.shape = (
int(np.sum(self._local_extent_sizes)),
int(np.prod(self.global_dims)),
)
super().__init__(shape=self.shape, dtype=np.dtype(dtype), base_comm=comm)
def _parse_halo(self, h: Union[int, tuple]) -> tuple:
"""Normalize halo input to a 2 * ndim tuple of per-side widths for each axis of the N-dimensional array.
Accepts a scalar, a tuple of length-1, one value per axis (the same value is assigned to both sides),
or explicit minus/plus pairs for each axis.
"""
if isinstance(h, (int, np.int64, np.int32)):
halo = (h,) * (2 * self.ndim)
trimmed = list(halo)
for ax in range(self.ndim):
if trimmed[2 * ax] and self.neigh[("-", ax)] == MPI.PROC_NULL:
trimmed[2 * ax] = 0
if trimmed[2 * ax + 1] and self.neigh[("+", ax)] == MPI.PROC_NULL:
trimmed[2 * ax + 1] = 0
halo = tuple(trimmed)
if any(h < 0 for h in halo):
raise ValueError("Halo widths must be non-negative")
return halo
h = tuple(h)
if len(h) == 1:
halo = h * (2 * self.ndim)
elif len(h) == self.ndim:
halo = sum(tuple([(d, d) for d in h]), ())
elif len(h) == 2 * self.ndim:
halo = h
else:
raise ValueError(f"Invalid halo length {len(h)} for ndim={self.ndim}")
if any(h < 0 for h in halo):
raise ValueError("Halo widths must be non-negative")
return halo
def _build_topo(self) -> Tuple[MPI.Comm, Dict[Tuple[str, int], int]]:
"""Create the Cartesian communicator and map neighboring ranks on the distribution axis."""
cart_comm = self.comm.Create_cart(
self.proc_grid_shape,
periods=[False] * self.ndim,
reorder=True,
)
neigh = {}
for ax in range(self.ndim):
before, after = cart_comm.Shift(ax, 1)
neigh[("-", ax)] = before
neigh[("+", ax)] = after
return cart_comm, neigh
def _calc_local_dims(self) -> tuple:
"""Compute this rank's local block shape before halo padding."""
rank = self.cart_comm.Get_rank()
coords = self.cart_comm.Get_coords(rank)
local = []
for ax, (gdim, coord, grid_procs) in enumerate(
zip(self.global_dims, coords, self.proc_grid_shape)
):
block_size = math.ceil(gdim / grid_procs)
start = coord * block_size
end = min(start + block_size, gdim)
local.append(end - start)
return tuple(local)
def _calc_local_extent(self) -> tuple:
"""Compute this rank's local block shape after halo padding."""
ext = []
for ax in range(self.ndim):
minus_halo, plus_halo = self.halo[2 * ax], self.halo[2 * ax + 1]
ext.append(self.local_dims[ax] + minus_halo + plus_halo)
return tuple(ext)
def _validate_local_array_shape(
self, x: DistributedArray, expected_shape: tuple, name: str
) -> None:
"""Raise if a distributed input does not match this rank's expected local shape."""
local_shapes = self.cart_comm.allgather(
(x.local_array.size, int(np.prod(expected_shape)), expected_shape)
)
for rank, (actual_size, expected_size, shape) in enumerate(local_shapes):
if actual_size != expected_size:
raise ValueError(
"MPIHalo input local shapes do not match the Cartesian block "
f"decomposition: rank {rank}: {name} local array has size "
f"{actual_size}, expected {expected_size} for local shape {shape}"
)
def _validate_exchange_widths(self) -> None:
"""
Raise if the requested halos cannot be exchanged with one-hop neighbors.
For example:
- Halo width Larger than local block size or that of the remote neighbors.
"""
width_error = 1
mismatch_error = 2
local_error = 0
for ax in range(self.ndim):
before, after = self.halo[2 * ax], self.halo[2 * ax + 1]
minus_nbr, plus_nbr = self.neigh[("-", ax)], self.neigh[("+", ax)]
local_dim = self.local_dims[ax]
if before > local_dim and minus_nbr != MPI.PROC_NULL:
local_error |= width_error
if after > local_dim and plus_nbr != MPI.PROC_NULL:
local_error |= width_error
plus_neighbor_before = self.cart_comm.sendrecv(
before, dest=minus_nbr, source=plus_nbr
)
minus_neighbor_after = self.cart_comm.sendrecv(after, dest=plus_nbr, source=minus_nbr)
if plus_nbr != MPI.PROC_NULL and after != plus_neighbor_before:
local_error |= mismatch_error
if minus_nbr != MPI.PROC_NULL and before != minus_neighbor_after:
local_error |= mismatch_error
global_error = self.cart_comm.allreduce(local_error, op=MPI.BOR)
if global_error & width_error:
raise ValueError(
"MPIHalo halo widths are not supported by the current one-hop "
"exchange: halo width exceeds local block size"
)
if global_error & mismatch_error:
raise ValueError(
"MPIHalo halo widths are not supported by the current one-hop "
"exchange: halo width does not match neighbor halo width"
)
def _exchange_along_axis(self, ncp: Any, arr: Any, axis: int, before: int, after: int, engine: str) -> None:
"""Exchange boundary/halo slices with neighboring ranks along one axis."""
minus_nbr, plus_nbr = self.neigh[("-", axis)], self.neigh[("+", axis)]
# slice definitions
slicer = [slice(None)] * self.ndim
# send before
if before and minus_nbr != MPI.PROC_NULL:
snd_s = slicer.copy()
snd_s[axis] = slice(before, 2 * before)
snd = arr[tuple(snd_s)].copy()
rcv = ncp.empty_like(snd)
rcv = self._sendrecv(
self.cart_comm,
None,
snd,
rcv,
dest=minus_nbr,
source=minus_nbr,
engine=engine,
)
rcv_s = slicer.copy()
rcv_s[axis] = slice(0, before)
arr[tuple(rcv_s)] = rcv
# send after
if after and plus_nbr != MPI.PROC_NULL:
snd_s = slicer.copy()
snd_s[axis] = slice(-2 * after, -after)
rcv_s = slicer.copy()
rcv_s[axis] = slice(-after, None)
snd = arr[tuple(snd_s)].copy()
rcv = ncp.empty_like(snd)
rcv = self._sendrecv(
self.cart_comm,
None,
snd,
rcv,
dest=plus_nbr,
source=plus_nbr,
engine=engine,
)
arr[tuple(rcv_s)] = rcv
def _matvec(self, x: DistributedArray) -> DistributedArray:
ncp = get_module(x.engine)
if x.partition != Partition.SCATTER:
raise ValueError(
f"x should have partition={Partition.SCATTER} Got {x.partition} instead..."
)
self._validate_local_array_shape(x, self.local_dims, "x")
y = DistributedArray(
global_shape=self.shape[0],
partition=Partition.SCATTER,
local_shapes=self._local_extent_sizes,
base_comm=x.base_comm,
base_comm_nccl=x.base_comm_nccl,
engine=x.engine,
dtype=self.dtype,
)
core = x.local_array.reshape(self.local_dims)
halo_arr = ncp.zeros(self.local_extent, dtype=self.dtype)
# insert core
core_slices = [
slice(left, left + ldim)
for left, ldim in zip(self.halo[::2], self.local_dims)
]
halo_arr[tuple(core_slices)] = core
# exchange along each axis
for ax in range(self.ndim):
before, after = self.halo[2 * ax], self.halo[2 * ax + 1]
self._exchange_along_axis(
ncp, halo_arr, axis=ax, before=before, after=after, engine=x.engine
)
y[:] = halo_arr.ravel()
return y
def _rmatvec(self, x: DistributedArray) -> DistributedArray:
if x.partition != Partition.SCATTER:
raise ValueError(
f"x should have partition={Partition.SCATTER} Got {x.partition} instead..."
)
self._validate_local_array_shape(x, self.local_extent, "x")
res = DistributedArray(
global_shape=self.shape[1],
partition=Partition.SCATTER,
local_shapes=self._local_dim_sizes,
base_comm=x.base_comm,
base_comm_nccl=x.base_comm_nccl,
engine=x.engine,
dtype=self.dtype,
)
arr = x.local_array.reshape(self.local_extent)
core_slices = [
slice(left, left + ldim)
for left, ldim in zip(self.halo[::2], self.local_dims)
]
core = arr[tuple(core_slices)]
res[:] = core.ravel()
return res