Source code for pylops_mpi.basicoperators.MatrixMult

import numpy as np
import math
from mpi4py import MPI
from pylops.utils.backend import get_module
from pylops.utils.typing import DTypeLike, NDArray

from pylops_mpi import (
    DistributedArray,
    MPILinearOperator,
    Partition
)


[docs] class MPIMatrixMult(MPILinearOperator): r"""MPI Matrix multiplication Implement distributed matrix-matrix multiplication between a matrix :math:`\mathbf{A}` blocked over rows (i.e., blocks of rows are stored over different ranks) and the input model and data vector, which are both to be interpreted as matrices blocked over columns. Parameters ---------- A : :obj:`numpy.ndarray` Local block of the matrix of shape :math:`[N_{loc} \times K]` where :math:`N_{loc}` is the number of rows stored on this MPI rank and ``K`` is the global number of columns. M : :obj:`int` Global leading dimension (i.e., number of columns) of the matrices representing the input model and data vectors. saveAt : :obj:`bool`, optional Save ``A`` and ``A.H`` to speed up the computation of adjoint (``True``) or create ``A.H`` on-the-fly (``False``) Note that ``saveAt=True`` will double the amount of required memory. Default is ``False``. base_comm : :obj:`mpi4py.MPI.Comm`, optional MPI Base Communicator. Defaults to ``mpi4py.MPI.COMM_WORLD``. dtype : :obj:`str`, optional Type of elements in input array. Attributes ---------- shape : :obj:`tuple` Operator shape Raises ------ Exception If the operator is created with a non-square number of MPI ranks. ValueError If input vector does not have the correct partition type. Notes ----- This operator performs a matrix-matrix multiplication, whose forward operation can be described as :math:`Y = A \cdot X` where: - :math:`\mathbf{A}` is the distributed matrix operator of shape :math:`[N \times K]` - :math:`\mathbf{X}` is the distributed operand matrix of shape :math:`[K \times M]` - :math:`\mathbf{Y}` is the resulting distributed matrix of shape :math:`[N \times M]` whilst the adjoint operation is represented by :math:`\mathbf{X}_{adj} = \mathbf{A}^H \cdot \mathbf{Y}` where :math:`\mathbf{A}^H` is the complex conjugate and transpose of :math:`\mathbf{A}`. This implementation is based on a 1D block distribution of the operator matrix and reshaped model and data vectors replicated across :math:`P` processes by a factor equivalent to :math:`\sqrt{P}` across a square process grid (:math:`\sqrt{P}\times\sqrt{P}`). More specifically: - The matrix ``A`` is distributed across MPI processes in a block-row fashion and each process holds a local block of ``A`` with shape :math:`[N_{loc} \times K]` - The operand matrix ``X`` is distributed in a block-column fashion and each process holds a local block of ``X`` with shape :math:`[K \times M_{loc}]` - Communication is minimized by using a 2D process grid layout **Forward Operation step-by-step** 1. **Input Preparation**: The input vector ``x`` (flattened from matrix ``X`` of shape ``(K, M)``) is reshaped to ``(K, M_local)`` where ``M_local`` is the number of columns assigned to the current process. 2. **Local Computation**: Each process computes ``A_local @ X_local`` where: - ``A_local`` is the local block of matrix ``A`` (shape ``N_local x K``) - ``X_local`` is the broadcasted operand (shape ``K x M_local``) 3. **Row-wise Gather**: Results from all processes in each row are gathered using ``allgather`` to ensure that each rank has a block-column of the output matrix. **Adjoint Operation step-by-step** The adjoint operation performs the conjugate transpose multiplication: 1. **Input Reshaping**: The input vector ``x`` is reshaped to ``(N, M_local)`` representing the local columns of the input matrix. 2. **Local Adjoint Computation**: Each process computes ``A_local.H @ X_tile`` where ``A_local.H`` is either i) Pre-computed and stored in ``At`` (if ``saveAt=True``), ii) computed on-the-fly as ``A.T.conj()`` (if ``saveAt=False``). Each process multiplies its transposed local ``A`` block ``A_local^H`` (shape ``K x N_block``) with the extracted ``X_tile`` (shape ``N_block x M_local``), producing a partial result of shape ``(K, M_local)``. This computes the local contribution of columns of ``A^H`` to the final result. 3. **Row-wise Reduction**: Since the full result ``Y = A^H \cdot X`` is the sum of the contributions from all column blocks of ``A^H``, processes in the same row perform an ``allreduce`` sum to combine their partial results. This gives the complete ``(K, M_local)`` result for their assigned column. """ def __init__( self, A: NDArray, M: int, saveAt: bool = False, base_comm: MPI.Comm = MPI.COMM_WORLD, dtype: DTypeLike = "float64", ) -> None: rank = base_comm.Get_rank() size = base_comm.Get_size() # Determine grid dimensions (P_prime × C) such that P_prime * C ≥ size self._P_prime = math.isqrt(size) self._C = self._P_prime if self._P_prime * self._C != size: raise Exception(f"Number of processes must be a square number, provided {size} instead...") self._col_id = rank % self._P_prime self._row_id = rank // self._P_prime self.base_comm = base_comm self._row_comm = base_comm.Split(color=self._row_id, key=self._col_id) self._col_comm = base_comm.Split(color=self._col_id, key=self._row_id) self.A = A.astype(np.dtype(dtype)) if saveAt: self.At = A.T.conj() self.N = self._row_comm.allreduce(self.A.shape[0], op=MPI.SUM) self.K = A.shape[1] self.M = M block_cols = int(math.ceil(self.M / self._P_prime)) blk_rows = int(math.ceil(self.N / self._P_prime)) self._row_start = self._col_id * blk_rows self._row_end = min(self.N, self._row_start + blk_rows) self._col_start = self._row_id * block_cols self._col_end = min(self.M, self._col_start + block_cols) self._local_ncols = max(0, self._col_end - self._col_start) self._rank_col_lens = self.base_comm.allgather(self._local_ncols) total_ncols = np.sum(self._rank_col_lens) self.dims = (self.K, total_ncols) self.dimsd = (self.N, total_ncols) shape = (int(np.prod(self.dimsd)), int(np.prod(self.dims))) super().__init__(shape=shape, dtype=np.dtype(dtype), base_comm=base_comm) @staticmethod def active_grid_comm(base_comm: MPI.Comm, N: int, M: int): r"""Configure active grid Configure a square process grid from a parent MPI communicator and select a subset of "active" processes. Each process in ``base_comm`` is assigned to a logical 2D grid of size :math:`P' \times P'`, where :math:`P' = \bigl \lceil \sqrt{P} \bigr \rceil`. Only the first :math:`active_dim x active_dim` processes (by row-major order) are considered "active". Inactive ranks return immediately with no new communicator. Parameters: ----------- base_comm : :obj:`mpi4py.MPI.Comm` MPI Parent Communicator. (e.g., ``mpi4py.MPI.COMM_WORLD``). N : :obj:`int` Number of rows of the global data domain. M : :obj:`int` Number of columns of the global data domain. Returns: -------- comm : :obj:`mpi4py.MPI.Comm` Sub-communicator including only active ranks. rank : :obj:`int` Rank within the new sub-communicator (or original rank if inactive). row : :obj:`int` Grid row index of this process in the active grid (or original rank if inactive). col : :obj:`int` Grid column index of this process in the active grid (or original rank if inactive). is_active : :obj:`bool` Flag indicating whether this rank is in the active sub-grid. """ rank = base_comm.Get_rank() size = base_comm.Get_size() p_prime = math.isqrt(size) row, col = divmod(rank, p_prime) active_dim = min(N, M, p_prime) is_active = (row < active_dim and col < active_dim) if not is_active: return None, rank, row, col, False active_ranks = [r for r in range(size) if (r // p_prime) < active_dim and (r % p_prime) < active_dim] new_group = base_comm.Get_group().Incl(active_ranks) new_comm = base_comm.Create_group(new_group) p_prime_new = math.isqrt(len(active_ranks)) new_rank = new_comm.Get_rank() new_row, new_col = divmod(new_rank, p_prime_new) return new_comm, new_rank, new_row, new_col, True def _matvec(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) if x.partition != Partition.SCATTER: raise ValueError(f"x should have partition={Partition.SCATTER} Got {x.partition} instead...") y = DistributedArray( global_shape=(self.N * self.dimsd[1]), local_shapes=[(self.N * c) for c in self._rank_col_lens], mask=x.mask, partition=Partition.SCATTER, dtype=self.dtype, base_comm=self.base_comm ) my_own_cols = self._rank_col_lens[self.rank] x_arr = x.local_array.reshape((self.dims[0], my_own_cols)) X_local = x_arr.astype(self.dtype) Y_local = ncp.vstack( self._row_comm.allgather( ncp.matmul(self.A, X_local) ) ) y[:] = Y_local.flatten() return y def _rmatvec(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) if x.partition != Partition.SCATTER: raise ValueError(f"x should have partition={Partition.SCATTER}. Got {x.partition} instead.") y = DistributedArray( global_shape=(self.K * self.dimsd[1]), local_shapes=[self.K * c for c in self._rank_col_lens], mask=x.mask, partition=Partition.SCATTER, dtype=self.dtype, base_comm=self.base_comm ) x_arr = x.local_array.reshape((self.N, self._local_ncols)).astype(self.dtype) X_tile = x_arr[self._row_start:self._row_end, :] A_local = self.At if hasattr(self, "At") else self.A.T.conj() Y_local = ncp.matmul(A_local, X_tile) y_layer = self._row_comm.allreduce(Y_local, op=MPI.SUM) y[:] = y_layer.flatten() return y