pylops_mpi.basicoperators.MPIMatrixMult#
- class pylops_mpi.basicoperators.MPIMatrixMult(A, M, saveAt=False, base_comm=<mpi4py.MPI.Intracomm object>, dtype='float64')[source]#
MPI Matrix multiplication
Implement distributed matrix-matrix multiplication between a matrix \(\mathbf{A}\) blocked over rows (i.e., blocks of rows are stored over different ranks) and the input model and data vector, which are both to be interpreted as matrices blocked over columns.
- Parameters:
- A
numpy.ndarray
Local block of the matrix of shape \([N_{loc} \times K]\) where \(N_{loc}\) is the number of rows stored on this MPI rank and
K
is the global number of columns.- M
int
Global leading dimension (i.e., number of columns) of the matrices representing the input model and data vectors.
- saveAt
bool
, optional Save
A
andA.H
to speed up the computation of adjoint (True
) or createA.H
on-the-fly (False
) Note thatsaveAt=True
will double the amount of required memory. Default isFalse
.- base_comm
mpi4py.MPI.Comm
, optional MPI Base Communicator. Defaults to
mpi4py.MPI.COMM_WORLD
.- dtype
str
, optional Type of elements in input array.
- A
- Attributes:
- shape
tuple
Operator shape
- shape
- Raises:
- Exception
If the operator is created with a non-square number of MPI ranks.
- ValueError
If input vector does not have the correct partition type.
Notes
This operator performs a matrix-matrix multiplication, whose forward operation can be described as \(Y = A \cdot X\) where:
\(\mathbf{A}\) is the distributed matrix operator of shape \([N \times K]\)
\(\mathbf{X}\) is the distributed operand matrix of shape \([K \times M]\)
\(\mathbf{Y}\) is the resulting distributed matrix of shape \([N \times M]\)
whilst the adjoint operation is represented by \(\mathbf{X}_{adj} = \mathbf{A}^H \cdot \mathbf{Y}\) where \(\mathbf{A}^H\) is the complex conjugate and transpose of \(\mathbf{A}\).
This implementation is based on a 1D block distribution of the operator matrix and reshaped model and data vectors replicated across \(P\) processes by a factor equivalent to \(\sqrt{P}\) across a square process grid (\(\sqrt{P}\times\sqrt{P}\)). More specifically:
The matrix
A
is distributed across MPI processes in a block-row fashion and each process holds a local block ofA
with shape \([N_{loc} \times K]\)The operand matrix
X
is distributed in a block-column fashion and each process holds a local block ofX
with shape \([K \times M_{loc}]\)Communication is minimized by using a 2D process grid layout
Forward Operation step-by-step
Input Preparation: The input vector
x
(flattened from matrixX
of shape(K, M)
) is reshaped to(K, M_local)
whereM_local
is the number of columns assigned to the current process.Local Computation: Each process computes
A_local @ X_local
where: -A_local
is the local block of matrixA
(shapeN_local x K
) -X_local
is the broadcasted operand (shapeK x M_local
)Row-wise Gather: Results from all processes in each row are gathered using
allgather
to ensure that each rank has a block-column of the output matrix.
Adjoint Operation step-by-step
The adjoint operation performs the conjugate transpose multiplication:
Input Reshaping: The input vector
x
is reshaped to(N, M_local)
representing the local columns of the input matrix.Local Adjoint Computation: Each process computes
A_local.H @ X_tile
whereA_local.H
is either i) Pre-computed and stored inAt
(ifsaveAt=True
), ii) computed on-the-fly asA.T.conj()
(ifsaveAt=False
). Each process multiplies its transposed localA
blockA_local^H
(shapeK x N_block
) with the extractedX_tile
(shapeN_block x M_local
), producing a partial result of shape(K, M_local)
. This computes the local contribution of columns ofA^H
to the final result.Row-wise Reduction: Since the full result
Y = A^H \cdot X
is the sum of the contributions from all column blocks ofA^H
, processes in the same row perform anallreduce
sum to combine their partial results. This gives the complete(K, M_local)
result for their assigned column.
Methods
__init__
(A, M[, saveAt, base_comm, dtype])active_grid_comm
(base_comm, N, M)Configure active grid
adjoint
()Adjoint MPI LinearOperator
conj
()Complex conjugate operator
dot
(x)Matrix Vector Multiplication
matvec
(x)Matrix-vector multiplication.
rmatvec
(x)Adjoint Matrix-vector multiplication.
transpose
()Transposition of MPI LinearOperator