Source code for pylops_mpi.utils.dottest

__all__ = ["dottest"]

from typing import Optional

import numpy as np

from pylops_mpi.DistributedArray import DistributedArray
from pylops.utils.backend import to_numpy


[docs] def dottest( Op, u: DistributedArray, v: DistributedArray, nr: Optional[int] = None, nc: Optional[int] = None, rtol: float = 1e-6, atol: float = 1e-21, raiseerror: bool = True, verb: bool = False, ) -> bool: r"""Dot test. Perform dot-test to verify the validity of forward and adjoint operators using user-provided random vectors :math:`\mathbf{u}` and :math:`\mathbf{v}` (whose Partition must be consistent with the operator being tested). This test can help to detect errors in the operator ximplementation. Parameters ---------- Op : :obj:`pylops_mpi.LinearOperator` Linear operator to test. u : :obj:`pylops_mpi.DistributedArray` Distributed array of size equal to the number of columns of operator v : :obj:`pylops_mpi.DistributedArray` Distributed array of size equal to the number of rows of operator nr : :obj:`int` Number of rows of operator (i.e., elements in data) nc : :obj:`int` Number of columns of operator (i.e., elements in model) rtol : :obj:`float`, optional Relative dottest tolerance atol : :obj:`float`, optional Absolute dottest tolerance .. versionadded:: 2.0.0 raiseerror : :obj:`bool`, optional Raise error or simply return ``False`` when dottest fails verb : :obj:`bool`, optional Verbosity Returns ------- passed : :obj:`bool` Passed flag. Raises ------ AssertionError If dot-test is not verified within chosen tolerances. Notes ----- A dot-test is mathematical tool used in the development of numerical linear operators. More specifically, a correct implementation of forward and adjoint for a linear operator should verify the following *equality* within a numerical tolerance: .. math:: (\mathbf{Op}\,\mathbf{u})^H\mathbf{v} = \mathbf{u}^H(\mathbf{Op}^H\mathbf{v}) """ if nr is None: nr = Op.shape[0] if nc is None: nc = Op.shape[1] if (nr, nc) != Op.shape: raise AssertionError("Provided nr and nc do not match operator shape") y = Op.matvec(u) # Op * u x = Op.rmatvec(v) # Op'* v yy = np.vdot(y.asarray(), v.asarray()) # (Op * u)' * v xx = np.vdot(u.asarray(), x.asarray()) # u' * (Op' * v) # convert back to numpy (in case cupy arrays were used), make into a numpy # array and extract the first element. This is ugly but allows to handle # complex numbers in subsequent prints also when using cupy arrays. xx, yy = np.array([to_numpy(xx)])[0], np.array([to_numpy(yy)])[0] # evaluate if dot test passed passed = np.isclose(xx, yy, rtol, atol) # verbosity or error raising if (not passed and raiseerror) or verb: passed_status = "passed" if passed else "failed" msg = f"Dot test {passed_status}, v^H(Opu)={yy} - u^H(Op^Hv)={xx}" if not passed and raiseerror: raise AssertionError(msg) else: print(msg) return passed