Source code for pylops_mpi.LinearOperator

from __future__ import annotations

import numpy as np
from mpi4py import MPI
from typing import Callable, Optional

from scipy.sparse._sputils import isintlike, isshape
from scipy.sparse.linalg._interface import _get_dtype

from pylops import LinearOperator
from pylops.utils import DTypeLike, ShapeLike

from pylops_mpi import DistributedArray


[docs] class MPILinearOperator: """MPI-enabled PyLops Linear Operator Common interface for performing matrix-vector products in distributed fashion. In practice, this class provides methods to perform matrix-vector and adjoint matrix-vector products between any :obj:`pylops.LinearOperator` (which must be the same across ranks) and a :class:`pylops_mpi.DistributedArray` with ``Partition.BROADCAST`` and ``Partition.UNSAFE_BROADCAST`` partition. It internally handles the extraction of the local array from the distributed array and the creation of the output :class:`pylops_mpi.DistributedArray`. Note that whilst this operator could also be used with different :obj:`pylops.LinearOperator` across ranks, and with a :class:`pylops_mpi.DistributedArray` with ``Partition.SCATTER``, it is however recommended to use the :class:`pylops_mpi.basicoperators.MPIBlockDiag` operator instead as this can also handle distributed arrays with subcommunicators. Parameters ---------- Op : :obj:`pylops.LinearOperator`, optional If other arguments are provided, they will overwrite those obtained from ``Op``. Defaults to ``None``. shape : :obj:`tuple(int, int)`, optional Shape of the MPI Linear Operator. If not provided, obtained from ``dims`` and ``dimsd``. dims : :obj:`tuple(int, ..., int)`, optional Dimensions of model. If not provided, ``(self.shape[1],)`` is used. dimsd : :obj:`tuple(int, ..., int)`, optional Dimensions of data. If not provided, ``(self.shape[0],)`` is used. dtype : :obj:`str`, optional Type of elements in input array. Defaults to ``None``. base_comm : :obj:`mpi4py.MPI.Comm`, optional MPI Base Communicator. Defaults to ``mpi4py.MPI.COMM_WORLD``. """ def __init__( self, Op: Optional[LinearOperator] = None, shape: Optional[ShapeLike] = None, dims: Optional[ShapeLike] = None, dimsd: Optional[ShapeLike] = None, dtype: Optional[DTypeLike] = None, base_comm: MPI.Comm = MPI.COMM_WORLD ): if Op is not None: self.Op = Op dtype = Op.dtype if dtype is None else dtype shape = Op.shape if shape is None else shape # Optional arguments dims = getattr(Op, "dims", (Op.shape[1], )) if dims is None else dims dimsd = getattr(Op, "dimsd", (Op.shape[0], )) if dimsd is None else dimsd if shape is not None: self.shape = shape if dims is not None: self.dims = dims if dimsd is not None: self.dimsd = dimsd if dtype is not None: self.dtype = dtype # For MPI self.base_comm = base_comm self.size = base_comm.Get_size() self.rank = base_comm.Get_rank() @property def shape(self): _shape = getattr(self, "_shape", None) if _shape is None: # Cannot find shape, falling back on dims and dimsd dims = getattr(self, "_dims", None) dimsd = getattr(self, "_dimsd", None) if dims is None or dimsd is None: # Cannot find both dims and dimsd, error msg = ( f"'{self.__class__.__name__}' object has no attribute 'shape' " "nor both fallback attributes ('dims', 'dimsd')" ) raise AttributeError(msg) _shape = (int(np.prod(dimsd)), int(np.prod(dims))) self._shape = _shape # Update to not redo everything above on next call return _shape @shape.setter def shape(self, new_shape: ShapeLike) -> None: new_shape = tuple(new_shape) if not isshape(new_shape): msg = f"Invalid shape; must be 2-d tuple of integers, got {new_shape}" raise ValueError(msg) dims = getattr(self, "_dims", None) dimsd = getattr(self, "_dimsd", None) if dims is not None and dimsd is not None: # Found dims and dimsd if np.prod(dimsd) != new_shape[0] and np.prod(dims) != new_shape[1]: msg = "New shape incompatible with dims and dimsd" raise ValueError(msg) elif np.prod(dimsd) != new_shape[0]: msg = "New shape incompatible with dimsd" raise ValueError(msg) elif np.prod(dims) != new_shape[1]: msg = "New shape incompatible with dims" raise ValueError(msg) self._shape = new_shape @property def dims(self): _dims = getattr(self, "_dims", None) if _dims is None: shape = getattr(self, "_shape", None) if shape is None: msg = ( f"'{self.__class__.__name__}' object has no " "attributes 'dims' or 'shape'" ) raise AttributeError(msg) _dims = (shape[1],) return _dims @dims.setter def dims(self, new_dims: ShapeLike) -> None: new_dims = tuple(new_dims) shape = getattr(self, "_shape", None) if shape is None: # shape not set yet self._dims = new_dims else: if np.prod(new_dims) == self.shape[1]: self._dims = new_dims else: msg = "dims incompatible with shape[1]" raise ValueError(msg) @property def dimsd(self): _dimsd = getattr(self, "_dimsd", None) if _dimsd is None: shape = getattr(self, "_shape", None) if shape is None: msg = ( f"'{self.__class__.__name__}' object has " "no attributes 'dimsd' or 'shape'" ) raise AttributeError(msg) _dimsd = (shape[0],) return _dimsd @dimsd.setter def dimsd(self, new_dimsd: ShapeLike) -> None: new_dimsd = tuple(new_dimsd) shape = getattr(self, "_shape", None) if shape is None: # shape not set yet self._dimsd = new_dimsd else: if np.prod(new_dimsd) == self.shape[0]: self._dimsd = new_dimsd else: msg = "dimsd incompatible with shape[0]" raise ValueError(msg) def matvec(self, x: DistributedArray) -> DistributedArray: """Matrix-vector multiplication. Modified version of pylops matvec This method makes use of :class:`pylops_mpi.DistributedArray` to calculate matrix vector multiplication in a distributed fashion. Parameters ---------- x : :obj:`pylops_mpi.DistributedArray` A DistributedArray of global shape (N, ). Returns ------- y : :obj:`pylops_mpi.DistributedArray` DistributedArray of global shape (M, ) """ M, N = self.shape if x.global_shape != (N,): raise ValueError("dimension mismatch") return self._matvec(x) def _matvec(self, x: DistributedArray) -> DistributedArray: if self.Op: y = DistributedArray(global_shape=self.shape[0], base_comm=self.base_comm, base_comm_nccl=x.base_comm_nccl, partition=x.partition, axis=x.axis, engine=x.engine, dtype=self.dtype) y[:] = self.Op._matvec(x.local_array) return y def rmatvec(self, x: DistributedArray) -> DistributedArray: """Adjoint Matrix-vector multiplication. Modified version of pylops rmatvec This method makes use of :class:`pylops_mpi.DistributedArray` to calculate adjoint matrix vector multiplication in a distributed fashion. Parameters ---------- x : :obj:`pylops_mpi.DistributedArray` A DistributedArray of global shape (M, ). Returns ------- y : :obj:`pylops_mpi.DistributedArray` DistributedArray of global shape (N, ) """ M, N = self.shape if x.global_shape != (M,): raise ValueError("dimension mismatch") return self._rmatvec(x) def _rmatvec(self, x: DistributedArray) -> DistributedArray: if self.Op: y = DistributedArray(global_shape=self.shape[1], base_comm=self.base_comm, base_comm_nccl=x.base_comm_nccl, partition=x.partition, axis=x.axis, engine=x.engine, dtype=self.dtype) y[:] = self.Op._rmatvec(x.local_array) return y def dot(self, x): """Matrix Vector Multiplication Parameters ---------- x : :obj:`pylops_mpi.DistributedArray` or :obj:`pylops_mpi.MPILinearOperator DistributedArray or a MPILinearOperator. Returns ------- y : :obj:`pylops_mpi.DistributedArray` or :obj:`pylops_mpi.MPILinearOperator` DistributedArray or a MPILinearOperator. """ if isinstance(x, MPILinearOperator): Op = _ProductLinearOperator(self, x) self._copy_attributes( Op, exclude=['dims'] ) Op.dims = x.dims return Op elif np.isscalar(x): Op = _ScaledLinearOperator(self, x) self._copy_attributes( Op ) return Op else: if x is None or x.ndim == 1: return self.matvec(x) else: raise ValueError('expected 1-d DistributedArray, got %r' % x.global_shape) def adjoint(self): """Adjoint MPI LinearOperator Returns ------- op : :obj:`pylops_mpi.MPILinearOperator` Adjoint of Operator """ return self._adjoint() H = property(adjoint) def transpose(self): """Transposition of MPI LinearOperator Returns ------- op : :obj:`pylops_mpi.MPILinearOperator` Transpose Linear Operator """ return self._transpose() T = property(transpose) def __mul__(self, x): return self.dot(x) def __rmul__(self, x): if np.isscalar(x): Op = _ScaledLinearOperator(self, x) self._copy_attributes( Op ) return Op else: return NotImplemented def __matmul__(self, x): if np.isscalar(x): raise ValueError("Scalar not allowed, use * instead") return self.__mul__(x) def __rmatmul__(self, x): if np.isscalar(x): raise ValueError("Scalar not allowed, use * instead") return self.__rmul__(x) def __pow__(self, p): Op = _PowerLinearOperator(self, p) self._copy_attributes( Op ) return Op def __add__(self, x): Op = _SumLinearOperator(self, x) self._copy_attributes( Op ) return Op def __neg__(self): Op = _ScaledLinearOperator(self, -1) self._copy_attributes( Op ) return Op def __sub__(self, x): return self.__add__(-x) def _adjoint(self): Op = _AdjointLinearOperator(self) self._copy_attributes( Op, exclude=['dims', 'dimsd'] ) Op.dims = self.dimsd Op.dimsd = self.dims return Op def _transpose(self): Op = _TransposedLinearOperator(self) self._copy_attributes( Op, exclude=['dims', 'dimsd'] ) Op.dims = self.dimsd Op.dimsd = self.dims return Op def conj(self): """Complex conjugate operator Returns ------- conjop : :obj:`pylops_mpi.MPILinearOperator` Complex conjugate operator """ return _ConjLinearOperator(self) def _copy_attributes( self, dest: MPILinearOperator, exclude: list[str] | None = None, ) -> None: """Copy attributes from one MPILinearOperator to another""" attrs = ["dims", "dimsd"] if exclude is not None: for item in exclude: attrs.remove(item) for attr in attrs: if hasattr(self, attr): setattr(dest, attr, getattr(self, attr)) def __repr__(self): M, N = self.shape if self.dtype is None: dt = "unspecified dtype" else: dt = f"dtype={self.dtype}" return f"<{M}x{N} {self.__class__.__name__} with {dt}>"
class _AdjointLinearOperator(MPILinearOperator): """Adjoint of MPI Linear Operator""" def __init__(self, A: MPILinearOperator): self.A = A self.args = (A,) super().__init__(shape=(A.shape[1], A.shape[0]), dtype=A.dtype, base_comm=MPI.COMM_WORLD) def _matvec(self, x: DistributedArray) -> DistributedArray: return self.A.rmatvec(x) def _rmatvec(self, x: DistributedArray) -> DistributedArray: return self.A.matvec(x) class _TransposedLinearOperator(MPILinearOperator): """Transposition of MPI Linear Operator""" def __init__(self, A: MPILinearOperator): self.A = A self.args = (A,) super().__init__(shape=(A.shape[1], A.shape[0]), dtype=A.dtype, base_comm=MPI.COMM_WORLD) def _matvec(self, x: DistributedArray) -> DistributedArray: x = x.conj() y = self.A.rmatvec(x) y = y.conj() return y def _rmatvec(self, x: DistributedArray) -> DistributedArray: x = x.conj() y = self.A.matvec(x) y = y.conj() return y class _ProductLinearOperator(MPILinearOperator): """Product of MPI LinearOperators """ def __init__(self, A: MPILinearOperator, B: MPILinearOperator): if not isinstance(A, MPILinearOperator) or not isinstance(B, MPILinearOperator): raise ValueError('both operands have to be a LinearOperator') if A.shape[1] != B.shape[0]: raise ValueError('cannot multiply %r and %r: shape mismatch' % (A, B)) self.args = (A, B) super().__init__(shape=(A.shape[0], B.shape[1]), dtype=_get_dtype([A, B]), base_comm=MPI.COMM_WORLD) def _matvec(self, x: DistributedArray) -> DistributedArray: return self.args[0].matvec(self.args[1].matvec(x)) def _rmatvec(self, x: DistributedArray) -> DistributedArray: return self.args[1].rmatvec(self.args[0].rmatvec(x)) def _adjoint(self) -> MPILinearOperator: A, B = self.args return B.H * A.H class _ScaledLinearOperator(MPILinearOperator): """Scaled MPI Linear Operator """ def __init__(self, A: MPILinearOperator, alpha): if not isinstance(A, MPILinearOperator): raise ValueError('MPILinearOperator expected as A') if not np.isscalar(alpha): raise ValueError('scalar expected as alpha') self.args = (A, alpha) super().__init__(shape=A.shape, dtype=_get_dtype([A], [type(alpha)]), base_comm=MPI.COMM_WORLD) def _matvec(self, x: DistributedArray) -> DistributedArray: y = self.args[0].matvec(x) if y is not None: y[:] *= self.args[1] return y def _rmatvec(self, x: DistributedArray) -> DistributedArray: y = self.args[0].rmatvec(x) if y is not None: y[:] *= np.conj(self.args[1]) return y def _adjoint(self) -> MPILinearOperator: A, alpha = self.args return A.H * np.conj(alpha) class _SumLinearOperator(MPILinearOperator): """Sum of MPI LinearOperators """ def __init__(self, A: MPILinearOperator, B: MPILinearOperator): if not isinstance(A, MPILinearOperator) or not isinstance(B, MPILinearOperator): raise ValueError('both operands have to be a MPILinearOperator') # Make sure it works with different kinds if A.shape != B.shape: raise ValueError("cannot add %r and %r: shape mismatch" % (A, B)) self.args = (A, B) super().__init__(shape=A.shape, dtype=A.dtype, base_comm=MPI.COMM_WORLD) def _matvec(self, x: DistributedArray) -> DistributedArray: arr1 = self.args[0].matvec(x) arr2 = self.args[1].matvec(x) return arr1 + arr2 def _rmatvec(self, x: DistributedArray) -> DistributedArray: arr1 = self.args[0].rmatvec(x) arr2 = self.args[1].rmatvec(x) return arr1 + arr2 def _adjoint(self) -> MPILinearOperator: A, B = self.args return A.H + B.H class _PowerLinearOperator(MPILinearOperator): """Power of MPI Linear Operator """ def __init__(self, A: MPILinearOperator, p: int) -> None: if not isinstance(A, MPILinearOperator): raise ValueError("LinearOperator expected as A") if A.shape[0] != A.shape[1]: raise ValueError("square LinearOperator expected, got %r" % A) if not isintlike(p) or p < 0: raise ValueError("non-negative integer expected as p") super(_PowerLinearOperator, self).__init__(shape=A.shape, dtype=A.dtype, base_comm=A.base_comm) self.args = (A, p) def _power(self, fun: Callable, x: DistributedArray) -> DistributedArray: res = x.copy() for _ in range(self.args[1]): res[:] = fun(res).local_array return res def _matvec(self, x: DistributedArray) -> DistributedArray: return self._power(self.args[0].matvec, x) def _rmatvec(self, x: DistributedArray) -> DistributedArray: return self._power(self.args[0].rmatvec, x) class _ConjLinearOperator(MPILinearOperator): """Complex conjugate MPI Linear Operator """ def __init__(self, A: MPILinearOperator): if not isinstance(A, MPILinearOperator): raise TypeError('A must be a MPILinearOperator') self.A = A super().__init__(shape=A.shape, dtype=A.dtype, base_comm=MPI.COMM_WORLD) def _matvec(self, x: DistributedArray) -> DistributedArray: x = x.conj() y = self.A.matvec(x) if y is not None: y = y.conj() return y def _rmatvec(self, x: DistributedArray) -> DistributedArray: x = x.conj() y = self.A.rmatvec(x) if y is not None: y = y.conj() return y def _adjoint(self) -> MPILinearOperator: return _ConjLinearOperator(self.A.H)
[docs] def asmpilinearoperator(Op): """Return Op as a MPI LinearOperator. Converts a :class:`pylops.LinearOperator` to a :class:`pylops_mpi.MPILinearOperator`. Parameters ---------- Op : :obj:`pylops.LinearOperator` PyLops LinearOperator Returns ------- Op : :obj:`pylops_mpi.MPILinearOperator` Operator of type :obj:`pylops_mpi.MPILinearOperator` """ if isinstance(Op, MPILinearOperator): return Op else: return MPILinearOperator(Op=Op, base_comm=MPI.COMM_WORLD)