from __future__ import annotations
from typing import Optional, Union, Callable
from abc import abstractmethod, ABC
import numpy as np
from mpi4py import MPI
from pylops.utils import ShapeLike, DTypeLike
from scipy.sparse._sputils import isintlike, isshape
from scipy.sparse.linalg._interface import _get_dtype
from pylops_mpi.DistributedArray import DistributedArray, StackedDistributedArray
[docs]
class MPIStackedLinearOperator(ABC):
"""Stack of MPI-enabled PyLops Linear Operators
Common interface for performing matrix-vector products in distributed fashion
for a stack of :class:`pylops_mpi.MPILinearOperator` operators.
In practice, this class provides methods to perform matrix-vector and adjoint
matrix-vector products on a stack of :class:`pylops_mpi.MPILinearOperator`
operators, allowing the actual execution of each operator to be distributed,
whilst dispatching the execution of the different operators in sequential order.
.. note:: End users of pylops-mpi should not use this class directly but simply
use operators that are already implemented as extensions of this class.
This class is meant for developers only, it has to be used as the parent
class of any new stacked operator developed within pylops-mpi.
Parameters
----------
shape : :obj:`tuple(int, int)`, optional
Shape of the MPIStackedLinearOperator. If not provided, obtained from ``dims`` and ``dimsd``.
dims : :obj:`tuple(int, ..., int)`, optional
Dimensions of model. If not provided, ``(self.shape[1],)`` is used.
dimsd : :obj:`tuple(int, ..., int)`, optional
Dimensions of data. If not provided, ``(self.shape[0],)`` is used.
dtype : :obj:`str`, optional
Type of elements in input array. Defaults to ``None``.
base_comm : :obj:`mpi4py.MPI.Comm`, optional
MPI Base Communicator. Defaults to ``mpi4py.MPI.COMM_WORLD``.
"""
def __init__(
self,
shape: Optional[ShapeLike] = None,
dims: Optional[ShapeLike] = None,
dimsd: Optional[ShapeLike] = None,
dtype: Optional[DTypeLike] = None,
base_comm: MPI.Comm = MPI.COMM_WORLD
):
# Infer shape from dims/dimsd if not provided
if shape is not None:
self.shape = shape
if dims is not None:
self.dims = dims
if dimsd is not None:
self.dimsd = dimsd
if dtype is not None:
self.dtype = dtype
# For MPI
self.base_comm = base_comm
self.size = base_comm.Get_size()
self.rank = base_comm.Get_rank()
@property
def shape(self):
_shape = getattr(self, "_shape", None)
if _shape is None: # Cannot find shape, falling back on dims and dimsd
dims = getattr(self, "_dims", None)
dimsd = getattr(self, "_dimsd", None)
if dims is None or dimsd is None: # Cannot find both dims and dimsd, error
msg = (
f"'{self.__class__.__name__}' object has no attribute 'shape' "
"nor both fallback attributes ('dims', 'dimsd')"
)
raise AttributeError(msg)
_shape = (int(np.prod(dimsd)), int(np.prod(dims)))
self._shape = _shape # Update to not redo everything above on next call
return _shape
@shape.setter
def shape(self, new_shape: ShapeLike) -> None:
new_shape = tuple(new_shape)
if not isshape(new_shape):
msg = f"Invalid shape; must be 2-d tuple of integers, got {new_shape}"
raise ValueError(msg)
dims = getattr(self, "_dims", None)
dimsd = getattr(self, "_dimsd", None)
if dims is not None and dimsd is not None: # Found dims and dimsd
if np.prod(dimsd) != new_shape[0] and np.prod(dims) != new_shape[1]:
msg = "New shape incompatible with dims and dimsd"
raise ValueError(msg)
elif np.prod(dimsd) != new_shape[0]:
msg = "New shape incompatible with dimsd"
raise ValueError(msg)
elif np.prod(dims) != new_shape[1]:
msg = "New shape incompatible with dims"
raise ValueError(msg)
self._shape = new_shape
@property
def dims(self):
_dims = getattr(self, "_dims", None)
if _dims is None:
shape = getattr(self, "_shape", None)
if shape is None:
msg = (
f"'{self.__class__.__name__}' object has no "
"attributes 'dims' or 'shape'"
)
raise AttributeError(msg)
_dims = (shape[1],)
return _dims
@dims.setter
def dims(self, new_dims: ShapeLike) -> None:
new_dims = tuple(new_dims)
shape = getattr(self, "_shape", None)
if shape is None: # shape not set yet
self._dims = new_dims
else:
if np.prod(new_dims) == self.shape[1]:
self._dims = new_dims
else:
msg = "dims incompatible with shape[1]"
raise ValueError(msg)
@property
def dimsd(self):
_dimsd = getattr(self, "_dimsd", None)
if _dimsd is None:
shape = getattr(self, "_shape", None)
if shape is None:
msg = (
f"'{self.__class__.__name__}' object has "
"no attributes 'dimsd' or 'shape'"
)
raise AttributeError(msg)
_dimsd = (shape[0],)
return _dimsd
@dimsd.setter
def dimsd(self, new_dimsd: ShapeLike) -> None:
new_dimsd = tuple(new_dimsd)
shape = getattr(self, "_shape", None)
if shape is None: # shape not set yet
self._dimsd = new_dimsd
else:
if np.prod(new_dimsd) == self.shape[0]:
self._dimsd = new_dimsd
else:
msg = "dimsd incompatible with shape[0]"
raise ValueError(msg)
def matvec(self, x: Union[DistributedArray, StackedDistributedArray]) -> Union[DistributedArray, StackedDistributedArray]:
"""Matrix-vector multiplication.
Modified version of :class:`pylops_mpi.MPILinearOperator` matvec.
This method makes use of either :class:`pylops_mpi.DistributedArray` or
:class:`pylops_mpi.StackedDistributedArray` to calculate matrix vector multiplication
in a distributed fashion.
Parameters
----------
x : :obj:`pylops_mpi.DistributedArray` or :obj:`pylops_mpi.DistributedArray`
A StackedDistributedArray or a DistributedArray of global shape (N, )
Returns
-------
y : :obj:`pylops_mpi.DistributedArray` or :obj:`pylops_mpi.DistributedArray`
A StackedDistributedArray or a DistributedArray of global shape (M, )
"""
M, N = self.shape
if isinstance(x, StackedDistributedArray):
stacked_shape = (np.sum([a.global_shape for a in x.distarrays]), )
if stacked_shape != (N, ):
raise ValueError("dimension mismatch")
if isinstance(x, DistributedArray) and x.global_shape != (N,):
raise ValueError("dimension mismatch")
return self._matvec(x)
@abstractmethod
def _matvec(self, x: Union[DistributedArray, StackedDistributedArray]) -> Union[DistributedArray, StackedDistributedArray]:
pass
def rmatvec(self, x: Union[DistributedArray, StackedDistributedArray]) -> Union[DistributedArray, StackedDistributedArray]:
"""Adjoint Matrix-vector multiplication.
Modified version of :class:`pylops_mpi.MPILinearOperator` rmatvec
This method makes use of either :class:`pylops_mpi.DistributedArray` or
:class:`pylops_mpi.StackedDistributedArray` to calculate adjoint matrix vector multiplication
in a distributed fashion.
Parameters
----------
x : :obj:`pylops_mpi.DistributedArray` or :obj:`pylops_mpi.DistributedArray`
A StackedDistributedArray or a DistributedArray of global shape (M, )
Returns
-------
y : :obj:`pylops_mpi.DistributedArray` or :obj:`pylops_mpi.DistributedArray`
A StackedDistributedArray or a DistributedArray of global shape (N, )
"""
M, N = self.shape
if isinstance(x, StackedDistributedArray):
stacked_shape = (np.sum([a.global_shape for a in x.distarrays]), )
if stacked_shape != (M, ):
raise ValueError("dimension mismatch")
if isinstance(x, DistributedArray) and x.global_shape != (M,):
raise ValueError("dimension mismatch")
return self._rmatvec(x)
@abstractmethod
def _rmatvec(self, x: Union[DistributedArray, StackedDistributedArray]) -> Union[DistributedArray, StackedDistributedArray]:
pass
def dot(self, x):
"""Matrix Vector Multiplication
Parameters
----------
x : :obj:`pylops_mpi.DistributedArray` or :obj:`pylops_mpi.StackedDistributedArray` or
:obj:`pylops_mpi.StackedMPILinearOperator
StackedDistributedArray, DistributedArray or StackedMPILinearOperator.
Returns
-------
y : :obj:`pylops_mpi.DistributedArray` or :obj:`pylops_mpi.StackedDistributedArray` or
:obj:`pylops_mpi.StackedMPILinearOperator
StackedDistributedArray, DistributedArray or a StackedMPILinearOperator.
"""
if isinstance(x, MPIStackedLinearOperator):
Op = _ProductStackedLinearOperator(self, x)
self._copy_attributes(
Op,
exclude=['dims']
)
Op.dims = x.dims
return Op
elif np.isscalar(x):
Op = _ScaledStackedLinearOperator(self, x)
self._copy_attributes(
Op
)
return Op
else:
if x is None or (isinstance(x, DistributedArray) and x.ndim == 1):
return self.matvec(x)
elif isinstance(x, StackedDistributedArray):
ndims = np.unique([dis.ndim for dis in x.distarrays])
if len(ndims) == 1 and ndims[0] == 1:
return self.matvec(x)
else:
raise ValueError('expected 1-d DistributedArray or StackedDistributedArray')
def adjoint(self):
"""Adjoint MPIStackedLinearOperator
Returns
-------
op : :obj:`pylops_mpi.MPIStackedLinearOperator`
Adjoint of Operator
"""
return self._adjoint()
H = property(adjoint)
def transpose(self):
"""Transposition of MPIStackedLinearOperator
Returns
-------
op : :obj:`pylops_mpi.MPIStackedLinearOperator`
Transpose MPIStackedLinearOperator
"""
return self._transpose()
T = property(transpose)
def __mul__(self, x):
return self.dot(x)
def __rmul__(self, x):
if np.isscalar(x):
Op = _ScaledStackedLinearOperator(self, x)
self._copy_attributes(
Op
)
return Op
else:
return NotImplemented
def __matmul__(self, x):
if np.isscalar(x):
raise ValueError("Scalar not allowed, use * instead")
return self.__mul__(x)
def __rmatmul__(self, x):
if np.isscalar(x):
raise ValueError("Scalar not allowed, use * instead")
return self.__rmul__(x)
def __pow__(self, p):
Op = _PowerLinearOperator(self, p)
self._copy_attributes(
Op
)
return Op
def __add__(self, x):
Op = _SumStackedLinearOperator(self, x)
self._copy_attributes(
Op
)
return Op
def __neg__(self):
Op = _ScaledStackedLinearOperator(self, -1)
self._copy_attributes(
Op
)
return Op
def __sub__(self, x):
return self.__add__(-x)
def _adjoint(self):
Op = _AdjointStackedLinearOperator(self)
self._copy_attributes(
Op,
exclude=['dims', 'dimsd']
)
Op.dims = self.dimsd
Op.dimsd = self.dims
return Op
def _transpose(self):
Op = _TransposedStackedLinearOperator(self)
self._copy_attributes(
Op,
exclude=['dims', 'dimsd']
)
Op.dims = self.dimsd
Op.dimsd = self.dims
return Op
def conj(self):
"""Complex conjugate operator
Returns
-------
conjop : :obj:`pylops_mpi.MPIStackedLinearOperator`
Complex conjugate operator
"""
return _ConjLinearOperator(self)
def __repr__(self):
M, N = self.shape
if self.dtype is None:
dt = "unspecified dtype"
else:
dt = f"dtype={self.dtype}"
return f"<{M}x{N} {self.__class__.__name__} with {dt}>"
def _copy_attributes(
self,
dest: MPIStackedLinearOperator,
exclude: list[str] | None = None,
) -> None:
"""Copy attributes from one MPIStackedLinearOperator to another"""
attrs = ["dims", "dimsd"]
if exclude is not None:
for item in exclude:
attrs.remove(item)
for attr in attrs:
if hasattr(self, attr):
setattr(dest, attr, getattr(self, attr))
class _AdjointStackedLinearOperator(MPIStackedLinearOperator):
"""Adjoint of MPIStackedLinearOperator"""
def __init__(self, A: MPIStackedLinearOperator):
self.A = A
self.args = (A,)
super().__init__(shape=(A.shape[1], A.shape[0]), dtype=A.dtype,
base_comm=MPI.COMM_WORLD)
def _matvec(self, x: Union[DistributedArray, StackedDistributedArray]) -> Union[DistributedArray, StackedDistributedArray]:
return self.A.rmatvec(x)
def _rmatvec(self, x: Union[DistributedArray, StackedDistributedArray]) -> Union[DistributedArray, StackedDistributedArray]:
return self.A.matvec(x)
class _TransposedStackedLinearOperator(MPIStackedLinearOperator):
"""Transpose of MPIStackedLinearOperator"""
def __init__(self, A: MPIStackedLinearOperator):
self.A = A
self.args = (A,)
super().__init__(shape=(A.shape[1], A.shape[0]), dtype=A.dtype,
base_comm=MPI.COMM_WORLD)
def _matvec(self, x: Union[DistributedArray, StackedDistributedArray]) -> Union[DistributedArray, StackedDistributedArray]:
x = x.conj()
y = self.A.rmatvec(x)
y = y.conj()
return y
def _rmatvec(self, x: Union[DistributedArray, StackedDistributedArray]) -> Union[DistributedArray, StackedDistributedArray]:
x = x.conj()
y = self.A.matvec(x)
y = y.conj()
return y
class _ProductStackedLinearOperator(MPIStackedLinearOperator):
"""Product of MPI Stacked Linear Operators"""
def __init__(self, A: MPIStackedLinearOperator, B: MPIStackedLinearOperator):
from pylops_mpi.basicoperators.VStack import MPIStackedVStack
from pylops_mpi.basicoperators.BlockDiag import MPIStackedBlockDiag
if not isinstance(A, MPIStackedLinearOperator) or not isinstance(B, MPIStackedLinearOperator):
raise ValueError('both operands have to be a MPIStackedLinearOperator')
if isinstance(A, MPIStackedVStack) and isinstance(B, MPIStackedVStack):
raise ValueError('both operands cannot be MPIStackedVStack')
if isinstance(A, MPIStackedBlockDiag) and isinstance(B, MPIStackedBlockDiag) and len(A.ops) != len(B.ops):
raise ValueError(f'both MPIStackedBlockDiag cannot have different number of ops, {A.ops} != {B.ops}')
if A.shape[1] != B.shape[0]:
raise ValueError('cannot multiply %r and %r: shape mismatch' % (A, B))
self.args = (A, B)
super().__init__(shape=(A.shape[0], B.shape[1]), dtype=_get_dtype([A, B]),
base_comm=MPI.COMM_WORLD)
def _matvec(self, x: Union[StackedDistributedArray, DistributedArray]) -> Union[StackedDistributedArray, DistributedArray]:
return self.args[0].matvec(self.args[1].matvec(x))
def _rmatvec(self, x: Union[StackedDistributedArray, DistributedArray]) -> Union[StackedDistributedArray, DistributedArray]:
return self.args[1].rmatvec(self.args[0].rmatvec(x))
def _adjoint(self) -> MPIStackedLinearOperator:
A, B = self.args
return B.H * A.H
class _ScaledStackedLinearOperator(MPIStackedLinearOperator):
"""Scaled MPI StackedLinearOperator
"""
def __init__(self, A: MPIStackedLinearOperator, alpha):
if not isinstance(A, MPIStackedLinearOperator):
raise ValueError('MPILinearOperator expected as A')
if not np.isscalar(alpha):
raise ValueError('scalar expected as alpha')
self.args = (A, alpha)
super().__init__(shape=A.shape, dtype=_get_dtype([A], [type(alpha)]),
base_comm=MPI.COMM_WORLD)
def _matvec(self, x: Union[DistributedArray, StackedDistributedArray]) -> Union[DistributedArray, StackedDistributedArray]:
y = self.args[0].matvec(x)
if y is not None:
y *= self.args[1]
return y
def _rmatvec(self, x: Union[DistributedArray, StackedDistributedArray]) -> Union[DistributedArray, StackedDistributedArray]:
y = self.args[0].rmatvec(x)
if y is not None:
y *= np.conj(self.args[1])
return y
def _adjoint(self) -> MPIStackedLinearOperator:
A, alpha = self.args
return A.H * np.conj(alpha)
class _SumStackedLinearOperator(MPIStackedLinearOperator):
"""Sum of MPI StackedLinearOperators
"""
def __init__(self, A: MPIStackedLinearOperator, B: MPIStackedLinearOperator):
if not isinstance(A, MPIStackedLinearOperator) or not isinstance(B, MPIStackedLinearOperator):
raise ValueError('both operands have to be a MPIStackedLinearOperator')
if type(A) != type(B): # noqa: E721
raise ValueError(f'both operands have to be of same type, {A} != {B}')
if A.shape != B.shape:
raise ValueError("cannot add %r and %r: shape mismatch" % (A, B))
self.args = (A, B)
super().__init__(shape=A.shape, dtype=A.dtype, base_comm=MPI.COMM_WORLD)
def _matvec(self, x: Union[DistributedArray, StackedDistributedArray]) -> Union[DistributedArray, StackedDistributedArray]:
arr1 = self.args[0].matvec(x)
arr2 = self.args[1].matvec(x)
return arr1 + arr2
def _rmatvec(self, x: Union[DistributedArray, StackedDistributedArray]) -> Union[DistributedArray, StackedDistributedArray]:
arr1 = self.args[0].rmatvec(x)
arr2 = self.args[1].rmatvec(x)
return arr1 + arr2
def _adjoint(self) -> MPIStackedLinearOperator:
A, B = self.args
return A.H + B.H
class _PowerLinearOperator(MPIStackedLinearOperator):
"""Power of MPI StackedLinearOperator
"""
def __init__(self, A: MPIStackedLinearOperator, p: int) -> None:
if not isinstance(A, MPIStackedLinearOperator):
raise ValueError("MPIStackedLinearOperator expected as A")
if A.shape[0] != A.shape[1]:
raise ValueError("square MPIStackedLinearOperator expected, got %r" % A)
if not isintlike(p) or p < 0:
raise ValueError("non-negative integer expected as p")
super(_PowerLinearOperator, self).__init__(shape=A.shape, dtype=A.dtype, base_comm=A.base_comm)
self.args = (A, p)
def _power(self, fun: Callable, x: Union[StackedDistributedArray, DistributedArray]) -> Union[StackedDistributedArray, DistributedArray]:
res = x.copy()
for _ in range(self.args[1]):
res[:] = fun(res)[:]
return res
def _matvec(self, x: Union[DistributedArray, StackedDistributedArray]) -> Union[DistributedArray, StackedDistributedArray]:
return self._power(self.args[0].matvec, x)
def _rmatvec(self, x: Union[DistributedArray, StackedDistributedArray]) -> Union[DistributedArray, StackedDistributedArray]:
return self._power(self.args[0].rmatvec, x)
class _ConjLinearOperator(MPIStackedLinearOperator):
"""Complex conjugate MPI StackedLinearOperator
"""
def __init__(self, A: MPIStackedLinearOperator):
if not isinstance(A, MPIStackedLinearOperator):
raise TypeError('A must be a MPIStackedLinearOperator')
self.A = A
super().__init__(shape=A.shape, dtype=A.dtype, base_comm=MPI.COMM_WORLD)
def _matvec(self, x: Union[DistributedArray, StackedDistributedArray]) -> Union[DistributedArray, StackedDistributedArray]:
x = x.conj()
y = self.A.matvec(x)
if y is not None:
y = y.conj()
return y
def _rmatvec(self, x: Union[DistributedArray, StackedDistributedArray]) -> Union[DistributedArray, StackedDistributedArray]:
x = x.conj()
y = self.A.rmatvec(x)
if y is not None:
y = y.conj()
return y
def _adjoint(self) -> MPIStackedLinearOperator:
return _ConjLinearOperator(self.A.H)