from __future__ import annotations
from collections.abc import Iterable, Sequence
from typing import TYPE_CHECKING, Any, cast
import numpy as np
from astropy import units as u
from scipy import fft as sp_fft
from gwexpy.interop._registry import ConverterRegistry
from .base import FieldBase
if TYPE_CHECKING:
from gwexpy.types.typing import IndexLike
__all__ = ["ScalarField"]
[docs]
class ScalarField(FieldBase):
"""4D Field with domain states and FFT operations.
This class extends :class:`Array4D` to represent physical fields that
can exist in different domains (time/frequency for axis 0, real/k-space
for spatial axes 1-3).
**Key feature**: All indexing operations return a ScalarField, maintaining
4D structure. Integer indices result in axes with length 1 rather than
being dropped.
Parameters
----------
data : array-like
4-dimensional input data.
unit : `~astropy.units.Unit`, optional
Physical unit of the data.
axis0 : `~astropy.units.Quantity` or array-like, optional
Index values for axis 0 (time or frequency).
axis1 : `~astropy.units.Quantity` or array-like, optional
Index values for axis 1 (x or kx).
axis2 : `~astropy.units.Quantity` or array-like, optional
Index values for axis 2 (y or ky).
axis3 : `~astropy.units.Quantity` or array-like, optional
Index values for axis 3 (z or kz).
axis_names : iterable of str, optional
Names for each axis (length 4). Defaults based on domain.
axis0_domain : {'time', 'frequency'}, optional
Domain of axis 0. Default is 'time'.
space_domain : {'real', 'k'} or dict, optional
Domain of spatial axes. If str, applies to all spatial axes.
If dict, maps axis names to domains. Default is 'real'.
**kwargs
Additional keyword arguments passed to :class:`Array4D`.
Attributes
----------
axis0_domain : str
Current domain of axis 0 ('time' or 'frequency').
space_domains : dict
Mapping of spatial axis names to their domains ('real' or 'k').
Examples
--------
>>> import numpy as np
>>> from astropy import units as u
>>> from gwexpy.fields import ScalarField
>>> data = np.random.randn(100, 32, 32, 32)
>>> times = np.arange(100) * 0.01 * u.s
>>> x = np.arange(32) * 1.0 * u.m
>>> field = ScalarField(data, axis0=times, axis1=x, axis2=x, axis3=x,
... axis_names=['t', 'x', 'y', 'z'])
"""
_axis0_index: IndexLike
_axis1_index: IndexLike
_axis2_index: IndexLike
_axis3_index: IndexLike
def __getitem__(self, item: Any) -> ScalarField:
"""Get item, always returning ScalarField (4D maintained).
Integer indices are converted to length-1 slices to maintain
4D structure.
"""
forced_item = self._force_4d_item(item)
return self._getitem_scalarfield(forced_item)
def _force_4d_item(self, item: Any) -> tuple[Any, ...]:
"""Convert int indices to slice(i, i+1) to maintain 4D."""
if not isinstance(item, tuple):
item = (item,)
# Handle Ellipsis
if Ellipsis in item:
if item.count(Ellipsis) > 1:
raise IndexError("Only one ellipsis allowed")
ellipsis_idx = item.index(Ellipsis)
num_specified = len(item) - 1
fill = 4 - num_specified
if fill < 0:
raise IndexError("Too many indices for 4D array")
item = (
item[:ellipsis_idx] + (slice(None),) * fill + item[ellipsis_idx + 1 :]
)
# Pad to length 4
if len(item) < 4:
item = item + (slice(None),) * (4 - len(item))
if len(item) > 4:
raise IndexError("Too many indices for 4D array")
# Convert int to slice(i, i+1)
result = []
for i, idx in enumerate(item):
if self._is_int_index(idx):
# Normalize negative indices
size = self.shape[i]
if idx < 0:
idx = size + idx
if idx < 0 or idx >= size:
raise IndexError(
f"Index {idx} out of bounds for axis {i} with size {size}"
)
result.append(slice(idx, idx + 1))
else:
result.append(idx)
return tuple(result)
def _getitem_scalarfield(self, item: tuple[Any, ...]) -> ScalarField:
"""Perform getitem with ScalarField reconstruction.
item should already be normalized (all slices, length 4).
"""
# Call parent's raw getitem
from gwpy.types.array import Array as GwpyArray
raw = GwpyArray.__getitem__(self, item)
if not isinstance(item, tuple) or len(item) != 4:
return self._to_plain(raw)
# All should be slices now (from _force_4d_item)
current_axes = [
(self._axis0_name, self._axis0_index),
(self._axis1_name, self._axis1_index),
(self._axis2_name, self._axis2_index),
(self._axis3_name, self._axis3_index),
]
new_axes = []
for i, sl in enumerate(item):
name, idx_arr = current_axes[i]
if isinstance(sl, slice):
new_axes.append((name, idx_arr[sl]))
else:
# Unexpected: should be slice after _force_4d_item
return self._to_plain(raw)
if getattr(raw, "ndim", None) != 4:
return self._to_plain(raw)
value, unit = self._value_unit(raw)
meta = self._metadata_kwargs(raw)
# Build space_domains for new axes
new_space_domains = {}
for name, _ in new_axes[1:]: # spatial axes only
if name in self._space_domains:
new_space_domains[name] = self._space_domains[name]
else:
new_space_domains[name] = "real"
result = ScalarField(
value,
unit=unit,
axis_names=[n for n, _ in new_axes],
axis0=new_axes[0][1],
axis1=new_axes[1][1],
axis2=new_axes[2][1],
axis3=new_axes[3][1],
axis0_domain=self._axis0_domain,
space_domain=new_space_domains,
copy=False,
**meta,
)
self._propagate_gwex_attrs(result)
return result
def _isel_tuple(self, item_tuple: tuple[Any, ...]) -> ScalarField:
"""Internal isel using ScalarField getitem logic."""
forced_item = self._force_4d_item(item_tuple)
return self._getitem_scalarfield(forced_item)
# =========================================================================
# Time FFT (axis=0, GWpy TimeSeries.fft compatible)
# =========================================================================
def _validate_axis_for_fft(
self, axis_index: IndexLike, axis_name: str, domain_name: str
) -> None:
"""Validate that an axis is suitable for FFT.
Parameters
----------
axis_index : Quantity
The axis coordinate array.
axis_name : str
Name of the axis for error messages.
domain_name : str
Domain name ('time', 'frequency', etc.) for error messages.
Raises
------
ValueError
If axis length < 2 or axis is not regularly spaced.
"""
if len(axis_index) < 2:
raise ValueError(
f"FFT requires {domain_name} axis length >= 2, "
f"got length {len(axis_index)} for axis '{axis_name}'"
)
# Check regularity using AxisDescriptor
from ..types.axis import AxisDescriptor
ax_desc = AxisDescriptor(axis_name, axis_index)
if not ax_desc.regular:
raise ValueError(
f"FFT requires regularly spaced {domain_name} axis, "
f"but axis '{axis_name}' is irregular"
)
[docs]
def fft_time(self, nfft: int | None = None) -> ScalarField:
"""Compute FFT along time axis (axis 0).
This method applies the same normalization as GWpy's
``TimeSeries.fft()``: rfft / nfft, with DC-excluded bins
multiplied by 2 (except Nyquist bin for even nfft).
Parameters
----------
nfft : int, optional
Length of the FFT. If None, uses the length of axis 0.
Returns
-------
ScalarField
Transformed field with ``axis0_domain='frequency'``.
Raises
------
ValueError
If ``axis0_domain`` is not 'time'.
ValueError
If time axis length < 2 or is irregularly spaced.
TypeError
If input data is complex-valued.
See Also
--------
gwpy.timeseries.TimeSeries.fft : The reference implementation.
"""
if self._axis0_domain != "time":
raise ValueError(
f"fft_time requires axis0_domain='time', got '{self._axis0_domain}'"
)
# Validate axis regularity and length
self._validate_axis_for_fft(self._axis0_index, self._axis0_name, "time")
# Reject complex input (rfft expects real-valued signals)
if np.iscomplexobj(self.value):
raise TypeError(
"fft_time requires real-valued input. "
"For complex data, use a full FFT approach."
)
if nfft is None:
nfft = self.shape[0]
# Preserve time-axis origin for later ifft_time reconstruction
t0 = self._axis0_index[0]
# rfft along axis 0, normalized
import scipy.fft as sp_fft
dft = sp_fft.rfft(self.value, n=nfft, axis=0) / nfft
# Multiply non-DC, non-Nyquist bins by 2 (one-sided spectrum correction)
# For even nfft: Nyquist bin is at index -1, should NOT be doubled
# For odd nfft: there is no Nyquist bin, double all bins from 1:
if nfft % 2 == 0:
# Even: double bins 1 to -1 (exclusive of Nyquist)
dft[1:-1, ...] *= 2.0
else:
# Odd: double bins 1 onwards (no Nyquist bin)
dft[1:, ...] *= 2.0
# Compute frequency axis
dt = self._axis0_index[1] - self._axis0_index[0]
dt_value = getattr(dt, "value", dt)
dt_unit = getattr(dt, "unit", u.dimensionless_unscaled)
freqs_value = np.fft.rfftfreq(nfft, d=dt_value)
freqs = freqs_value * (1 / dt_unit)
result = ScalarField(
dft,
unit=self.unit,
axis0=freqs,
axis1=self._axis1_index,
axis2=self._axis2_index,
axis3=self._axis3_index,
axis_names=[
self._FREQ_AXIS_NAME,
self._axis1_name,
self._axis2_name,
self._axis3_name,
],
axis0_domain="frequency",
space_domain=self._space_domains,
)
# Store the original time offset in metadata
result._axis0_offset = t0
result._validate_domain_units()
self._propagate_gwex_attrs(result)
return result
[docs]
def ifft_time(self, nout: int | None = None) -> ScalarField:
"""Compute inverse FFT along frequency axis (axis 0).
This method applies the inverse normalization of
``fft_time()`` / GWpy's ``FrequencySeries.ifft()``.
Parameters
----------
nout : int, optional
Length of the output time series. If None, computed as
``(n_freq - 1) * 2``.
Returns
-------
ScalarField
Transformed field with ``axis0_domain='time'``.
Raises
------
ValueError
If ``axis0_domain`` is not 'frequency'.
ValueError
If frequency axis length < 2 or is irregularly spaced.
See Also
--------
gwpy.frequencyseries.FrequencySeries.ifft : Reference implementation.
"""
if self._axis0_domain != "frequency":
raise ValueError(
f"ifft_time requires axis0_domain='frequency', "
f"got '{self._axis0_domain}'"
)
# Validate axis regularity and length
self._validate_axis_for_fft(self._axis0_index, self._axis0_name, "frequency")
if nout is None:
nout = (self.shape[0] - 1) * 2
# Undo normalization: divide non-DC, non-Nyquist by 2, multiply by nout
array = self.value.copy()
if nout % 2 == 0:
# Even nout: Nyquist was not doubled, so only undo for 1:-1
array[1:-1, ...] /= 2.0
else:
# Odd nout: no Nyquist, undo for all bins 1:
array[1:, ...] /= 2.0
dift = sp_fft.irfft(array * nout, n=nout, axis=0)
# Compute time axis
df = self._axis0_index[1] - self._axis0_index[0]
df_value = getattr(df, "value", df)
df_unit = getattr(df, "unit", u.dimensionless_unscaled)
# dt = 1 / (nout * df)
dt_value = 1.0 / (nout * df_value)
dt_unit = 1 / df_unit
# Restore time-axis origin if preserved from fft_time
t0_offset = getattr(self, "_axis0_offset", None)
if t0_offset is not None:
t0_value = t0_offset.value
times = (np.arange(nout) * dt_value + t0_value) * dt_unit
else:
times = np.arange(nout) * dt_value * dt_unit
result = ScalarField(
dift,
unit=self.unit,
axis0=times,
axis1=self._axis1_index,
axis2=self._axis2_index,
axis3=self._axis3_index,
axis_names=[
self._TIME_AXIS_NAME,
self._axis1_name,
self._axis2_name,
self._axis3_name,
],
axis0_domain="time",
space_domain=self._space_domains,
)
result._validate_domain_units()
self._propagate_gwex_attrs(result)
return result
# =========================================================================
# Spatial FFT (axes 1-3, two-sided signed FFT)
# =========================================================================
[docs]
def fft_space(
self,
axes: Iterable[str] | None = None,
n: Sequence[int] | None = None,
overwrite: bool = False,
) -> ScalarField:
"""Compute FFT along spatial axes.
This method uses two-sided FFT (scipy.fft.fftn) and produces
angular wavenumber (k = 2π·fftfreq).
Parameters
----------
axes : iterable of str, optional
Axis names to transform (e.g., ['x', 'y']). If None,
transforms all spatial axes in 'real' domain.
n : tuple of int, optional
FFT lengths for each axis.
overwrite : bool, optional
If True, perform FFT in-place on a temporary copy of the data
to reduce peak memory usage. Default is False.
Returns
-------
ScalarField
Transformed field with specified axes in 'k' domain.
Raises
------
ValueError
If any specified axis is not in 'real' domain.
ValueError
If any specified axis is not uniformly spaced.
Notes
-----
**Angular Wavenumber Convention**
The wavenumber axis is computed as ``k = 2π * fftfreq(n, d=dx)``,
satisfying ``k = 2π / λ``. This is the standard **angular wavenumber**
definition in physics, with units of [rad/length].
Note: This is NOT the cycle wavenumber (1/λ) commonly used in some
fields. To convert: ``k_cycle = k_angular / (2π)``.
**Sign Convention for Descending Axes (dx < 0)**
If the spatial axis is descending (dx < 0), the k-axis is sign-flipped
to preserve physical consistency with the phase factor convention
``e^{+ikx}``. This ensures that positive k corresponds to waves
propagating in the positive x direction, regardless of the data
storage order.
This convention differs from the standard FFT behavior (which ignores
axis direction) but maintains physical consistency for interferometer
simulations and wave propagation analysis.
This formula has been validated through unit tests and independent technical review
(2026-02-01). The ``2π`` factor is correctly applied, and units are
properly set as ``1/dx_unit``.
References
----------
.. [1] Press et al., Numerical Recipes (3rd ed., 2007), §12.3.2
.. [2] NumPy fftfreq documentation
.. [3] GWpy FrequencySeries (Duncan Macleod et al., SoftwareX 13, 2021)
.. [4] Jackson, Classical Electrodynamics (3rd ed., 1998), §4.2:
Fourier transform sign conventions
"""
# Default: all real-domain spatial axes
if axes is None:
axes = [
name
for name in [self._axis1_name, self._axis2_name, self._axis3_name]
if self._space_domains.get(name) == "real"
]
if not axes:
raise ValueError("No axes specified for fft_space")
# Validate axes and get integer indices
target_axes_int = []
for ax_name in axes:
ax_int = self._get_axis_index(ax_name)
if ax_int == 0:
raise ValueError(
"Cannot use fft_space on axis 0 (time/frequency axis). "
"Use fft_time instead."
)
domain = self._space_domains.get(ax_name)
if domain != "real":
raise ValueError(
f"Axis '{ax_name}' is not in 'real' domain (current: {domain}). "
f"Cannot apply fft_space."
)
# Check uniform spacing and length
ax_desc = self.axis(ax_name)
if ax_desc.size < 2:
raise ValueError(
f"FFT requires axis length >= 2, "
f"got length {ax_desc.size} for axis '{ax_name}'"
)
if not ax_desc.regular:
raise ValueError(
f"Axis '{ax_name}' is not uniformly spaced. Cannot apply FFT."
)
# Check strict monotonicity
diffs = np.diff(ax_desc.index.value)
if not (np.all(diffs > 0) or np.all(diffs < 0)):
raise ValueError(
f"Axis '{ax_name}' is not strictly monotonic. "
f"Spatial axes must be strictly ascending or descending."
)
target_axes_int.append(ax_int)
# Perform fftn
s = None
if n is not None:
s = tuple(n)
import scipy.fft as sp_fft
if overwrite:
# Create explicit copy to allow overwrite_x optimization
# This avoids creating internal temporary buffers in sp_fft
work_data = self.value.copy()
dft = sp_fft.fftn(work_data, s=s, axes=target_axes_int, overwrite_x=True)
else:
dft = sp_fft.fftn(self.value, s=s, axes=target_axes_int)
# Build new axis metadata
new_indices = [
self._axis0_index,
self._axis1_index.copy(),
self._axis2_index.copy(),
self._axis3_index.copy(),
]
new_names = list(self.axis_names)
new_space_domains = dict(self._space_domains)
for ax_name, ax_int in zip(axes, target_axes_int):
ax_desc = self.axis(ax_name)
# Use signed delta to preserve axis direction
delta = ax_desc.delta
if delta is None:
raise ValueError(
f"Axis '{ax_name}' does not have a defined spacing (delta)"
)
dx_value = getattr(delta, "value", delta) # Already signed from diff
dx_unit = getattr(delta, "unit", u.dimensionless_unscaled)
npts = dft.shape[ax_int]
# Angular wavenumber: k = 2π * fftfreq(n, d=|dx|)
# Use abs(dx) for fftfreq (expects positive spacing)
k_values = 2 * np.pi * np.fft.fftfreq(npts, d=abs(dx_value))
# If original axis was descending (dx < 0), flip k-axis sign
if dx_value < 0:
k_values = -k_values
k_unit = 1 / dx_unit
new_indices[ax_int] = k_values * k_unit
# Update axis name: x -> kx
new_name = f"k{ax_name}"
new_names[ax_int] = new_name
# Update domain
del new_space_domains[ax_name]
new_space_domains[new_name] = "k"
result = ScalarField(
dft,
unit=self.unit,
axis0=new_indices[0],
axis1=new_indices[1],
axis2=new_indices[2],
axis3=new_indices[3],
axis_names=new_names,
axis0_domain=self._axis0_domain,
space_domain=new_space_domains,
)
result._validate_domain_units()
self._propagate_gwex_attrs(result)
return result
[docs]
def ifft_space(
self,
axes: Iterable[str] | None = None,
n: Sequence[int] | None = None,
overwrite: bool = False,
) -> ScalarField:
"""Compute inverse FFT along k-space axes.
Parameters
----------
axes : iterable of str, optional
Axis names to transform (e.g., ['kx', 'ky']). If None,
transforms all spatial axes in 'k' domain.
n : tuple of int, optional
Output lengths for each axis.
overwrite : bool, optional
If True, perform IFFT in-place on a temporary copy.
Returns
-------
ScalarField
Transformed field with specified axes in 'real' domain.
Raises
------
ValueError
If any specified axis is not in 'k' domain.
"""
# Default: all k-domain spatial axes
if axes is None:
axes = [
name
for name in [self._axis1_name, self._axis2_name, self._axis3_name]
if self._space_domains.get(name) == "k"
]
if not axes:
raise ValueError("No axes specified for ifft_space")
# Validate axes and get integer indices
target_axes_int = []
for ax_name in axes:
ax_int = self._get_axis_index(ax_name)
if ax_int == 0:
raise ValueError(
"Cannot use ifft_space on axis 0. Use ifft_time instead."
)
domain = self._space_domains.get(ax_name)
if domain != "k":
raise ValueError(
f"Axis '{ax_name}' is not in 'k' domain (current: {domain}). "
f"Cannot apply ifft_space."
)
target_axes_int.append(ax_int)
# Perform ifftn
s = None
if n is not None:
s = tuple(n)
import scipy.fft as sp_fft
if overwrite:
work_data = self.value.copy()
dift = sp_fft.ifftn(work_data, s=s, axes=target_axes_int, overwrite_x=True)
else:
dift = sp_fft.ifftn(self.value, s=s, axes=target_axes_int)
# Build new axis metadata
new_indices = [
self._axis0_index,
self._axis1_index.copy(),
self._axis2_index.copy(),
self._axis3_index.copy(),
]
new_names = list(self.axis_names)
new_space_domains = dict(self._space_domains)
for ax_name, ax_int in zip(axes, target_axes_int):
# Derive real-space axis name from k-axis name
if ax_name.startswith("k"):
real_name = ax_name[1:] # kx -> x
else:
real_name = ax_name
npts = dift.shape[ax_int]
# Compute real-space coordinates from k-space
# k = 2π * fftfreq(n, d=dx) => dx = 2π / (n * |dk|)
k_axis = self.axis(ax_name).index
if len(k_axis) < 2:
raise ValueError(
f"ifft_space requires axis length >= 2, "
f"got length {len(k_axis)} for axis '{ax_name}'"
)
dk_raw = k_axis[1] - k_axis[0]
dk_value = getattr(dk_raw, "value", dk_raw)
dk_unit = getattr(dk_raw, "unit", u.dimensionless_unscaled)
# dx = 2π / (n * |dk|)
dx_value = 2 * np.pi / (npts * abs(dk_value))
dx_unit = 1 / dk_unit
# If k-axis was effectively "descending" (dk < 0),
# the reconstructed x-axis should also be descending
if dk_value < 0:
x_values = np.arange(npts - 1, -1, -1) * (-dx_value) * dx_unit
else:
x_values = np.arange(npts) * dx_value * dx_unit
new_indices[ax_int] = x_values
new_names[ax_int] = real_name
# Update domain
del new_space_domains[ax_name]
new_space_domains[real_name] = "real"
result = ScalarField(
dift,
unit=self.unit,
axis0=new_indices[0],
axis1=new_indices[1],
axis2=new_indices[2],
axis3=new_indices[3],
axis_names=new_names,
axis0_domain=self._axis0_domain,
space_domain=new_space_domains,
)
result._validate_domain_units()
self._propagate_gwex_attrs(result)
return result
[docs]
def wavelength(self, axis: str | int) -> u.Quantity:
"""Compute wavelength from wavenumber axis.
Parameters
----------
axis : str or int
The k-domain axis name or index.
Returns
-------
`~astropy.units.Quantity`
Wavelength values (:math:`\\lambda = 2\\pi / |k|`). k=0 returns inf.
Raises
------
ValueError
If the axis is not in 'k' domain.
"""
ax_name = self.axis_names[self._get_axis_index(axis)]
domain = self._space_domains.get(ax_name)
if domain != "k":
raise ValueError(
f"Axis '{ax_name}' is not in 'k' domain (current: {domain})"
)
k_index = self.axis(ax_name).index
with np.errstate(divide="ignore"):
k_val = getattr(k_index, "value", k_index)
wavelength_values = 2 * np.pi / np.abs(k_val)
return wavelength_values * (
1 / getattr(k_index, "unit", u.dimensionless_unscaled)
)
# =========================================================================
# pyroomacoustics interop
# =========================================================================
[docs]
@classmethod
def from_pyroomacoustics_field(
cls,
room: Any,
*,
grid_shape: tuple[int, ...],
source: int = 0,
mode: str = "rir",
unit: Any | None = None,
) -> ScalarField:
"""Create from pyroomacoustics room with grid-placed microphones.
Parameters
----------
room : pyroomacoustics.Room
Room with microphones on a regular spatial grid.
grid_shape : tuple of int
Spatial grid shape ``(nx, ny, nz)`` or ``(nx, ny)``.
source : int, default 0
Source index (for ``mode='rir'``).
mode : {'rir', 'signals'}
``'rir'`` for impulse responses, ``'signals'`` for mic signals.
unit : str or astropy.units.Unit, optional
Unit to assign to the data.
Returns
-------
ScalarField
"""
from gwexpy.interop import from_pyroomacoustics_field
return from_pyroomacoustics_field(
cls, room, grid_shape=grid_shape, source=source, mode=mode, unit=unit
)
# =========================================================================
# Simulation
# =========================================================================
[docs]
@classmethod
def simulate(cls, method: str, *args: Any, **kwargs: Any) -> ScalarField:
"""Generate a simulated ScalarField.
Parameters
----------
method : str
Name of the generator from ``gwexpy.noise.field``.
(e.g., 'gaussian', 'plane_wave').
*args, **kwargs
Arguments passed to the generator.
Returns
-------
ScalarField
Generated field.
Examples
--------
>>> from gwexpy.fields import ScalarField
>>> field = ScalarField.simulate('gaussian', shape=(100, 10, 10, 10))
"""
from gwexpy.noise import field
if not hasattr(field, method):
raise ValueError(
f"Unknown simulation method '{method}'. "
f"Available methods in gwexpy.noise.field: "
f"{[m for m in dir(field) if not m.startswith('_')]}"
)
func = getattr(field, method)
return func(*args, **kwargs)
# =========================================================================
# Extraction API (Phase 0.3)
# =========================================================================
[docs]
def slice_map2d(self, plane="xy", at=None):
"""Extract a 2D slice (map) from the 4D field.
Parameters
----------
plane : str, optional
The plane to extract: 'xy', 'xz', 'yz', 'tx', 'ty', 'tz'.
Default is 'xy'.
at : dict, optional
Dictionary specifying fixed values for axes not in the plane.
If None, axes with length=1 are used automatically.
Returns
-------
ScalarField
A ScalarField with the non-plane axes having length=1.
Raises
------
ValueError
If plane specification is invalid.
ValueError
If ``at`` is None and there is ambiguity about which axes to fix.
Examples
--------
>>> # Extract xy plane at a specific time and z
>>> field_2d = field.slice_map2d('xy', at={'t': 0.5 * u.s, 'z': 0.0 * u.m})
>>> field_2d.plot_map2d()
"""
from gwexpy.plot._coord import nearest_index, slice_from_index
# Parse plane specification
valid_planes = {"xy", "xz", "yz", "tx", "ty", "tz"}
# Also support k-space equivalents
plane_lower = plane.lower()
if plane_lower not in valid_planes:
# Check for k-space variants
plane_chars = set(plane_lower)
if not plane_chars.issubset({"t", "f", "x", "y", "z", "k"}):
raise ValueError(
f"Invalid plane '{plane}'. Must be one of {valid_planes} "
f"or their k-space equivalents."
)
# Determine which axes are in the plane
axis_names = list(self.axis_names)
plane_axes = []
for char in plane_lower:
# Handle 'k' prefix
if char == "k":
continue
# Find axis that starts with this character (or k-prefix version)
for ax_name in axis_names:
if ax_name == char or ax_name == f"k{char}":
if ax_name not in plane_axes:
plane_axes.append(ax_name)
break
# If plane is like "kx" or "ky", handle specially
for i, char in enumerate(plane_lower):
if char == "k" and i + 1 < len(plane_lower):
combined = f"k{plane_lower[i + 1]}"
for ax_name in axis_names:
if ax_name == combined and ax_name not in plane_axes:
plane_axes.append(ax_name)
# Build slices
all_axes = [
(0, self._axis0_name, self._axis0_index),
(1, self._axis1_name, self._axis1_index),
(2, self._axis2_name, self._axis2_index),
(3, self._axis3_name, self._axis3_index),
]
slices = [slice(None)] * 4
for ax_int, ax_name, ax_index in all_axes:
if ax_name in plane_axes or ax_name.lstrip("k") in [
p.lstrip("k") for p in plane_axes
]:
# Keep this axis
continue
# Need to fix this axis
if at is not None and ax_name in at:
value = at[ax_name]
idx = nearest_index(ax_index, value)
slices[ax_int] = slice_from_index(idx)
elif len(ax_index) == 1:
# Already length=1, use it
slices[ax_int] = slice(0, 1)
else:
raise ValueError(
f"Axis '{ax_name}' is not in plane '{plane}' and has "
f"length {len(ax_index)} > 1. Specify its value in 'at'."
)
return self[tuple(slices)]
# =========================================================================
# Visualization Methods (Phase 1)
# =========================================================================
[docs]
def plot_map2d(
self,
plane="xy",
at=None,
mode="real",
method="pcolormesh",
ax=None,
add_colorbar=True,
vmin=None,
vmax=None,
title=None,
cmap=None,
**kwargs,
):
"""Plot a 2D map (heatmap) of the field.
Parameters
----------
plane : str, optional
The plane to plot: 'xy', 'xz', 'yz', 'tx', 'ty', 'tz'.
Default is 'xy'.
at : dict, optional
Dictionary specifying fixed values for axes not in the plane.
If None, axes with length=1 are used automatically.
mode : str, optional
Component to extract from complex data:
'real', 'imag', 'abs', 'angle', 'power'. Default is 'real'.
method : str, optional
Plot method: 'pcolormesh' or 'imshow'. Default is 'pcolormesh'.
ax : matplotlib.axes.Axes, optional
Axes to plot on. If None, creates new figure.
add_colorbar : bool, optional
Whether to add a colorbar. Default is True.
vmin, vmax : float, optional
Color scale limits.
title : str, optional
Plot title.
cmap : str or Colormap, optional
Colormap to use.
**kwargs
Additional arguments passed to the plot method.
Returns
-------
tuple
(fig, ax): The matplotlib figure and axes objects.
Examples
--------
>>> fig, ax = field.plot_map2d('xy', at={'t': 0.5 * u.s, 'z': 0.0 * u.m})
"""
import matplotlib.pyplot as plt
from gwexpy.plot._coord import select_value
# Get the 2D slice
sliced = self.slice_map2d(plane=plane, at=at)
# Determine which axes are the plane axes
shape = sliced.shape
plane_axis_ints = [i for i, s in enumerate(shape) if s > 1]
if len(plane_axis_ints) < 2:
# Handle case where one dimension might also be 1
plane_axis_ints = [i for i, s in enumerate(shape) if s >= 1][:2]
if len(plane_axis_ints) < 2:
raise ValueError(
f"Cannot create 2D plot: slice has shape {shape}. "
f"Need at least 2 dimensions with size > 1."
)
ax1_int, ax2_int = plane_axis_ints[0], plane_axis_ints[1]
# Get axis indices and names
axes_info = [
(sliced._axis0_name, sliced._axis0_index),
(sliced._axis1_name, sliced._axis1_index),
(sliced._axis2_name, sliced._axis2_index),
(sliced._axis3_name, sliced._axis3_index),
]
ax1_name, ax1_index = axes_info[ax1_int]
ax2_name, ax2_index = axes_info[ax2_int]
# Extract the 2D data
# Squeeze out the length-1 dimensions while preserving order
data_4d = sliced.value
# Build transpose + squeeze to get [ax1, ax2] ordering
squeeze_axes = [i for i in range(4) if i not in (ax1_int, ax2_int)]
# Create slice to reduce to 2D
idx = [0 if i in squeeze_axes else slice(None) for i in range(4)]
data_2d = data_4d[tuple(idx)]
if data_2d.ndim > 2:
data_2d = data_2d.squeeze()
# Apply mode (real/abs/etc.)
data_2d = select_value(data_2d, mode=mode)
if hasattr(data_2d, "value"):
data_2d = data_2d.value
# Create figure if needed
if ax is None:
fig, ax = plt.subplots()
else:
fig = ax.get_figure()
# Prepare coordinates for pcolormesh (need edges)
x_coords = ax2_index.value
y_coords = ax1_index.value
# Plot
if method == "pcolormesh":
im = ax.pcolormesh(
x_coords, y_coords, data_2d, vmin=vmin, vmax=vmax, cmap=cmap, **kwargs
)
elif method == "imshow":
extent = [x_coords[0], x_coords[-1], y_coords[0], y_coords[-1]]
im = ax.imshow(
data_2d,
extent=extent,
origin="lower",
aspect="auto",
vmin=vmin,
vmax=vmax,
cmap=cmap,
**kwargs,
)
else:
raise ValueError(
f"Unknown method '{method}'. Use 'pcolormesh' or 'imshow'."
)
# Labels with units
ax.set_xlabel(f"{ax2_name} [{ax2_index.unit}]")
ax.set_ylabel(f"{ax1_name} [{ax1_index.unit}]")
if title:
ax.set_title(title)
# Colorbar
if add_colorbar:
cbar = fig.colorbar(im, ax=ax)
if self.unit is not None:
cbar.set_label(f"{mode} [{self.unit}]")
return fig, ax
[docs]
def plot_timeseries_points(
self,
points,
labels=None,
interp="nearest",
ax=None,
legend=True,
**kwargs,
):
"""Plot time series extracted at specified spatial points.
Parameters
----------
points : list of tuple
List of (x, y, z) coordinates.
labels : list of str, optional
Labels for each time series. If None, auto-generated.
interp : str, optional
Interpolation method. Default is 'nearest'.
ax : matplotlib.axes.Axes, optional
Axes to plot on. If None, creates new figure.
legend : bool, optional
Whether to show legend. Default is True.
**kwargs
Additional arguments passed to plot.
Returns
-------
tuple
(fig, ax): The matplotlib figure and axes objects.
Examples
--------
>>> points = [(1.0 * u.m, 2.0 * u.m, 3.0 * u.m)]
>>> fig, ax = field.plot_timeseries_points(points)
"""
import matplotlib.pyplot as plt
ts_list = self.extract_points(points, interp=interp)
if ax is None:
fig, ax = plt.subplots()
else:
fig = ax.get_figure()
for i, ts in enumerate(ts_list):
label = labels[i] if labels else ts.name
ax.plot(ts.times.value, ts.value, label=label, **kwargs)
# Labels
x_unit = getattr(self._axis0_index, "unit", u.dimensionless_unscaled)
ax.set_xlabel(f"{self._axis0_name} [{x_unit}]")
if self.unit is not None:
ax.set_ylabel(f"[{self.unit}]")
if legend:
ax.legend()
return fig, ax
[docs]
def plot_profile(
self,
axis,
at,
mode="real",
ax=None,
label=None,
**kwargs,
):
"""Plot a 1D profile along a specified axis.
Parameters
----------
axis : str
Axis name to plot along.
at : dict
Dictionary specifying fixed values for other axes.
mode : str, optional
Component to extract: 'real', 'imag', 'abs', 'angle', 'power'.
Default is 'real'.
ax : matplotlib.axes.Axes, optional
Axes to plot on. If None, creates new figure.
label : str, optional
Line label for legend.
**kwargs
Additional arguments passed to plot.
Returns
-------
tuple
(fig, ax): The matplotlib figure and axes objects.
Examples
--------
>>> fig, ax = field.plot_profile(
... 'x', at={'t': 0.5 * u.s, 'y': 0.0 * u.m, 'z': 0.0 * u.m}
... )
"""
import matplotlib.pyplot as plt
from gwexpy.plot._coord import select_value
axis_index, values = self.extract_profile(axis, at)
# Apply mode
values = select_value(values, mode=mode)
if ax is None:
fig, ax = plt.subplots()
else:
fig = ax.get_figure()
# Get value array
y_data = values.value if hasattr(values, "value") else values
y_unit = values.unit if hasattr(values, "unit") else None
x_data = getattr(axis_index, "value", axis_index)
ax.plot(x_data, y_data, label=label, **kwargs)
# Labels with units
x_unit = getattr(axis_index, "unit", u.dimensionless_unscaled)
ax.set_xlabel(f"{axis} [{x_unit}]")
if y_unit is not None:
ax.set_ylabel(f"{mode} [{y_unit}]")
if label:
ax.legend()
return fig, ax
# =========================================================================
# Comparison & Summary Methods (Phase 2)
# =========================================================================
[docs]
def diff(self, other, mode="diff"):
"""Compute difference or ratio between two ScalarField objects.
Parameters
----------
other : ScalarField
The field to compare against.
mode : str, optional
Comparison mode:
- 'diff': Difference (self - other)
- 'ratio': Ratio (self / other)
- 'percent': Percentage difference ((self - other) / other * 100)
Default is 'diff'.
Returns
-------
ScalarField
Result field. For 'diff', unit is same as input.
For 'ratio' and 'percent', unit is dimensionless.
Raises
------
ValueError
If mode is not recognized.
ValueError
If shapes are incompatible.
Examples
--------
>>> diff_field = field1.diff(field2)
>>> ratio_field = field1.diff(field2, mode='ratio')
"""
from astropy import units as u
if self.shape != other.shape:
raise ValueError(f"Shape mismatch: {self.shape} vs {other.shape}")
valid_modes = ("diff", "ratio", "percent")
if mode not in valid_modes:
raise ValueError(f"Invalid mode '{mode}'. Must be one of {valid_modes}.")
if mode == "diff":
result_data = self.value - other.value
result_unit = self.unit
elif mode == "ratio":
with np.errstate(divide="ignore", invalid="ignore"):
result_data = self.value / other.value
result_unit = u.dimensionless_unscaled
elif mode == "percent":
with np.errstate(divide="ignore", invalid="ignore"):
result_data = (self.value - other.value) / other.value * 100
result_unit = u.percent
else:
raise ValueError(f"Invalid mode '{mode}'")
result = ScalarField(
result_data,
unit=result_unit,
axis0=self._axis0_index,
axis1=self._axis1_index,
axis2=self._axis2_index,
axis3=self._axis3_index,
axis_names=list(self.axis_names),
axis0_domain=self._axis0_domain,
space_domain=self._space_domains,
)
self._propagate_gwex_attrs(result)
return result
[docs]
def zscore(self, baseline_t=None):
"""Compute z-score normalized field using a baseline period.
The z-score is computed as (data - mean) / std, where mean and std
are computed from the baseline period along axis 0 (time).
Parameters
----------
baseline_t : tuple of Quantity, optional
Time range (t_start, t_end) for computing baseline statistics.
If None, uses the entire time axis.
Returns
-------
ScalarField
Z-score normalized field (dimensionless).
Raises
------
ValueError
If ``axis0_domain`` is not 'time'.
ValueError
If baseline time range is outside the available data.
Examples
--------
>>> from astropy import units as u
>>> zscore_field = field.zscore(baseline_t=(0 * u.s, 1 * u.s))
"""
from astropy import units as u
from gwexpy.plot._coord import nearest_index
if self._axis0_domain != "time":
raise ValueError(
f"zscore requires axis0_domain='time', got '{self._axis0_domain}'"
)
if baseline_t is None:
# Use entire time axis
baseline_data = self.value
else:
t_start, t_end = baseline_t
# Find indices for baseline range
i_start = nearest_index(self._axis0_index, t_start)
i_end = nearest_index(self._axis0_index, t_end)
if i_start > i_end:
i_start, i_end = i_end, i_start
baseline_data = self.value[i_start : i_end + 1, ...]
# Compute mean and std along time axis
mean = np.mean(baseline_data, axis=0, keepdims=True)
std = np.std(baseline_data, axis=0, keepdims=True)
# Avoid division by zero
with np.errstate(divide="ignore", invalid="ignore"):
zscore_data = (self.value - mean) / std
zscore_data = np.nan_to_num(zscore_data, nan=0.0, posinf=0.0, neginf=0.0)
result = ScalarField(
zscore_data,
unit=u.dimensionless_unscaled,
axis0=self._axis0_index,
axis1=self._axis1_index,
axis2=self._axis2_index,
axis3=self._axis3_index,
axis_names=list(self.axis_names),
axis0_domain=self._axis0_domain,
space_domain=self._space_domains,
)
self._propagate_gwex_attrs(result)
return result
[docs]
def time_stat_map(self, stat="mean", t_range=None, plane="xy", at=None):
"""Compute a time-aggregated 2D map.
Parameters
----------
stat : str, optional
Statistic to compute: 'mean', 'std', 'rms', 'max', 'min'.
Default is 'mean'.
t_range : tuple of Quantity, optional
Time range (t_start, t_end) to aggregate over.
If None, uses the entire time axis.
plane : str, optional
Spatial plane to visualize: 'xy', 'xz', 'yz'.
Default is 'xy'.
at : dict, optional
Fixed values for axes not in the plane.
Returns
-------
ScalarField
Result field with time axis reduced to length=1.
Raises
------
ValueError
If stat is not recognized.
Examples
--------
>>> from astropy import units as u
>>> mean_map = field.time_stat_map('mean', t_range=(0 * u.s, 1 * u.s))
>>> mean_map.plot_map2d('xy')
"""
from gwexpy.plot._coord import nearest_index
valid_stats = ("mean", "std", "rms", "max", "min")
if stat not in valid_stats:
raise ValueError(f"Invalid stat '{stat}'. Must be one of {valid_stats}.")
# Get time slice
if t_range is not None:
t_start, t_end = t_range
i_start = nearest_index(self._axis0_index, t_start)
i_end = nearest_index(self._axis0_index, t_end)
if i_start > i_end:
i_start, i_end = i_end, i_start
time_slice = slice(i_start, i_end + 1)
subset = self[time_slice, :, :, :]
else:
subset = self
# Compute statistic along time axis
data = subset.value
if stat == "mean":
result_data = np.mean(data, axis=0, keepdims=True)
elif stat == "std":
result_data = np.std(data, axis=0, keepdims=True)
elif stat == "rms":
result_data = np.sqrt(np.mean(data**2, axis=0, keepdims=True))
elif stat == "max":
result_data = np.max(data, axis=0, keepdims=True)
elif stat == "min":
result_data = np.min(data, axis=0, keepdims=True)
else:
raise ValueError(f"Invalid stat '{stat}'")
# Create result with mean time for the aggregated point
if t_range is not None:
mean_time = (t_start + t_end) / 2
else:
mean_time = (self._axis0_index[0] + self._axis0_index[-1]) / 2
result = ScalarField(
result_data,
unit=self.unit,
axis0=np.array([mean_time.value]) * mean_time.unit,
axis1=subset._axis1_index,
axis2=subset._axis2_index,
axis3=subset._axis3_index,
axis_names=list(self.axis_names),
axis0_domain=self._axis0_domain,
space_domain=self._space_domains,
)
self._propagate_gwex_attrs(result)
# If plane and at are specified, further slice
if at is not None:
result = result.slice_map2d(plane=plane, at=at)
return result
[docs]
def time_space_map(self, axis="x", at=None, mode="real", reduce=None):
"""Extract a 2D time-space map (t vs one spatial axis).
Parameters
----------
axis : str, optional
Spatial axis name ('x', 'y', 'z' or k-variants).
Default is 'x'.
at : dict, optional
Fixed values for the other two spatial axes.
mode : str, optional
Component to extract: 'real', 'imag', 'abs', 'angle', 'power'.
Default is 'real'.
reduce : None
Reserved for future averaging support. Currently ignored.
Returns
-------
tuple
(t_axis, space_axis, data_2d): Quantity arrays for axes and
2D numpy array for the data.
Raises
------
ValueError
If axis is not valid.
ValueError
If ``at`` dictionary is missing required axes.
Examples
--------
>>> from astropy import units as u
>>> t, x, data = field.time_space_map('x', at={'y': 0 * u.m, 'z': 0 * u.m})
"""
from gwexpy.plot._coord import nearest_index, select_value, slice_from_index
# Map axis name to integer index
axis_int = self._get_axis_index(axis)
if axis_int == 0:
raise ValueError("Cannot use time axis as spatial axis for time_space_map")
# Determine which axes need to be fixed (all except 0 and axis_int)
all_axes = [
(0, self._axis0_name, self._axis0_index),
(1, self._axis1_name, self._axis1_index),
(2, self._axis2_name, self._axis2_index),
(3, self._axis3_name, self._axis3_index),
]
if at is None:
at = {}
slices = [slice(None)] * 4
for ax_int, ax_name, ax_index in all_axes:
if ax_int == 0 or ax_int == axis_int:
# Keep these axes
continue
if ax_name in at:
idx = nearest_index(ax_index, at[ax_name])
slices[ax_int] = slice_from_index(idx)
elif len(ax_index) == 1:
slices[ax_int] = slice(0, 1)
else:
raise ValueError(
f"Axis '{ax_name}' has length {len(ax_index)} > 1. "
f"Specify its value in 'at'."
)
# Extract the data
sliced = self[tuple(slices)]
# Shape should be (nt, 1, n, 1) or similar with 2 non-trivial dims
data_4d = sliced.value
data_2d = data_4d.squeeze()
# Apply mode
data_2d = select_value(data_2d, mode=mode)
if hasattr(data_2d, "value"):
data_2d = data_2d.value
# Get axes
t_axis = self._axis0_index
space_axis = [
self._axis0_index,
self._axis1_index,
self._axis2_index,
self._axis3_index,
][axis_int]
return t_axis, space_axis, data_2d
[docs]
def plot_time_space_map(
self,
axis="x",
at=None,
mode="real",
method="pcolormesh",
ax=None,
add_colorbar=True,
vmin=None,
vmax=None,
title=None,
cmap=None,
**kwargs,
):
"""Plot a 2D time-space map (t vs one spatial axis).
Parameters
----------
axis : str, optional
Spatial axis name. Default is 'x'.
at : dict, optional
Fixed values for other spatial axes.
mode : str, optional
Component to extract. Default is 'real'.
method : str, optional
Plot method. Default is 'pcolormesh'.
ax : matplotlib.axes.Axes, optional
Axes to plot on.
add_colorbar : bool, optional
Whether to add colorbar. Default is True.
vmin, vmax : float, optional
Color scale limits.
title : str, optional
Plot title.
cmap : str or Colormap, optional
Colormap to use.
**kwargs
Additional plot arguments.
Returns
-------
tuple
(fig, ax): The matplotlib figure and axes objects.
Examples
--------
>>> fig, ax = field.plot_time_space_map('x', at={'y': 0*u.m, 'z': 0*u.m})
"""
import matplotlib.pyplot as plt
t_axis, space_axis, data_2d = self.time_space_map(axis, at=at, mode=mode)
if ax is None:
fig, ax = plt.subplots()
else:
fig = ax.get_figure()
# Plot
if method == "pcolormesh":
im = ax.pcolormesh(
space_axis.value,
t_axis.value,
data_2d,
vmin=vmin,
vmax=vmax,
cmap=cmap,
**kwargs,
)
elif method == "imshow":
extent = [
space_axis.value[0],
space_axis.value[-1],
t_axis.value[0],
t_axis.value[-1],
]
im = ax.imshow(
data_2d,
extent=extent,
origin="lower",
aspect="auto",
vmin=vmin,
vmax=vmax,
cmap=cmap,
**kwargs,
)
else:
raise ValueError(f"Unknown method '{method}'.")
# Labels
ax.set_xlabel(f"{axis} [{space_axis.unit}]")
ax.set_ylabel(f"{self._axis0_name} [{t_axis.unit}]")
if title:
ax.set_title(title)
if add_colorbar:
cbar = fig.colorbar(im, ax=ax)
if self.unit is not None:
cbar.set_label(f"{mode} [{self.unit}]")
return fig, ax
# =========================================================================
# Signal Processing Methods (Phase 3)
# =========================================================================
[docs]
def compute_psd(self, point_or_region, **kwargs):
"""Compute power spectral density using Welch's method.
This is a convenience wrapper around :func:`~gwexpy.fields.signal.compute_psd`.
Parameters
----------
point_or_region : tuple, list of tuples, or dict
Spatial location(s) to extract:
- Single point: ``(x, y, z)`` tuple of Quantities
- Multiple points: list of ``(x, y, z)`` tuples
- Region dict: ``{'x': slice or value, 'y': ..., 'z': ...}``
**kwargs
Additional keyword arguments passed to compute_psd:
nperseg, noverlap, window, detrend, scaling, average.
Returns
-------
FrequencySeries or FrequencySeriesList
PSD estimate(s).
See Also
--------
gwexpy.fields.signal.compute_psd : Full documentation.
"""
from .signal import compute_psd
return compute_psd(self, point_or_region, **kwargs)
[docs]
def freq_space_map(self, axis, at=None, **kwargs):
"""Compute frequency-space map along a spatial axis.
This is a convenience wrapper around
:func:`~gwexpy.fields.signal.freq_space_map`.
Parameters
----------
axis : str
Spatial axis to scan along ('x', 'y', or 'z').
at : dict, optional
Fixed values for the other two spatial axes.
**kwargs
Additional keyword arguments passed to freq_space_map.
Returns
-------
ScalarField
2D frequency-space map.
See Also
--------
gwexpy.fields.signal.freq_space_map : Full documentation.
"""
from .signal import freq_space_map
return freq_space_map(self, axis, at=at, **kwargs)
[docs]
def resample(self, rate, **kwargs) -> ScalarField:
"""Resample the field along the time axis (axis 0).
Parameters
----------
rate : float, Quantity
The new sampling rate (e.g., in Hz).
**kwargs
Additional arguments passed to :meth:`gwpy.timeseries.TimeSeries.resample`.
Returns
-------
ScalarField
Resampled field.
"""
if self._axis0_domain != "time":
raise ValueError("resample requires axis0_domain='time'")
# Reshape to (time, points)
orig_shape = self.shape
data_2d = self.value.reshape(orig_shape[0], -1)
# Use scipy.signal.resample for efficient array resampling
# We need to compute the new number of samples
dt = self._axis0_index[1] - self._axis0_index[0]
if hasattr(rate, "to"):
new_fs = rate.to("Hz").value
new_dt = (1.0 / new_fs) * u.s
else:
new_fs = float(rate)
new_dt = (1.0 / new_fs) * dt.unit
duration = (orig_shape[0] * dt).to(new_dt.unit).value
new_nt = int(round(duration * new_fs))
import scipy.signal
new_data_2d = scipy.signal.resample(data_2d, new_nt, axis=0)
new_times = (
np.arange(new_nt) * (1.0 / new_fs) * new_dt.unit + self._axis0_index[0]
)
# Reshape back
new_shape = [new_nt] + list(orig_shape[1:])
result = ScalarField(
new_data_2d.reshape(new_shape),
unit=self.unit,
axis0=new_times,
axis1=self._axis1_index,
axis2=self._axis2_index,
axis3=self._axis3_index,
axis_names=self.axis_names,
axis0_domain="time",
space_domain=self._space_domains,
)
self._propagate_gwex_attrs(result)
return result
[docs]
def filter(self, *args, **kwargs) -> ScalarField:
"""Apply a filter along the time axis (axis 0).
Parameters
----------
*args, **kwargs
Filter specification. Supports same arguments as
:meth:`gwpy.timeseries.TimeSeries.filter`.
Returns
-------
ScalarField
Filtered field.
"""
if self._axis0_domain != "time":
raise ValueError("filter requires axis0_domain='time'")
# Reshape to (time, points)
orig_shape = self.shape
data_2d = self.value.reshape(orig_shape[0], -1)
# Use gwpy to parse the filter
# We need sample rate
fs = (1.0 / (self._axis0_index[1] - self._axis0_index[0])).to("Hz").value
from gwpy.signal import filter_design
from scipy import signal as scipy_signal
analog = kwargs.pop("analog", False)
unit = kwargs.pop("unit", "rad/s")
normalize_gain = kwargs.pop("normalize_gain", False)
filtfilt = kwargs.pop("filtfilt", True) # Default to True for phase consistency
filt = args[0] if len(args) == 1 else args
try:
# GWpy >= 4.0
zpk = filter_design.prepare_digital_filter(
filt,
analog=analog,
sample_rate=fs,
unit=unit,
normalize_gain=normalize_gain,
output="zpk",
)
form = "zpk"
filt_obj = zpk
except AttributeError:
# Fallback for GWpy < 4.0
form, filt_obj = filter_design.parse_filter(
filt,
analog=analog,
sample_rate=fs,
)
# Apply filter along axis 0
if form == "zpk":
sos = scipy_signal.zpk2sos(*filt_obj)
new_data_2d = (
scipy_signal.sosfiltfilt(sos, data_2d, axis=0)
if filtfilt
else scipy_signal.sosfilt(sos, data_2d, axis=0)
)
else:
b, a = filt_obj
new_data_2d = (
scipy_signal.filtfilt(b, a, data_2d, axis=0)
if filtfilt
else scipy_signal.lfilter(b, a, data_2d, axis=0)
)
# Reshape back
result = ScalarField(
new_data_2d.reshape(orig_shape),
unit=self.unit,
axis0=self._axis0_index,
axis1=self._axis1_index,
axis2=self._axis2_index,
axis3=self._axis3_index,
axis_names=self.axis_names,
axis0_domain="time",
space_domain=self._space_domains,
)
self._propagate_gwex_attrs(result)
return result
[docs]
def compute_xcorr(self, point_a, point_b, **kwargs):
"""Compute cross-correlation between two spatial points.
This is a convenience wrapper around
:func:`~gwexpy.fields.signal.compute_xcorr`.
Parameters
----------
point_a, point_b : tuple of Quantity
Spatial coordinates (x, y, z) for the two points.
**kwargs
Additional keyword arguments passed to compute_xcorr.
Returns
-------
TimeSeries
Cross-correlation function with lag axis.
See Also
--------
gwexpy.fields.signal.compute_xcorr : Full documentation.
"""
from .signal import compute_xcorr
return compute_xcorr(self, point_a, point_b, **kwargs)
[docs]
def time_delay_map(self, ref_point, plane="xy", at=None, **kwargs):
"""Compute time delay map from a reference point.
This is a convenience wrapper around
:func:`~gwexpy.fields.signal.time_delay_map`.
Parameters
----------
ref_point : tuple of Quantity
Reference point coordinates (x, y, z).
plane : str
2D plane to map: 'xy', 'xz', or 'yz'. Default 'xy'.
at : dict, optional
Fixed value for the axis not in the plane.
**kwargs
Additional keyword arguments passed to time_delay_map.
Returns
-------
ScalarField
Time delay map.
See Also
--------
gwexpy.fields.signal.time_delay_map : Full documentation.
"""
from .signal import time_delay_map
return time_delay_map(self, ref_point, plane=plane, at=at, **kwargs)
[docs]
def coherence_map(self, ref_point, plane="xy", at=None, **kwargs):
"""Compute coherence map from a reference point.
This is a convenience wrapper around
:func:`~gwexpy.fields.signal.coherence_map`.
Parameters
----------
ref_point : tuple of Quantity
Reference point coordinates (x, y, z).
plane : str
2D plane to map: 'xy', 'xz', or 'yz'. Default 'xy'.
at : dict, optional
Fixed value for the axis not in the plane.
**kwargs
Additional keyword arguments passed to coherence_map.
Returns
-------
ScalarField or FieldDict
Coherence map.
See Also
--------
gwexpy.fields.signal.coherence_map : Full documentation.
"""
from .signal import coherence_map
return coherence_map(self, ref_point, plane=plane, at=at, **kwargs)
# =========================================================================
# Spectral Density (Phase 2)
# =========================================================================
[docs]
def spectral_density(self, axis=0, **kwargs):
"""Compute spectral density along any axis.
Generalized spectral density function that works on time axis (0)
or spatial axes (1-3). Returns a new ScalarField with the transformed
axis in spectral domain.
Parameters
----------
axis : int or str
Axis to transform. Default 0 (time axis).
**kwargs
Additional arguments passed to
:func:`~gwexpy.fields.signal.spectral_density`.
See that function for full parameter list.
Returns
-------
ScalarField
Spectral density field with transformed axis.
Examples
--------
>>> # Time PSD
>>> psd_field = field.spectral_density(axis=0)
>>> psd_field.axis0_domain # 'frequency'
>>> # Spatial wavenumber spectrum
>>> kx_spec = field.spectral_density(axis='x')
See Also
--------
gwexpy.fields.signal.spectral_density : Full documentation.
psd : Convenience alias for time-axis PSD.
"""
from .signal import spectral_density
return spectral_density(self, axis=axis, **kwargs)
[docs]
def psd(self, **kwargs):
"""Compute power spectral density along time axis.
Convenience method equivalent to ``spectral_density(axis=0)``.
Uses Welch's method by default for robust PSD estimation.
Parameters
----------
**kwargs
Keyword arguments passed to :func:`~gwexpy.fields.signal.spectral_density`
or :func:`~gwexpy.fields.signal.compute_psd` (if point_or_region is used).
Common options:
- point_or_region : tuple or list, optional
If provided, computes PSD at specific spatial point(s) or region
average instead of full field. Returns FrequencySeries(List).
- method : {'welch', 'fft'}, default 'welch'
- fftlength : float, optional
Segment length in seconds (time-based specification)
- nfft : int, optional
Number of samples per segment (sample-based specification)
- overlap : float, optional
Overlap in seconds
- noverlap : int, optional
Number of overlapping samples
- window : str, window function
- scaling : {'density', 'spectrum'}
Returns
-------
ScalarField
PSD field with axis0_domain='frequency'.
Examples
--------
>>> from astropy import units as u
>>> # Time-based specification
>>> psd_field = field.psd(fftlength=1.0, overlap=0.5)
>>> # Or sample-based specification
>>> psd_field = field.psd(nfft=512, noverlap=256)
>>> psd_field.shape # (n_freq, nx, ny, nz)
See Also
--------
spectral_density : Generalized spectral density for any axis.
"""
if "point_or_region" in kwargs:
from .signal import compute_psd
point_or_region = kwargs.pop("point_or_region")
return compute_psd(self, point_or_region, **kwargs)
return self.spectral_density(axis=0, **kwargs)