from __future__ import annotations
import numpy as np
from mpi4py import MPI
from typing import Callable, Optional
from scipy.sparse._sputils import isintlike, isshape
from scipy.sparse.linalg._interface import _get_dtype
from pylops import LinearOperator
from pylops.utils import DTypeLike, ShapeLike
from pylops_mpi import DistributedArray
[docs]
class MPILinearOperator:
"""MPI-enabled PyLops Linear Operator
Common interface for performing matrix-vector products in distributed fashion.
In practice, this class provides methods to perform matrix-vector and
adjoint matrix-vector products between any :obj:`pylops.LinearOperator`
(which must be the same across ranks) and a :class:`pylops_mpi.DistributedArray`
with ``Partition.BROADCAST`` and ``Partition.UNSAFE_BROADCAST`` partition. It
internally handles the extraction of the local array from the distributed array
and the creation of the output :class:`pylops_mpi.DistributedArray`.
Note that whilst this operator could also be used with different
:obj:`pylops.LinearOperator` across ranks, and with a
:class:`pylops_mpi.DistributedArray` with ``Partition.SCATTER``, it is however
recommended to use the :class:`pylops_mpi.basicoperators.MPIBlockDiag` operator
instead as this can also handle distributed arrays with subcommunicators.
Parameters
----------
Op : :obj:`pylops.LinearOperator`, optional
If other arguments are provided, they will overwrite those obtained from ``Op``. Defaults to ``None``.
shape : :obj:`tuple(int, int)`, optional
Shape of the MPI Linear Operator. If not provided, obtained from ``dims`` and ``dimsd``.
dims : :obj:`tuple(int, ..., int)`, optional
Dimensions of model. If not provided, ``(self.shape[1],)`` is used.
dimsd : :obj:`tuple(int, ..., int)`, optional
Dimensions of data. If not provided, ``(self.shape[0],)`` is used.
dtype : :obj:`str`, optional
Type of elements in input array. Defaults to ``None``.
base_comm : :obj:`mpi4py.MPI.Comm`, optional
MPI Base Communicator. Defaults to ``mpi4py.MPI.COMM_WORLD``.
"""
def __init__(
self,
Op: Optional[LinearOperator] = None,
shape: Optional[ShapeLike] = None,
dims: Optional[ShapeLike] = None,
dimsd: Optional[ShapeLike] = None,
dtype: Optional[DTypeLike] = None,
base_comm: MPI.Comm = MPI.COMM_WORLD
):
if Op is not None:
self.Op = Op
dtype = Op.dtype if dtype is None else dtype
shape = Op.shape if shape is None else shape
# Optional arguments
dims = getattr(Op, "dims", (Op.shape[1], )) if dims is None else dims
dimsd = getattr(Op, "dimsd", (Op.shape[0], )) if dimsd is None else dimsd
if shape is not None:
self.shape = shape
if dims is not None:
self.dims = dims
if dimsd is not None:
self.dimsd = dimsd
if dtype is not None:
self.dtype = dtype
# For MPI
self.base_comm = base_comm
self.size = base_comm.Get_size()
self.rank = base_comm.Get_rank()
@property
def shape(self):
_shape = getattr(self, "_shape", None)
if _shape is None: # Cannot find shape, falling back on dims and dimsd
dims = getattr(self, "_dims", None)
dimsd = getattr(self, "_dimsd", None)
if dims is None or dimsd is None: # Cannot find both dims and dimsd, error
msg = (
f"'{self.__class__.__name__}' object has no attribute 'shape' "
"nor both fallback attributes ('dims', 'dimsd')"
)
raise AttributeError(msg)
_shape = (int(np.prod(dimsd)), int(np.prod(dims)))
self._shape = _shape # Update to not redo everything above on next call
return _shape
@shape.setter
def shape(self, new_shape: ShapeLike) -> None:
new_shape = tuple(new_shape)
if not isshape(new_shape):
msg = f"Invalid shape; must be 2-d tuple of integers, got {new_shape}"
raise ValueError(msg)
dims = getattr(self, "_dims", None)
dimsd = getattr(self, "_dimsd", None)
if dims is not None and dimsd is not None: # Found dims and dimsd
if np.prod(dimsd) != new_shape[0] and np.prod(dims) != new_shape[1]:
msg = "New shape incompatible with dims and dimsd"
raise ValueError(msg)
elif np.prod(dimsd) != new_shape[0]:
msg = "New shape incompatible with dimsd"
raise ValueError(msg)
elif np.prod(dims) != new_shape[1]:
msg = "New shape incompatible with dims"
raise ValueError(msg)
self._shape = new_shape
@property
def dims(self):
_dims = getattr(self, "_dims", None)
if _dims is None:
shape = getattr(self, "_shape", None)
if shape is None:
msg = (
f"'{self.__class__.__name__}' object has no "
"attributes 'dims' or 'shape'"
)
raise AttributeError(msg)
_dims = (shape[1],)
return _dims
@dims.setter
def dims(self, new_dims: ShapeLike) -> None:
new_dims = tuple(new_dims)
shape = getattr(self, "_shape", None)
if shape is None: # shape not set yet
self._dims = new_dims
else:
if np.prod(new_dims) == self.shape[1]:
self._dims = new_dims
else:
msg = "dims incompatible with shape[1]"
raise ValueError(msg)
@property
def dimsd(self):
_dimsd = getattr(self, "_dimsd", None)
if _dimsd is None:
shape = getattr(self, "_shape", None)
if shape is None:
msg = (
f"'{self.__class__.__name__}' object has "
"no attributes 'dimsd' or 'shape'"
)
raise AttributeError(msg)
_dimsd = (shape[0],)
return _dimsd
@dimsd.setter
def dimsd(self, new_dimsd: ShapeLike) -> None:
new_dimsd = tuple(new_dimsd)
shape = getattr(self, "_shape", None)
if shape is None: # shape not set yet
self._dimsd = new_dimsd
else:
if np.prod(new_dimsd) == self.shape[0]:
self._dimsd = new_dimsd
else:
msg = "dimsd incompatible with shape[0]"
raise ValueError(msg)
def matvec(self, x: DistributedArray) -> DistributedArray:
"""Matrix-vector multiplication.
Modified version of pylops matvec
This method makes use of :class:`pylops_mpi.DistributedArray` to calculate
matrix vector multiplication in a distributed fashion.
Parameters
----------
x : :obj:`pylops_mpi.DistributedArray`
A DistributedArray of global shape (N, ).
Returns
-------
y : :obj:`pylops_mpi.DistributedArray`
DistributedArray of global shape (M, )
"""
M, N = self.shape
if x.global_shape != (N,):
raise ValueError("dimension mismatch")
return self._matvec(x)
def _matvec(self, x: DistributedArray) -> DistributedArray:
if self.Op:
y = DistributedArray(global_shape=self.shape[0],
base_comm=self.base_comm,
base_comm_nccl=x.base_comm_nccl,
partition=x.partition,
axis=x.axis,
engine=x.engine,
dtype=self.dtype)
y[:] = self.Op._matvec(x.local_array)
return y
def rmatvec(self, x: DistributedArray) -> DistributedArray:
"""Adjoint Matrix-vector multiplication.
Modified version of pylops rmatvec
This method makes use of :class:`pylops_mpi.DistributedArray` to
calculate adjoint matrix vector multiplication in a distributed fashion.
Parameters
----------
x : :obj:`pylops_mpi.DistributedArray`
A DistributedArray of global shape (M, ).
Returns
-------
y : :obj:`pylops_mpi.DistributedArray`
DistributedArray of global shape (N, )
"""
M, N = self.shape
if x.global_shape != (M,):
raise ValueError("dimension mismatch")
return self._rmatvec(x)
def _rmatvec(self, x: DistributedArray) -> DistributedArray:
if self.Op:
y = DistributedArray(global_shape=self.shape[1],
base_comm=self.base_comm,
base_comm_nccl=x.base_comm_nccl,
partition=x.partition,
axis=x.axis,
engine=x.engine,
dtype=self.dtype)
y[:] = self.Op._rmatvec(x.local_array)
return y
def dot(self, x):
"""Matrix Vector Multiplication
Parameters
----------
x : :obj:`pylops_mpi.DistributedArray` or :obj:`pylops_mpi.MPILinearOperator
DistributedArray or a MPILinearOperator.
Returns
-------
y : :obj:`pylops_mpi.DistributedArray` or :obj:`pylops_mpi.MPILinearOperator`
DistributedArray or a MPILinearOperator.
"""
if isinstance(x, MPILinearOperator):
Op = _ProductLinearOperator(self, x)
self._copy_attributes(
Op,
exclude=['dims']
)
Op.dims = x.dims
return Op
elif np.isscalar(x):
Op = _ScaledLinearOperator(self, x)
self._copy_attributes(
Op
)
return Op
else:
if x is None or x.ndim == 1:
return self.matvec(x)
else:
raise ValueError('expected 1-d DistributedArray, got %r'
% x.global_shape)
def adjoint(self):
"""Adjoint MPI LinearOperator
Returns
-------
op : :obj:`pylops_mpi.MPILinearOperator`
Adjoint of Operator
"""
return self._adjoint()
H = property(adjoint)
def transpose(self):
"""Transposition of MPI LinearOperator
Returns
-------
op : :obj:`pylops_mpi.MPILinearOperator`
Transpose Linear Operator
"""
return self._transpose()
T = property(transpose)
def __mul__(self, x):
return self.dot(x)
def __rmul__(self, x):
if np.isscalar(x):
Op = _ScaledLinearOperator(self, x)
self._copy_attributes(
Op
)
return Op
else:
return NotImplemented
def __matmul__(self, x):
if np.isscalar(x):
raise ValueError("Scalar not allowed, use * instead")
return self.__mul__(x)
def __rmatmul__(self, x):
if np.isscalar(x):
raise ValueError("Scalar not allowed, use * instead")
return self.__rmul__(x)
def __pow__(self, p):
Op = _PowerLinearOperator(self, p)
self._copy_attributes(
Op
)
return Op
def __add__(self, x):
Op = _SumLinearOperator(self, x)
self._copy_attributes(
Op
)
return Op
def __neg__(self):
Op = _ScaledLinearOperator(self, -1)
self._copy_attributes(
Op
)
return Op
def __sub__(self, x):
return self.__add__(-x)
def _adjoint(self):
Op = _AdjointLinearOperator(self)
self._copy_attributes(
Op,
exclude=['dims', 'dimsd']
)
Op.dims = self.dimsd
Op.dimsd = self.dims
return Op
def _transpose(self):
Op = _TransposedLinearOperator(self)
self._copy_attributes(
Op,
exclude=['dims', 'dimsd']
)
Op.dims = self.dimsd
Op.dimsd = self.dims
return Op
def conj(self):
"""Complex conjugate operator
Returns
-------
conjop : :obj:`pylops_mpi.MPILinearOperator`
Complex conjugate operator
"""
return _ConjLinearOperator(self)
def _copy_attributes(
self,
dest: MPILinearOperator,
exclude: list[str] | None = None,
) -> None:
"""Copy attributes from one MPILinearOperator to another"""
attrs = ["dims", "dimsd"]
if exclude is not None:
for item in exclude:
attrs.remove(item)
for attr in attrs:
if hasattr(self, attr):
setattr(dest, attr, getattr(self, attr))
def __repr__(self):
M, N = self.shape
if self.dtype is None:
dt = "unspecified dtype"
else:
dt = f"dtype={self.dtype}"
return f"<{M}x{N} {self.__class__.__name__} with {dt}>"
class _AdjointLinearOperator(MPILinearOperator):
"""Adjoint of MPI Linear Operator"""
def __init__(self, A: MPILinearOperator):
self.A = A
self.args = (A,)
super().__init__(shape=(A.shape[1], A.shape[0]), dtype=A.dtype,
base_comm=MPI.COMM_WORLD)
def _matvec(self, x: DistributedArray) -> DistributedArray:
return self.A.rmatvec(x)
def _rmatvec(self, x: DistributedArray) -> DistributedArray:
return self.A.matvec(x)
class _TransposedLinearOperator(MPILinearOperator):
"""Transposition of MPI Linear Operator"""
def __init__(self, A: MPILinearOperator):
self.A = A
self.args = (A,)
super().__init__(shape=(A.shape[1], A.shape[0]), dtype=A.dtype,
base_comm=MPI.COMM_WORLD)
def _matvec(self, x: DistributedArray) -> DistributedArray:
x = x.conj()
y = self.A.rmatvec(x)
y = y.conj()
return y
def _rmatvec(self, x: DistributedArray) -> DistributedArray:
x = x.conj()
y = self.A.matvec(x)
y = y.conj()
return y
class _ProductLinearOperator(MPILinearOperator):
"""Product of MPI LinearOperators
"""
def __init__(self, A: MPILinearOperator, B: MPILinearOperator):
if not isinstance(A, MPILinearOperator) or not isinstance(B, MPILinearOperator):
raise ValueError('both operands have to be a LinearOperator')
if A.shape[1] != B.shape[0]:
raise ValueError('cannot multiply %r and %r: shape mismatch' % (A, B))
self.args = (A, B)
super().__init__(shape=(A.shape[0], B.shape[1]), dtype=_get_dtype([A, B]),
base_comm=MPI.COMM_WORLD)
def _matvec(self, x: DistributedArray) -> DistributedArray:
return self.args[0].matvec(self.args[1].matvec(x))
def _rmatvec(self, x: DistributedArray) -> DistributedArray:
return self.args[1].rmatvec(self.args[0].rmatvec(x))
def _adjoint(self) -> MPILinearOperator:
A, B = self.args
return B.H * A.H
class _ScaledLinearOperator(MPILinearOperator):
"""Scaled MPI Linear Operator
"""
def __init__(self, A: MPILinearOperator, alpha):
if not isinstance(A, MPILinearOperator):
raise ValueError('MPILinearOperator expected as A')
if not np.isscalar(alpha):
raise ValueError('scalar expected as alpha')
self.args = (A, alpha)
super().__init__(shape=A.shape, dtype=_get_dtype([A], [type(alpha)]),
base_comm=MPI.COMM_WORLD)
def _matvec(self, x: DistributedArray) -> DistributedArray:
y = self.args[0].matvec(x)
if y is not None:
y[:] *= self.args[1]
return y
def _rmatvec(self, x: DistributedArray) -> DistributedArray:
y = self.args[0].rmatvec(x)
if y is not None:
y[:] *= np.conj(self.args[1])
return y
def _adjoint(self) -> MPILinearOperator:
A, alpha = self.args
return A.H * np.conj(alpha)
class _SumLinearOperator(MPILinearOperator):
"""Sum of MPI LinearOperators
"""
def __init__(self, A: MPILinearOperator, B: MPILinearOperator):
if not isinstance(A, MPILinearOperator) or not isinstance(B, MPILinearOperator):
raise ValueError('both operands have to be a MPILinearOperator')
# Make sure it works with different kinds
if A.shape != B.shape:
raise ValueError("cannot add %r and %r: shape mismatch" % (A, B))
self.args = (A, B)
super().__init__(shape=A.shape, dtype=A.dtype, base_comm=MPI.COMM_WORLD)
def _matvec(self, x: DistributedArray) -> DistributedArray:
arr1 = self.args[0].matvec(x)
arr2 = self.args[1].matvec(x)
return arr1 + arr2
def _rmatvec(self, x: DistributedArray) -> DistributedArray:
arr1 = self.args[0].rmatvec(x)
arr2 = self.args[1].rmatvec(x)
return arr1 + arr2
def _adjoint(self) -> MPILinearOperator:
A, B = self.args
return A.H + B.H
class _PowerLinearOperator(MPILinearOperator):
"""Power of MPI Linear Operator
"""
def __init__(self, A: MPILinearOperator, p: int) -> None:
if not isinstance(A, MPILinearOperator):
raise ValueError("LinearOperator expected as A")
if A.shape[0] != A.shape[1]:
raise ValueError("square LinearOperator expected, got %r" % A)
if not isintlike(p) or p < 0:
raise ValueError("non-negative integer expected as p")
super(_PowerLinearOperator, self).__init__(shape=A.shape, dtype=A.dtype, base_comm=A.base_comm)
self.args = (A, p)
def _power(self, fun: Callable, x: DistributedArray) -> DistributedArray:
res = x.copy()
for _ in range(self.args[1]):
res[:] = fun(res).local_array
return res
def _matvec(self, x: DistributedArray) -> DistributedArray:
return self._power(self.args[0].matvec, x)
def _rmatvec(self, x: DistributedArray) -> DistributedArray:
return self._power(self.args[0].rmatvec, x)
class _ConjLinearOperator(MPILinearOperator):
"""Complex conjugate MPI Linear Operator
"""
def __init__(self, A: MPILinearOperator):
if not isinstance(A, MPILinearOperator):
raise TypeError('A must be a MPILinearOperator')
self.A = A
super().__init__(shape=A.shape, dtype=A.dtype, base_comm=MPI.COMM_WORLD)
def _matvec(self, x: DistributedArray) -> DistributedArray:
x = x.conj()
y = self.A.matvec(x)
if y is not None:
y = y.conj()
return y
def _rmatvec(self, x: DistributedArray) -> DistributedArray:
x = x.conj()
y = self.A.rmatvec(x)
if y is not None:
y = y.conj()
return y
def _adjoint(self) -> MPILinearOperator:
return _ConjLinearOperator(self.A.H)
[docs]
def asmpilinearoperator(Op):
"""Return Op as a MPI LinearOperator.
Converts a :class:`pylops.LinearOperator` to a :class:`pylops_mpi.MPILinearOperator`.
Parameters
----------
Op : :obj:`pylops.LinearOperator`
PyLops LinearOperator
Returns
-------
Op : :obj:`pylops_mpi.MPILinearOperator`
Operator of type :obj:`pylops_mpi.MPILinearOperator`
"""
if isinstance(Op, MPILinearOperator):
return Op
else:
return MPILinearOperator(Op=Op, base_comm=MPI.COMM_WORLD)