Source code for pylops_mpi.signalprocessing.FFTND

import warnings
from typing import Sequence

from mpi4py import MPI
import numpy as np

from pylops.signalprocessing._baseffts import _FFTNorms
from pylops.utils import DTypeLike, InputDimsLike

from pylops_mpi.utils.decorators import reshaped
from pylops_mpi.DistributedArray import DistributedArray, Partition
from pylops_mpi.signalprocessing._baseffts import _MPIBaseFFTND
from pylops_mpi.utils import deps, fftshift_nd, ifftshift_nd

mpi4py_fft_message = deps.mpi4py_fft_import("mpi4py_fft")

if mpi4py_fft_message is None:
    from mpi4py_fft import PFFT, newDistArray
    from mpi4py_fft.pencil import Subcomm


[docs] class MPIFFTND(_MPIBaseFFTND): r"""N-dimensional Fast-Fourier Transform. Apply N-dimensional Fast-Fourier Transform (FFT) to any n ``axes`` of a multidimensional array. When using ``real=True``, the result of the forward is also multiplied by :math:`\sqrt{2}` for all frequency bins except zero and Nyquist along the last ``axes``, and the input of the adjoint is multiplied by :math:`1 / \sqrt{2}` for the same frequencies. For a real valued input signal, it is advised to use the flag ``real=True`` as it stores the values of the Fourier transform of the last axis in ``axes`` at positive frequencies only as values at negative frequencies are simply their complex conjugates. Parameters ---------- dims : :obj:`tuple` Number of samples for each dimension axes : :obj:`tuple`, optional Axes (or axis) along which FFTND is applied sampling : :obj:`tuple` or :obj:`float`, optional Sampling steps for each direction. When supplied a single value, it is used for all directions. norm : `{"none", "1/n"}`, optional - "none": Does not scale the forward or the adjoint FFT transforms. Default is "none". - "1/n": Scales both the forward and adjoint FFT transforms by :math:`1/N_F`. real : :obj:`bool`, optional Model to which fft is applied has real numbers (``True``) or not (``False``). Used to enforce that the output of adjoint of a real model is real. Note that the real FFT is applied only to the first dimension to which the FFTND operator is applied (last element of ``axes``) ifftshift_before : :obj:`tuple` or :obj:`bool`, optional Apply ifftshift (``True``) or not (``False``) to model vector (before FFT). Consider using this option when the model vector's respective axis is symmetric with respect to the zero value sample. This will shift the zero value sample to coincide with the zero index sample. With such an arrangement, FFT will not introduce a sample-dependent phase-shift when compared to the continuous Fourier Transform. When passing a single value, the shift will the same for every direction. Pass a tuple to specify which dimensions are shifted. fftshift_after : :obj:`tuple` or :obj:`bool`, optional Apply fftshift (``True``) or not (``False``) to data vector (after FFT). Consider using this option when you require frequencies to be arranged naturally, from negative to positive. When not applying fftshift after FFT, frequencies are arranged from zero to largest positive, and then from negative Nyquist to the frequency bin before zero. When passing a single value, the shift will the same for every direction. Pass a tuple to specify which dimensions are shifted. dtype : :obj:`str`, optional Type of elements in input array. Note that the ``dtype`` of the operator is the corresponding complex type even when a real type is provided. In addition, note that the NumPy backend does not support returning ``dtype`` different from ``complex128``. base_comm : :obj:`mpi4py.MPI.Comm`, optional MPI Base Communicator. Defaults to ``mpi4py.MPI.COMM_WORLD``. Attributes ---------- fs : :obj:`tuple` Each element of the tuple corresponds to the Discrete Fourier Transform sample frequencies along the respective direction given by ``axes``. nffts : :obj:`tuple` or :obj:`int`, optional Number of samples in Fourier Transform for each axis in ``axes``. real : :obj:`bool` When ``True``, uses real fast fourier transform rdtype : :obj:`bool` Expected input type to the forward cdtype : :obj:`bool` Output type of the forward. Complex equivalent to ``rdtype``. shape : :obj:`tuple` Operator shape. clinear : :obj:`bool` Operator is complex-linear. Is false when either ``real=True`` or when ``dtype`` is not a complex type. fft : :obj:`mpi4py_fft.mpifft.PFFT` Parallel FFT operator object handling the distributed transform across MPI processes. Configured with the base communicator, dimension decomposition, transform axes, and dtype. See Also -------- MPIFFT2D: Two-dimensional FFT Raises ------ ValueError - If ``norm`` is not one of "none", or "1/n". Notes ----- The MPIFFTND operator applies the forward and inverse N-dimensional FFT to a :class:`pylops_mpi.DistributedArray`, accepted as a 1D flattened array and reshaped internally to the layout defined by ``dims``. The distributed FFT transform is performed by :class:`mpi4py_fft.mpifft.PFFT` via :class:`mpi4py_fft.pencil.Subcomm`. Since the 1D input is always distributed along ``axis=0`` after reshaping, PFFT is configured to distribute along ``axis=0`` by default. The exception is when ``axes[-1] == 0``: PFFT requires the final transform axis to be local on each rank, so the distribution is shifted to ``axis=1`` and the input is redistributed accordingly before the transform. After the transform, the output is flattened back to 1D. The class uses PFFT's two internal pencil layouts: ``pencil[False]`` for forward-input/backward-output and ``pencil[True]`` for forward-output/backward-input. During initialization, it records the distributed axes of these layouts as ``_pfft_in_axis`` and ``_pfft_out_axis``, and redistributes the input :class:`pylops_mpi.DistributedArray` as needed before each transform. In the forward pass, :meth:`PFFT.forward` is called with ``normalize=False``, computing: .. math:: D(k_1, \ldots, k_N) = \mathscr{F} (d) = \int\limits_{-\infty}^\infty \cdots \int\limits_{-\infty}^\infty d(x_1, \ldots, x_N) e^{-j2\pi k_1 x_1} \cdots e^{-j 2 \pi k_N x_N} \,\mathrm{d}x_1 \cdots \mathrm{d}x_N When ``norm="1/n"``, the result is additionally scaled by :math:`1/N_F`. In the adjoint pass, :meth:`PFFT.backward` is called with ``normalize=True``, so ``PFFT`` internally divides by :math:`N_F = \prod_i N_i`, computing: .. math:: d(x_1, \ldots, x_N) = \mathscr{F}^{-1} (D) = \frac{1}{N_F} \int\limits_{-\infty}^\infty \cdots \int\limits_{-\infty}^\infty D(k_1, \ldots, k_N) e^{j2\pi k_1 x_1} \cdots e^{j 2 \pi k_N x_N} \,\mathrm{d}k_1 \cdots \mathrm{d}k_N When ``norm="none"``, the adjoint multiplies by :math:`N_F` to cancel this internal scaling, returning a true unscaled adjoint. The result is then flattened back to a 1D :class:`pylops_mpi.DistributedArray`. All inter-rank data movement is handled internally by ``mpi4py_fft``. """ def __init__( self, dims: InputDimsLike, axes: InputDimsLike = (0, 1, 2), sampling: float | Sequence[float] = 1.0, norm: str = "none", real: bool = False, ifftshift_before: bool = False, fftshift_after: bool = False, dtype: DTypeLike = "complex128", base_comm: MPI.Comm = MPI.COMM_WORLD ) -> None: super().__init__( dims=dims, axes=axes, sampling=sampling, norm=norm, real=real, fftshift_after=fftshift_after, ifftshift_before=ifftshift_before, dtype=dtype, base_comm=base_comm ) if self.cdtype != np.complex128: warnings.warn( "numpy backend always returns complex128 dtype. " "To respect the passed dtype, data will be cast to {self.cdtype}.", stacklevel=2, ) if self.norm is _FFTNorms.NONE: self._scale = np.prod(self.nffts) elif self.norm is _FFTNorms.ONE_OVER_N: self._scale = 1.0 / np.prod(self.nffts) fft_dtype = self.rdtype if self.real else self.cdtype subcomm_dims = np.ones(len(dims), dtype=int) # axis=0 for the initial distribution by default # if the final axis over which FFT is applied is axis=0, the input array is first redistributed over axis=1 # prior to applying FFT. if axes[-1] == 0: subcomm_dims[1] = 0 else: subcomm_dims[0] = 0 subcomm = Subcomm(base_comm, subcomm_dims) self.fft = PFFT(subcomm, self.dims, axes=self.axes, dtype=fft_dtype) # PFFT uses two internal layouts (pencils): one before and one after the transform. The two layouts can differ # because PFFT may redistribute data mid-transform to align the active FFT axis with the distributed axis. # pencil[False] is the forward-input / backward-output layout. # pencil[True] is the forward-output / backward-input layout. # Distributed axis in the pre-transform (pencil[False]) layout. self._pfft_in_axis = next( (i for i, s in enumerate(self.fft.pencil[False].subcomm) if s.Get_size() > 1), 0 ) # Distributed axis in the post-transform (pencil[True]) layout. self._pfft_out_axis = next( (i for i, s in enumerate(self.fft.pencil[True].subcomm) if s.Get_size() > 1), 0 ) @reshaped def _matvec(self, x: DistributedArray) -> DistributedArray: if x.engine == "cupy": raise ValueError(f"x should be a numpy array with engine=numpy" f"Got {x.engine} instead...") if x.partition != Partition.SCATTER: raise ValueError(f"x should have partition={Partition.SCATTER}" f"Got {x.partition} instead...") if self.ifftshift_before.any(): x = ifftshift_nd(x, axes=self.axes[self.ifftshift_before]) if not self.clinear: x[:] = np.real(x.local_array) x_dist_pfft = newDistArray(self.fft, forward_output=False) y_dist_pfft = newDistArray(self.fft, forward_output=True) # Redistribute input to match the input PFFT axis x = x.redistribute(axis=self._pfft_in_axis) x_dist_pfft[:] = x.local_array # Perform the parallel forward FFT self.fft.forward(x_dist_pfft, y_dist_pfft, normalize=False) y = DistributedArray(global_shape=self.dimsd, dtype=self.dtype, axis=self._pfft_out_axis, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, engine=x.engine) y[:] = y_dist_pfft if self.real: self._scale_real_fft(y, inverse=False) if self.norm is _FFTNorms.ONE_OVER_N: y[:] *= self._scale y[:] = y.local_array.astype(self.cdtype) if self.fftshift_after.any(): y = fftshift_nd(y, axes=self.axes[self.fftshift_after]) return y @reshaped def _rmatvec(self, x: DistributedArray) -> DistributedArray: if x.engine == "cupy": raise ValueError(f"x should be a numpy array with engine=numpy" f"Got {x.engine} instead...") if x.partition != Partition.SCATTER: raise ValueError(f"x should have partition={Partition.SCATTER}, " f"Got {x.partition} instead...") if self.fftshift_after.any(): x = ifftshift_nd(x, axes=self.axes[self.fftshift_after]) if self.real: self._scale_real_fft(x, inverse=True) # Allocate distributed arrays for input and output y_dist_pfft = newDistArray(self.fft, forward_output=False) x_dist_pfft = newDistArray(self.fft, forward_output=True) # Redistribute input to match the PFFT axis x = x.redistribute(axis=self._pfft_out_axis) x_dist_pfft[:] = x.local_array # Perform the parallel backward FFT self.fft.backward(x_dist_pfft, y_dist_pfft, normalize=True) y = DistributedArray(global_shape=self.dims, dtype=self.dtype, axis=self._pfft_in_axis, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, engine=x.engine) y[:] = y_dist_pfft if self.norm is _FFTNorms.NONE: y[:] *= self._scale if not self.clinear: y[:] = np.real(y.local_array) y[:] = y.local_array.astype(self.rdtype) if self.ifftshift_before.any(): y = fftshift_nd(y, axes=self.axes[self.ifftshift_before]) return y def _scale_real_fft(self, x: DistributedArray, inverse: bool = False) -> None: """Apply scaling for real-valued FFTs. Scales the non-DC positive frequency components along the final FFT axis by ``sqrt(2)`` in forward mode and ``1/sqrt(2)`` in inverse mode. When the final FFT axis is distributed across MPI ranks, only the local portion overlapping with the global positive-frequency range is scaled. Parameters ---------- x : DistributedArray Distributed FFT array to scale in-place. inverse : bool, optional Apply inverse scaling when ``True``. Default is ``False``. """ scale = 1 / np.sqrt(2) if inverse else np.sqrt(2) if x.axis == self.axes[-1]: sizes = [loc_shape[self.axes[-1]] for loc_shape in x.local_shapes] local_start = sum(sizes[:self.base_comm.rank]) local_stop = local_start + sizes[self.base_comm.rank] freq_start, freq_stop = max(1, local_start), min(1 + (self.nffts[-1] - 1) // 2, local_stop) # Local overlap with the global frequency slice [1:k] if freq_stop > freq_start: local_slice = [slice(None)] * x.ndim local_slice[self.axes[-1]] = slice(freq_start - local_start, freq_stop - local_start) x[tuple(local_slice)] *= scale else: # Axis is local on this rank, so direct slicing freq_slice = [slice(None)] * x.ndim freq_slice[self.axes[-1]] = slice(1, 1 + (self.nffts[-1] - 1) // 2) x[tuple(freq_slice)] *= scale def __truediv__(self, y: DistributedArray) -> DistributedArray: y_div = self._rmatvec(y) y_div[:] = y_div.local_array / self._scale return y_div