__all__ = ["active_grid_comm",
"block_gather",
"local_block_split",
"MPIMatrixMult",
]
import math
import numpy as np
from mpi4py import MPI
from typing import Tuple, Literal
from pylops.utils.backend import get_module
from pylops.utils.typing import DTypeLike, NDArray
from pylops_mpi import (
DistributedArray,
MPILinearOperator,
Partition
)
[docs]
def active_grid_comm(base_comm: MPI.Comm, N: int, M: int):
r"""Active grid for distributed matrix multiplication.
Configure a square process grid from a parent MPI communicator and
select a subset of "active" processes. Each process in ``base_comm``
is assigned to a logical 2D grid of size :math:`P' \times P'`,
where :math:`P' = \bigl \lceil \sqrt{P} \bigr \rceil`. Only the first
:math:`active_{dim} \times active_{dim}` processes
(by row-major order) are considered "active". Inactive ranks return
immediately with no new communicator.
Parameters:
-----------
base_comm : :obj:`mpi4py.MPI.Comm`
MPI Parent Communicator. (e.g., ``mpi4py.MPI.COMM_WORLD``).
N : :obj:`int`
Number of rows of the global data domain.
M : :obj:`int`
Number of columns of the global data domain.
Returns:
--------
comm : :obj:`mpi4py.MPI.Comm`
Sub-communicator including only active ranks.
rank : :obj:`int`
Rank within the new sub-communicator (or original rank
if inactive).
row : :obj:`int`
Grid row index of this process in the active grid (or original rank
if inactive).
col : :obj:`int`
Grid column index of this process in the active grid
(or original rank if inactive).
is_active : :obj:`bool`
Flag indicating whether this rank is in the active sub-grid.
"""
rank = base_comm.Get_rank()
size = base_comm.Get_size()
p_prime = math.isqrt(size)
row, col = divmod(rank, p_prime)
active_dim = min(N, M, p_prime)
is_active = (row < active_dim and col < active_dim)
if not is_active:
return None, rank, row, col, False
active_ranks = [r for r in range(size)
if (r // p_prime) < active_dim and (r % p_prime) < active_dim]
new_group = base_comm.Get_group().Incl(active_ranks)
new_comm = base_comm.Create_group(new_group)
p_prime_new = math.isqrt(len(active_ranks))
new_rank = new_comm.Get_rank()
new_row, new_col = divmod(new_rank, p_prime_new)
return new_comm, new_rank, new_row, new_col, True
[docs]
def local_block_split(global_shape: Tuple[int, int],
rank: int,
comm: MPI.Comm) -> Tuple[slice, slice]:
r"""Local sub‐block of a 2D global array
Compute the local sub‐block of a 2D global array for a process in a square
process grid.
Parameters
----------
global_shape : :obj:`tuple`
Dimensions of the global 2D array ``(n_rows, n_cols)``.
rank : :obj:`int`
Rank of the MPI process in `comm` for which to get the owned block partition.
comm : :obj:`mpi4py.MPI.Comm`
MPI communicator whose total number of processes :math:`P`
must be a perfect square :math:`P = \sqrt{P'}}`.
Returns
-------
Tuple[slice, slice]
Two `slice` objects `(row_slice, col_slice)` representing the sub‐block
of the global array owned by this rank.
Raises
------
ValueError
If `rank` is not an integer value or out of range.
RuntimeError
If the number of processes participating in the provided communicator
is not a perfect square.
"""
size = comm.Get_size()
p_prime = math.isqrt(size)
if p_prime * p_prime != size:
raise RuntimeError(f"Number of processes must be a square number, "
f"provided {size} instead...")
if not (isinstance(rank, int) and 0 <= rank < size):
raise ValueError(f"rank must be an integer in [0, {size}), got {rank!r}")
pr, pc = divmod(rank, p_prime)
orig_r, orig_c = global_shape
new_r = math.ceil(orig_r / p_prime) * p_prime
new_c = math.ceil(orig_c / p_prime) * p_prime
blkr, blkc = new_r // p_prime, new_c // p_prime
rs, cs = pr * blkr, pc * blkc
re, ce = min(rs + blkr, orig_r), min(cs + blkc, orig_c)
return slice(rs, re), slice(cs, ce)
[docs]
def block_gather(x: DistributedArray, orig_shape: Tuple[int, int], comm: MPI.Comm):
r"""Local block from 2D block distributed matrix
Gather distributed local blocks from 2D block distributed matrix distributed
amongst a square process grid into the full global array.
Parameters
----------
x : :obj:`pylops_mpi.DistributedArray`
The distributed array to gather locally.
orig_shape : :obj:`tuple`
Global shape ``(N, M)`` of the global array to be gathered.
comm : :obj:`mpi4py.MPI.Comm`
MPI communicator whose size must be a perfect square (:math:`P = P'^2`).
Returns
-------
Array
The reconstructed 2D array of shape ``orig_shape``, assembled from
the distributed blocks.
Raises
------
RuntimeError
If the number of processes participating in the provided communicator
is not a perfect square.
"""
ncp = get_module(x.engine)
p_prime = math.isqrt(comm.Get_size())
if p_prime * p_prime != comm.Get_size():
raise RuntimeError(f"Communicator size must be a perfect square, got {comm.Get_size()!r}")
all_blks = comm.allgather(x.local_array)
nr, nc = orig_shape
br, bc = math.ceil(nr / p_prime), math.ceil(nc / p_prime)
C = ncp.zeros((nr, nc), dtype=all_blks[0].dtype)
for rank in range(p_prime * p_prime):
pr, pc = divmod(rank, p_prime)
rs, cs = pr * br, pc * bc
re, ce = min(rs + br, nr), min(cs + bc, nc)
if len(all_blks[rank]) != 0:
C[rs:re, cs:ce] = all_blks[rank].reshape(re - rs, cs - ce)
return C
class _MPIBlockMatrixMult(MPILinearOperator):
r"""MPI Blocked Matrix multiplication
Implement distributed matrix-matrix multiplication between a matrix
:math:`\mathbf{A}` blocked over rows (i.e., blocks of rows are stored
over different ranks) and the input model and data vector, which are both to
be interpreted as matrices blocked over columns.
Parameters
----------
A : :obj:`numpy.ndarray`
Local block of the matrix of shape :math:`[N_{loc} \times K]`
where :math:`N_{loc}` is the number of rows stored on this MPI rank and
``K`` is the global number of columns.
M : :obj:`int`
Global leading dimension (i.e., number of columns) of the matrices
representing the input model and data vectors.
saveAt : :obj:`bool`, optional
Save :math:`\mathbf{A}` and ``A.H`` to speed up the computation of adjoint
(``True``) or create ``A.H`` on-the-fly (``False``)
Note that ``saveAt=True`` will double the amount of required memory.
Default is ``False``.
base_comm : :obj:`mpi4py.MPI.Comm`, optional
MPI Base Communicator. Defaults to ``mpi4py.MPI.COMM_WORLD``.
dtype : :obj:`str`, optional
Type of elements in input array.
Attributes
----------
shape : :obj:`tuple`
Operator shape
Raises
------
Exception
If the operator is created with a non-square number of MPI ranks.
ValueError
If input vector does not have the correct partition type.
Notes
-----
This operator performs a matrix-matrix multiplication, whose forward
operation can be described as :math:`Y = A \cdot X` where:
- :math:`\mathbf{A}` is the distributed matrix operator of shape :math:`[N \times K]`
- :math:`\mathbf{X}` is the distributed operand matrix of shape :math:`[K \times M]`
- :math:`\mathbf{Y}` is the resulting distributed matrix of shape :math:`[N \times M]`
whilst the adjoint operation is represented by
:math:`\mathbf{X}_{adj} = \mathbf{A}^H \cdot \mathbf{Y}` where
:math:`\mathbf{A}^H` is the complex conjugate and transpose of :math:`\mathbf{A}`.
This implementation is based on a 1D block distribution of the operator
matrix and reshaped model and data vectors replicated across :math:`P`
processes by a factor equivalent to :math:`\sqrt{P}` across a square process
grid (:math:`\sqrt{P}\times\sqrt{P}`). More specifically:
- The matrix :math:`\mathbf{A}` is distributed across MPI processes in a block-row fashion
and each process holds a local block of :math:`\mathbf{A}` with shape
:math:`[N_{loc} \times K]`
- The operand matrix :math:`\mathbf{X}` is distributed in a block-column fashion and
each process holds a local block of :math:`\mathbf{X}` with shape
:math:`[K \times M_{loc}]`
- Communication is minimized by using a 2D process grid layout
**Forward Operation step-by-step**
1. **Input Preparation**: The input vector ``x`` (flattened from matrix :math:`\mathbf{X}`
of shape ``(K, M)``) is reshaped to ``(K, M_local)`` where ``M_local``
is the number of columns assigned to the current process.
2. **Local Computation**: Each process computes ``A_local @ X_local`` where:
- ``A_local`` is the local block of matrix :math:`\mathbf{A}` (shape ``N_local x K``)
- ``X_local`` is the broadcasted operand (shape ``K x M_local``)
3. **Row-wise Gather**: Results from all processes in each row are gathered
using ``allgather`` to ensure that each rank has a block-column of the
output matrix.
**Adjoint Operation step-by-step**
The adjoint operation performs the conjugate transpose multiplication:
1. **Input Reshaping**: The input vector ``x`` is reshaped to ``(N, M_local)``
representing the local columns of the input matrix.
2. **Local Adjoint Computation**: Each process computes
``A_local.H @ X_tile`` where ``A_local.H`` is either pre-computed
and stored in ``At`` (if ``saveAt=True``), or computed on-the-fly as
``A.T.conj()`` (if ``saveAt=False``). Each process multiplies its
transposed local :math:`\mathbf{A}` block ``A_local^H`` (shape ``K x N_block``)
with the extracted ``X_tile`` (shape ``N_block x M_local``),
producing a partial result of shape ``(K, M_local)``.
This computes the local contribution of columns of ``A^H`` to the final
result.
3. **Row-wise Reduction**: Since the full result ``Y = A^H \cdot X`` is the
sum of the contributions from all column blocks of ``A^H``, processes in
the same row perform an ``allreduce`` sum to combine their partial results.
This gives the complete ``(K, M_local)`` result for their assigned column.
"""
def __init__(
self,
A: NDArray,
M: int,
saveAt: bool = False,
base_comm: MPI.Comm = MPI.COMM_WORLD,
dtype: DTypeLike = "float64",
) -> None:
rank = base_comm.Get_rank()
size = base_comm.Get_size()
# Determine grid dimensions (P_prime × C) such that P_prime * C ≥ size
self._P_prime = math.isqrt(size)
self._C = self._P_prime
if self._P_prime * self._C != size:
raise Exception(f"Number of processes must be a square number, provided {size} instead...")
self._col_id = rank % self._P_prime
self._row_id = rank // self._P_prime
self.base_comm = base_comm
self._row_comm = base_comm.Split(color=self._row_id, key=self._col_id)
self._col_comm = base_comm.Split(color=self._col_id, key=self._row_id)
self.A = A.astype(np.dtype(dtype))
if saveAt:
self.At = A.T.conj()
self.N = self._row_comm.allreduce(self.A.shape[0], op=MPI.SUM)
self.K = A.shape[1]
self.M = M
block_cols = int(math.ceil(self.M / self._P_prime))
blk_rows = int(math.ceil(self.N / self._P_prime))
self._row_start = self._col_id * blk_rows
self._row_end = min(self.N, self._row_start + blk_rows)
self._col_start = self._row_id * block_cols
self._col_end = min(self.M, self._col_start + block_cols)
self._local_ncols = max(0, self._col_end - self._col_start)
self._rank_col_lens = self.base_comm.allgather(self._local_ncols)
total_ncols = np.sum(self._rank_col_lens)
self.dims = (self.K, total_ncols)
self.dimsd = (self.N, total_ncols)
shape = (int(np.prod(self.dimsd)), int(np.prod(self.dims)))
super().__init__(shape=shape, dtype=np.dtype(dtype), base_comm=base_comm)
def _matvec(self, x: DistributedArray) -> DistributedArray:
ncp = get_module(x.engine)
if x.partition != Partition.SCATTER:
raise ValueError(f"x should have partition={Partition.SCATTER} Got {x.partition} instead...")
output_dtype = np.result_type(self.dtype, x.dtype)
y = DistributedArray(
global_shape=(self.N * self.dimsd[1]),
local_shapes=[(self.N * c) for c in self._rank_col_lens],
mask=x.mask,
partition=Partition.SCATTER,
engine=x.engine,
dtype=output_dtype,
base_comm=self.base_comm
)
my_own_cols = self._rank_col_lens[self.rank]
x_arr = x.local_array.reshape((self.dims[0], my_own_cols))
X_local = x_arr.astype(output_dtype)
Y_local = ncp.vstack(
self._row_comm.allgather(
ncp.matmul(self.A, X_local)
)
)
y[:] = Y_local.flatten()
return y
def _rmatvec(self, x: DistributedArray) -> DistributedArray:
ncp = get_module(x.engine)
if x.partition != Partition.SCATTER:
raise ValueError(f"x should have partition={Partition.SCATTER}. Got {x.partition} instead.")
# - If A is real: A^H = A^T,
# so result_type(real_A, x.dtype) = x.dtype (if x is complex) or real (if x is real)
# - If A is complex: A^H is complex,
# so result will be complex regardless of x
if np.iscomplexobj(self.A):
output_dtype = np.result_type(self.dtype, x.dtype)
else:
# Real matrix: A^T @ x preserves input type complexity
output_dtype = x.dtype if np.iscomplexobj(x.local_array) else self.dtype
# But still need to check type promotion for precision
output_dtype = np.result_type(self.dtype, output_dtype)
y = DistributedArray(
global_shape=(self.K * self.dimsd[1]),
local_shapes=[self.K * c for c in self._rank_col_lens],
mask=x.mask,
partition=Partition.SCATTER,
engine=x.engine,
dtype=output_dtype,
base_comm=self.base_comm
)
x_arr = x.local_array.reshape((self.N, self._local_ncols)).astype(output_dtype)
X_tile = x_arr[self._row_start:self._row_end, :]
A_local = self.At if hasattr(self, "At") else self.A.T.conj()
Y_local = ncp.matmul(A_local, X_tile)
y_layer = self._row_comm.allreduce(Y_local, op=MPI.SUM)
y[:] = y_layer.flatten()
return y
class _MPISummaMatrixMult(MPILinearOperator):
r"""MPI SUMMA Matrix multiplication
Implements distributed matrix-matrix multiplication using the SUMMA algorithm
between a matrix :math:`\mathbf{A}` distributed over a 2D process grid and
input model and data vectors, which are both interpreted as matrices
distributed in block fashion wherein each process owns a tile of the matrix.
Parameters
----------
A : :obj:`numpy.ndarray`
Local block of the matrix of shape :math:`[N_{loc} \times K_{loc}]`
where :math:`N_{loc}` and :math:`K_{loc}` are the number of rows and
columns stored on this MPI rank.
M : :obj:`int`
Global leading dimension (i.e., number of columns) of the matrices
representing the input model and data vectors.
saveAt : :obj:`bool`, optional
Save :math:`\mathbf{A}` and ``A.H`` to speed up the computation of adjoint
(``True``) or create ``A.H`` on-the-fly (``False``).
Note that ``saveAt=True`` will double the amount of required memory.
Default is ``False``.
base_comm : :obj:`mpi4py.MPI.Comm`, optional
MPI Base Communicator. Defaults to ``mpi4py.MPI.COMM_WORLD``.
dtype : :obj:`str`, optional
Type of elements in input array.
Attributes
----------
shape : :obj:`tuple`
Operator shape
Raises
------
Exception
If the operator is created with a non-square number of MPI ranks.
ValueError
If input vector does not have the correct partition type.
Notes
-----
This operator performs distributed matrix-matrix multiplication using the
SUMMA (Scalable Universal Matrix Multiplication Algorithm), whose forward
operation can be described as :math:`\mathbf{Y} = \mathbf{A} \cdot \mathbf{X}` where:
- :math:`\mathbf{A}` is the distributed matrix operator of shape :math:`[N \times K]`
- :math:`\mathbf{X}` is the distributed operand matrix of shape :math:`[K \times M]`
- :math:`\mathbf{Y}` is the resulting distributed matrix of shape :math:`[N \times M]`
The adjoint operation is represented by
:math:`\mathbf{X}_{adj} = \mathbf{A}^H \cdot \mathbf{Y}` where
:math:`\mathbf{A}^H` is the complex conjugate transpose of :math:`\mathbf{A}`.
This implementation is based on a 2D block distribution across a square process
grid (:math:`\sqrt{P}\times\sqrt{P}`). The matrices are distributed as follows:
- The matrix :math:`\mathbf{A}` is distributed across MPI processes in 2D blocks where
each process holds a local block of :math:`\mathbf{A}` with shape :math:`[N_{loc} \times K_{loc}]`
where :math:`N_{loc} = \frac{N}{\sqrt{P}}` and :math:`K_{loc} = \frac{K}{\sqrt{P}}`.
- The operand matrix :math:`\mathbf{X}` is also distributed across MPI processes in 2D blocks where
each process holds a local block of :math:`\mathbf{X}` with shape :math:`[K_{loc} \times M_{loc}]`
where :math:`K_{loc} = \frac{K}{\sqrt{P}}` and :math:`M_{loc} = \frac{M}{\sqrt{P}}`.
- The result matrix :math:`\mathbf{Y}` is also distributed across MPI processes in 2D blocks where
each process holds a local block of :math:`\mathbf{Y}` with shape :math:`[N_{loc} \times M_{loc}]`
where :math:`N_{loc} = \frac{N}{\sqrt{P}}` and :math:`M_{loc} = \frac{M}{\sqrt{P}}`.
**Forward Operation (SUMMA Algorithm)**
The forward operation implements the SUMMA algorithm:
1. **Input Preparation**: The input vector ``x``is reshaped to ``(K_{loc}, M_{loc})`` representing
the local block assigned to the current process.
2. **SUMMA Iteration**: For each step ``k`` in the SUMMA algorithm -- :math:`k \in \[ 0, \sqrt{P} \)}` :
a. **Broadcast A blocks**: Process in column ``k`` broadcasts its :math:`\mathbf{A}`
block to all other processes in the same process row.
b. **Broadcast X blocks**: Process in row ``k`` broadcasts its :math:`\mathbf{X}`
block to all other processes in the same process column.
c. **Local Computation**: Each process computes the partial matrix
product ``A_broadcast @ X_broadcast`` and accumulates it to its
local result.
3. **Result Assembly**: After all k SUMMA iterations, each process has computed
its local block of the result matrix :math:`\mathbf{Y}`.
**Adjoint Operation (SUMMA Algorithm)**
The adjoint operation performs the conjugate transpose multiplication using
a modified SUMMA algorithm:
1. **Input Reshaping**: The input vector ``x`` is reshaped to ``(N_{loc}, M_{loc})``
representing the local block of the input matrix.
2. **SUMMA Adjoint Iteration**: For each step ``k`` in the adjoint SUMMA algorithm:
a. **Broadcast A^H blocks**: The conjugate transpose of :math:`\mathbf{A}` blocks is
communicated between processes. If ``saveAt=True``, the pre-computed
``A.H`` is used; otherwise, ``A.T.conj()`` is computed on-the-fly.
b. **Broadcast Y blocks**: Process in row ``k`` broadcasts its :math:`\mathbf{Y}`
block to all other processes in the same process column.
c. **Local Adjoint Computation**: Each process computes the partial
matrix product ``A_H_broadcast @ Y_broadcast`` and accumulates it
to the local result.
3. **Result Assembly**: After all adjoint SUMMA iterations, each process has
computed its local block of the result matrix ``X_{adj}``.
The implementation handles padding automatically to ensure proper block sizes
for the square process grid, and unpadding is performed before returning results.
"""
def __init__(
self,
A: NDArray,
M: int,
saveAt: bool = False,
base_comm: MPI.Comm = MPI.COMM_WORLD,
dtype: DTypeLike = "float64",
) -> None:
rank = base_comm.Get_rank()
size = base_comm.Get_size()
# Determine grid dimensions (P_prime × C) such that P_prime * C ≥ size
self._P_prime = math.isqrt(size)
if self._P_prime * self._P_prime != size:
raise Exception(f"Number of processes must be a square number, provided {size} instead...")
self._row_id, self._col_id = divmod(rank, self._P_prime)
self.base_comm = base_comm
self._row_comm = base_comm.Split(color=self._row_id, key=self._col_id)
self._col_comm = base_comm.Split(color=self._col_id, key=self._row_id)
self.A = A.astype(np.dtype(dtype))
self.N = self._col_comm.allreduce(A.shape[0])
self.K = self._row_comm.allreduce(A.shape[1])
self.M = M
self._N_padded = math.ceil(self.N / self._P_prime) * self._P_prime
self._K_padded = math.ceil(self.K / self._P_prime) * self._P_prime
self._M_padded = math.ceil(self.M / self._P_prime) * self._P_prime
bn = self._N_padded // self._P_prime
bk = self._K_padded // self._P_prime
bm = self._M_padded // self._P_prime # noqa: F841
pr = (bn - A.shape[0]) if self._row_id == self._P_prime - 1 else 0
pc = (bk - A.shape[1]) if self._col_id == self._P_prime - 1 else 0
if pr > 0 or pc > 0:
self.A = np.pad(self.A, [(0, pr), (0, pc)], mode='constant')
if saveAt:
self.At = self.A.T.conj()
self.dims = (self.K, self.M)
self.dimsd = (self.N, self.M)
shape = (int(np.prod(self.dimsd)), int(np.prod(self.dims)))
super().__init__(shape=shape, dtype=np.dtype(dtype), base_comm=base_comm)
def _matvec(self, x: DistributedArray) -> DistributedArray:
ncp = get_module(x.engine)
if x.partition != Partition.SCATTER:
raise ValueError(f"x should have partition={Partition.SCATTER} Got {x.partition} instead...")
output_dtype = np.result_type(self.dtype, x.dtype)
# Calculate local shapes for block distribution
bn = self._N_padded // self._P_prime # block size in N dimension
bm = self._M_padded // self._P_prime # block size in M dimension
local_n = bn if self._row_id != self._P_prime - 1 else self.N - (self._P_prime - 1) * bn
local_m = bm if self._col_id != self._P_prime - 1 else self.M - (self._P_prime - 1) * bm
local_shapes = self.base_comm.allgather(local_n * local_m)
y = DistributedArray(global_shape=(self.N * self.M),
mask=x.mask,
local_shapes=local_shapes,
partition=Partition.SCATTER,
engine=x.engine,
dtype=output_dtype,
base_comm=self.base_comm)
# Calculate expected padded dimensions for x
bk = self._K_padded // self._P_prime # block size in K dimension
# The input x corresponds to blocks from matrix B (K x M)
# This process should receive a block of size (local_k x local_m)
local_k = bk if self._row_id != self._P_prime - 1 else self.K - (self._P_prime - 1) * bk
# Reshape x.local_array to its 2D block form
x_block = x.local_array.reshape((local_k, local_m))
# Pad the block to the full padded size if necessary
pad_k = bk - local_k
pad_m = bm - local_m
if pad_k > 0 or pad_m > 0:
x_block = ncp.pad(x_block, [(0, pad_k), (0, pad_m)], mode='constant')
Y_local = ncp.zeros((self.A.shape[0], bm), dtype=output_dtype)
for k in range(self._P_prime):
Atemp = self.A.copy() if self._col_id == k else ncp.empty_like(self.A)
Xtemp = x_block.copy() if self._row_id == k else ncp.empty_like(x_block)
self._row_comm.Bcast(Atemp, root=k)
self._col_comm.Bcast(Xtemp, root=k)
Y_local += ncp.dot(Atemp, Xtemp)
Y_local_unpadded = Y_local[:local_n, :local_m]
y[:] = Y_local_unpadded.flatten()
return y
def _rmatvec(self, x: DistributedArray) -> DistributedArray:
ncp = get_module(x.engine)
if x.partition != Partition.SCATTER:
raise ValueError(f"x should have partition={Partition.SCATTER}. Got {x.partition} instead.")
# Calculate local shapes for block distribution
bk = self._K_padded // self._P_prime # block size in K dimension
bm = self._M_padded // self._P_prime # block size in M dimension
# Calculate actual local shape for this process (considering original dimensions)
# Adjust for edge/corner processes that might have smaller blocks
local_k = bk if self._row_id != self._P_prime - 1 else self.K - (self._P_prime - 1) * bk
local_m = bm if self._col_id != self._P_prime - 1 else self.M - (self._P_prime - 1) * bm
local_shapes = self.base_comm.allgather(local_k * local_m)
# - If A is real: A^H = A^T,
# so result_type(real_A, x.dtype) = x.dtype (if x is complex) or real (if x is real)
# - If A is complex: A^H is complex,
# so result will be complex regardless of x
if np.iscomplexobj(self.A):
output_dtype = np.result_type(self.dtype, x.dtype)
else:
# Real matrix: A^T @ x preserves input type complexity
output_dtype = x.dtype if np.iscomplexobj(x.local_array) else self.dtype
# But still need to check type promotion for precision
output_dtype = np.result_type(self.dtype, output_dtype)
y = DistributedArray(
global_shape=(self.K * self.M),
mask=x.mask,
local_shapes=local_shapes,
partition=Partition.SCATTER,
engine=x.engine,
dtype=output_dtype,
base_comm=self.base_comm
)
# Calculate expected padded dimensions for x
bn = self._N_padded // self._P_prime # block size in N dimension
# The input x corresponds to blocks from the result (N x M)
# This process should receive a block of size (local_n x local_m)
local_n = bn if self._row_id != self._P_prime - 1 else self.N - (self._P_prime - 1) * bn
# Reshape x.local_array to its 2D block form
x_block = x.local_array.reshape((local_n, local_m))
# Pad the block to the full padded size if necessary
pad_n = bn - local_n
pad_m = bm - local_m
if pad_n > 0 or pad_m > 0:
x_block = ncp.pad(x_block, [(0, pad_n), (0, pad_m)], mode='constant')
A_local = self.At if hasattr(self, "At") else self.A.T.conj()
Y_local = ncp.zeros((self.A.shape[1], bm), dtype=output_dtype)
for k in range(self._P_prime):
requests = []
ATtemp = ncp.empty_like(A_local)
srcA = k * self._P_prime + self._row_id
tagA = (100 + k) * 1000 + self.rank
requests.append(self.base_comm.Irecv(ATtemp, source=srcA, tag=tagA))
if self._row_id == k:
fixed_col = self._col_id
for moving_col in range(self._P_prime):
destA = fixed_col * self._P_prime + moving_col
tagA = (100 + k) * 1000 + destA
requests.append(self.base_comm.Isend(A_local, dest=destA, tag=tagA))
Xtemp = x_block.copy() if self._row_id == k else ncp.empty_like(x_block)
requests.append(self._col_comm.Ibcast(Xtemp, root=k))
MPI.Request.Waitall(requests)
Y_local += ncp.dot(ATtemp, Xtemp)
Y_local_unpadded = Y_local[:local_k, :local_m]
y[:] = Y_local_unpadded.flatten()
return y
[docs]
def MPIMatrixMult(
A: NDArray,
M: int,
saveAt: bool = False,
base_comm: MPI.Comm = MPI.COMM_WORLD,
kind: Literal["summa", "block"] = "summa",
dtype: DTypeLike = "float64"):
r"""
MPI Distributed Matrix Multiplication Operator
This operator performs distributed matrix-matrix multiplication
using either the SUMMA (Scalable Universal Matrix Multiplication
Algorithm [1]_) or a 1D block-row decomposition algorithm (based on the
specified ``kind`` parameter).
Parameters
----------
A : :obj:`numpy.ndarray`
Local block of the matrix operator.
M : :obj:`int`
Global number of columns in the operand and result matrices.
saveAt : :obj:`bool`, optional
If ``True``, store both :math:`\mathbf{A}` and its conjugate transpose
:math:`\mathbf{A}^H` to accelerate adjoint operations (uses twice the
memory). Default is ``False``.
base_comm : :obj:`mpi4py.MPI.Comm`, optional
MPI communicator to use. Defaults to ``MPI.COMM_WORLD``.
kind : :obj:`str`, optional
Algorithm used to perform matrix multiplication: ``'block'`` for #
block-row-column decomposition, and ``'summa'`` for SUMMA algorithm, or
. Default is ``'summa'``.
dtype : :obj:`str`, optional
Type of elements in input array. Defaults to ``numpy.float64``.
Attributes
----------
shape : :obj:`tuple`
Operator shape
kind : :obj:`str`, optional
Selected distributed matrix multiply algorithm (``'block'`` or ``'summa'``).
Raises
------
NotImplementedError
If ``kind`` is not one of ``'summa'`` or ``'block'``.
Exception
If the MPI communicator does not form a compatible grid for the
selected algorithm.
Notes
-----
The forward operator computes:
.. math::
\mathbf{Y} = \mathbf{A} \cdot \mathbf{X}
where:
- :math:`\mathbf{A}` is the distributed operator matrix of shape :math:`[N \times K]`
- :math:`\mathbf{X}` is the distributed operand matrix of shape :math:`[K \times M]`
- :math:`\mathbf{Y}` is the resulting distributed matrix of shape :math:`[N \times M]`
The adjoint (conjugate-transpose) operation computes:
.. math::
\mathbf{X}_{adj} = \mathbf{A}^H \cdot \mathbf{Y}
where :math:`\mathbf{A}^H` is the complex-conjugate transpose of :math:`\mathbf{A}`.
Based on the choice of ``kind``, the distribution layouts of the operator and model and
data vectors differ as follows:
:summa:
2D block-grid distribution over a square process grid :math:`[\sqrt{P} \times \sqrt{P}]`:
- :math:`\mathbf{A}` and :math:`\mathbf{X}` (and :math:`\mathbf{Y}`) are partitioned into
:math:`[N_{loc} \times K_{loc}]` and :math:`[K_{loc} \times M_{loc}]` tiles on each
rank, respectively.
- Each SUMMA iteration broadcasts row- and column-blocks of :math:`\mathbf{A}` and
:math:`\mathbf{X}` (forward) or :math:`\mathbf{Y}` (adjoint) and accumulates local
partial products.
:block:
1D block-row distribution over a :math:`[1 \times P]` grid:
- :math:`\mathbf{A}` is partitioned into :math:`[N_{loc} \times K]` blocks across ranks.
- :math:`\mathbf{X}` (and :math:`\mathbf{Y}`) are partitioned into :math:`[K \times M_{loc}]` blocks.
- Local multiplication is followed by row-wise gather (forward) or
allreduce (adjoint) across ranks.
.. [1] Robert A. van de Geijn, R., and Watts, J. "SUMMA: Scalable Universal
Matrix Multiplication Algorithm", 1995.
"""
if kind == "summa":
return _MPISummaMatrixMult(A, M, saveAt, base_comm, dtype)
elif kind == "block":
return _MPIBlockMatrixMult(A, M, saveAt, base_comm, dtype)
else:
raise NotImplementedError("kind must be summa or block")