from __future__ import annotations
from typing import Any, cast
import numpy as np
from astropy import units as u
from gwexpy.types.metadata import MetaDataDict, MetaDataMatrix
from gwexpy.types.mixin import PhaseMethodsMixin
from gwexpy.types.seriesmatrix import SeriesMatrix
from .collections import SpectrogramDict, SpectrogramList
from .matrix_analysis import SpectrogramMatrixAnalysisMixin
from .matrix_core import SpectrogramMatrixCoreMixin
from .spectrogram import Spectrogram
[docs]
class SpectrogramMatrix( # type: ignore[misc]
PhaseMethodsMixin,
SpectrogramMatrixCoreMixin,
SpectrogramMatrixAnalysisMixin,
SeriesMatrix,
):
"""
Evaluation Matrix for Spectrograms (Time-Frequency maps).
This class represents a collection of Spectrograms, structured as either:
- 3D: (Batch, Time, Frequency)
- 4D: (Row, Col, Time, Frequency)
It inherits from SeriesMatrix, providing powerful indexing, metadata management,
and analysis capabilities (slicing, interpolation, statistics).
Serialization
-------------
Pickle round-trips are supported via a custom ``__reduce__``/``__setstate__``
that appends ``__dict__`` to the ndarray state. This preserves axis metadata
such as ``times``/``frequencies``, ``rows``/``cols``, and ``meta`` as long as
they live in ``__dict__``. Attributes stored elsewhere or pointing to external
resources still require higher-level I/O (e.g., HDF5) for full fidelity.
"""
series_class = Spectrogram
dict_class = SpectrogramDict
list_class = SpectrogramList
def __new__(
cls,
data,
times=None,
frequencies=None,
unit=None,
name=None,
rows=None,
cols=None,
meta=None,
**kwargs,
):
# Handle alias
if times is None:
times = kwargs.get("xindex")
# SeriesMatrix expects 'xindex' and 'xunit' etc.
# We assume 'data' might be (N, M, Time, Freq) or (N, Time, Freq).
# We pass xindex=times.
# We first let SeriesMatrix handle normalization of N, M and MetaDataMatrix.
# However, SeriesMatrix.__new__ behavior for 4D/nD data depends on validation.
# SeriesMatrix validation usually assumes 3D.
# We may need to bypass or adjust SeriesMatrix.__new__ validation for 4D data if it's too strict.
# For now, we try to perform basic setup and call super().__new__ via np.ndarray mechanism
# but SeriesMatrix does a lot of heavy lifting in __new__.
# Strategy: adapt arguments to SeriesMatrix signature
# times -> xindex
# Note: SeriesMatrix input normalization might flatten extra dims if not careful.
# Check gwexpy/types/series_matrix_validation.py: _normalize_input handles 3D.
# For 4D specific handling, we might need to manually prep or rely on SeriesMatrix letting it pass?
# Actually SeriesMatrixValidationMixin _normalize_input mainly handles 1D, 2D, 3D.
# If data is 4D, SeriesMatrix _normalize_input might fail or treat it oddly.
# Let's verify _normalize_input logic (Step 1167).
# It has blocks for Scalar, Series, Array, 1D/2D, 3D. It does NOT explicitly handle 4D.
# So we might need to override behavior or pre-process data to be SeriesMatrix-compatible (stored as object array?)
# NO, we want numeric array.
# If data is 4D (N, M, T, F), SeriesMatrix assumes (Row, Col, Sample).
# If we want to use SeriesMatrix infrastructure, we must respect the 3-axis structure `(Row, Col, X)`?
# Integrating 4D directly into SeriesMatrix (nd=4) might break many assumptions in `series_matrix_core` (e.g. shape3D return).
# ALTERNATIVE: Use Object Array of Spectrograms? No, expensive.
# ALTERNATIVE: Treat Freq axis as implicit?
# If we invoke SeriesMatrix, it calls `_normalize_input`.
# If we just call `np.array(data).view(cls)`, we bypass SeriesMatrix.__new__ logic entirely?
# But we want mixins.
# Since we inherit SeriesMatrix, calling SeriesMatrix(data...) creates a new object using SeriesMatrix.__new__.
# Let's implement a custom __new__ that handles the 4D init, sets properties, and returns the view,
# mimicking SeriesMatrix.__new__ but tailored for 4D.
# ... Wait, if we inherit SeriesMatrix, `super()` refers to SeriesMatrix.
# If we don't call `super().__new__`, we skip its logic. That's fine if we replicate what we need.
obj = np.asarray(data).view(cls)
# Set Spectrogram-specific props
obj.times = times # sets xindex via CoreMixin
obj.frequencies = frequencies
# Set metadata manually or via helpers if available.
# Only do basic setup here to replicate old behavior + SeriesMatrix props
obj.name = name
obj.unit = unit # logic for unit array vs scalar unit needed?
# Setup MetaDataMatrix using rows/cols logic from previous implementation
def _entries_len(entries):
return len(entries) if entries is not None else None
if obj.ndim == 3: # (Batch, Time, Freq)
N = obj.shape[0]
# ... (same logic as before for rows/cols) ...
# Simplify for brevity or reuse logic?
row_len = _entries_len(rows)
col_len = _entries_len(cols)
use_grid = row_len and col_len and row_len * col_len == N
if use_grid:
obj.rows = MetaDataDict(rows, expected_size=row_len, key_prefix="row")
obj.cols = MetaDataDict(cols, expected_size=col_len, key_prefix="col")
obj.meta = MetaDataMatrix(meta, shape=(row_len, col_len))
else:
obj.rows = MetaDataDict(rows, expected_size=N, key_prefix="batch")
obj.cols = MetaDataDict(None, expected_size=1, key_prefix="col")
obj.meta = MetaDataMatrix(meta, shape=(N, 1))
elif obj.ndim == 4: # (Row, Col, Time, Freq)
nrow, ncol = obj.shape[:2]
obj.rows = MetaDataDict(rows, expected_size=nrow, key_prefix="row")
obj.cols = MetaDataDict(cols, expected_size=ncol, key_prefix="col")
obj.meta = MetaDataMatrix(meta, shape=(nrow, ncol))
else:
# Fallback
obj.rows = None # type: ignore[assignment]
obj.cols = None # type: ignore[assignment]
obj.meta = None # type: ignore[assignment]
# Apply unit to metadata if needed (only if not explicitly set in meta)
if unit is not None and obj.meta is not None:
for m in obj.meta.reshape(-1):
# MetaData defaults to dimensionless_unscaled, so check for that too
if m.unit is None or m.unit == u.dimensionless_unscaled:
m.unit = unit
# If no global unit was provided, infer it from metadata if consistent
if obj.unit is None and obj.meta is not None:
meta_units = {m.unit for m in obj.meta.reshape(-1) if m is not None}
if len(meta_units) == 1:
obj.unit = next(iter(meta_units))
obj.epoch = kwargs.get("epoch", 0.0)
obj._value = obj.view(np.ndarray)
return obj
def __array_finalize__(self, obj: Any) -> None:
if obj is None:
return
super().__array_finalize__(obj)
self.frequencies = getattr(obj, "frequencies", None)
# Propagate custom attributes (similar to TimeSeriesCore)
for key in getattr(obj, "__dict__", {}):
if key.startswith("_gwex_"):
setattr(self, key, getattr(obj, key))
if not hasattr(self, "_value"):
self._value = self.view(np.ndarray)
def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
"""
Override SeriesMatrix.__array_ufunc__ to correctly handle SpectrogramMatrix structure
(Batch, Time, Freq) or (Row, Col, Time, Freq).
Per-element units are preserved in MetaDataMatrix:
- Scalar operations: apply ufunc to each element's unit individually
- Binary matrix operations: check per-element unit compatibility and raise
UnitConversionError if any pair is incompatible
"""
from gwexpy.types.metadata import MetaData, MetaDataMatrix
if method != "__call__":
# Defer to ndarray (e.g. at, reduce) - might lose metadata but SeriesMatrix does too
args = [
inp.view(np.ndarray) if isinstance(inp, SpectrogramMatrix) else inp
for inp in inputs
]
return super(SeriesMatrix, self).__array_ufunc__(
ufunc, method, *args, **kwargs
)
# Identify ufunc category for unit handling
_ADD_SUB_UFUNCS = {np.add, np.subtract}
_COMPARISON_UFUNCS = {
np.less,
np.less_equal,
np.equal,
np.not_equal,
np.greater,
np.greater_equal,
}
_MUL_DIV_UFUNCS = {np.multiply, np.divide, np.floor_divide, np.true_divide}
# 1. Unpack inputs
args = []
sgm_inputs = [] # SpectrogramMatrix instances
scalar_inputs = [] # Scalars/units for unit arithmetic
for inp in inputs:
if isinstance(inp, SpectrogramMatrix):
args.append(inp.view(np.ndarray))
sgm_inputs.append(inp)
elif isinstance(inp, (u.Quantity, np.ndarray, float, int, complex)):
val = getattr(inp, "value", inp)
args.append(np.asarray(val))
scalar_inputs.append(inp)
elif isinstance(inp, u.UnitBase):
args.append(1.0) # Unit acts as multiplier
scalar_inputs.append(inp)
else:
return NotImplemented
if not sgm_inputs:
return NotImplemented
main = sgm_inputs[0]
# 2. Compute Data
try:
result_data = ufunc(*args, **kwargs)
except (TypeError, ValueError, u.UnitConversionError):
return NotImplemented
# 3. Handle per-element unit propagation
# Determine if this is a scalar op (1 matrix) or binary matrix op (2+ matrices)
is_scalar_op = len(sgm_inputs) == 1
is_binary_matrix_op = len(sgm_inputs) >= 2
new_meta = None
if main.meta is not None:
meta_shape = main.meta.shape
new_meta_arr = np.empty(meta_shape, dtype=object)
if is_scalar_op:
# Scalar operation: apply ufunc to each element's unit
# Get scalar unit(s)
scalar_unit = u.dimensionless_unscaled
for sc in scalar_inputs:
if isinstance(sc, u.UnitBase):
scalar_unit = sc
elif isinstance(sc, u.Quantity):
scalar_unit = sc.unit
# else: dimensionless
for idx in np.ndindex(meta_shape):
old_meta = cast(MetaData, main.meta[idx])
old_unit = (
old_meta.unit if old_meta.unit else u.dimensionless_unscaled
)
try:
# Apply ufunc to units
q_result = (
ufunc(u.Quantity(1, old_unit), u.Quantity(1, scalar_unit))
if len(inputs) == 2
else ufunc(u.Quantity(1, old_unit))
)
new_unit = (
q_result.unit if hasattr(q_result, "unit") else old_unit
)
except (TypeError, ValueError, u.UnitConversionError):
new_unit = old_unit
new_meta_arr[idx] = MetaData(
name=old_meta.name,
channel=old_meta.channel,
unit=new_unit,
)
elif is_binary_matrix_op:
# Binary matrix operation: check per-element unit compatibility
other_sgm = sgm_inputs[1] if len(sgm_inputs) > 1 else None
if other_sgm is not None and other_sgm.meta is not None:
# Check shape compatibility
if main.meta.shape != other_sgm.meta.shape:
raise ValueError(
f"Metadata shape mismatch: {main.meta.shape} vs {other_sgm.meta.shape}"
)
for idx in np.ndindex(meta_shape):
m1 = cast(MetaData, main.meta[idx])
m2 = cast(MetaData, other_sgm.meta[idx])
u1 = m1.unit if m1.unit else u.dimensionless_unscaled
u2 = m2.unit if m2.unit else u.dimensionless_unscaled
# Check strict unit equality for add/sub/comparison
# Following SeriesMatrix check_add_sub_compatibility:
# u0 != uk raises UnitConversionError (even for equivalent units like m vs cm)
if ufunc in _ADD_SUB_UFUNCS or ufunc in _COMPARISON_UFUNCS:
if u1 != u2:
raise u.UnitConversionError(
f"Unit mismatch at element {idx}: {u1} vs {u2}"
)
new_unit = u1 # Preserve first unit for add/sub
if ufunc in _COMPARISON_UFUNCS:
new_unit = u.dimensionless_unscaled
elif ufunc in _MUL_DIV_UFUNCS:
if ufunc == np.multiply:
new_unit = u1 * u2
else:
new_unit = u1 / u2
else:
# Default: try to compute
try:
q_result = ufunc(u.Quantity(1, u1), u.Quantity(1, u2))
new_unit = (
q_result.unit if hasattr(q_result, "unit") else u1
)
except (TypeError, ValueError, u.UnitConversionError) as e:
if isinstance(e, u.UnitConversionError):
raise
new_unit = u1
new_meta_arr[idx] = MetaData(
name=m1.name,
channel=m1.channel,
unit=new_unit,
)
else:
# Other matrix has no meta; keep main's meta
new_meta_arr = main.meta.copy()
new_meta = MetaDataMatrix(new_meta_arr)
def _infer_unit(meta):
if meta is None:
return None
meta_units = {m.unit for m in meta.reshape(-1) if m is not None}
if len(meta_units) == 1:
return next(iter(meta_units))
return None
# Reconstruct SpectrogramMatrix
if result_data.shape == main.shape:
obj = self.__class__(
result_data,
times=main.times,
frequencies=main.frequencies,
rows=main.rows,
cols=main.cols,
meta=new_meta,
name=main.name,
unit=_infer_unit(new_meta),
)
return obj
return result_data
def __mul__(self, other):
"""Multiply by scalar, unit, or matrix."""
# Explicitly handle u.UnitBase to avoid astropy ufunc precedence issues
if isinstance(other, u.UnitBase):
return np.multiply(self, u.Quantity(1, other))
return np.multiply(self, other)
def __rmul__(self, other):
"""Right multiply by scalar, unit, or matrix."""
if isinstance(other, u.UnitBase):
return np.multiply(u.Quantity(1, other), self)
return np.multiply(other, self)
def __truediv__(self, other):
"""Divide by scalar, unit, or matrix."""
if isinstance(other, u.UnitBase):
return np.divide(self, u.Quantity(1, other))
return np.divide(self, other)
def __rtruediv__(self, other):
"""Right divide by scalar, unit, or matrix."""
if isinstance(other, u.UnitBase):
return np.divide(u.Quantity(1, other), self)
return np.divide(other, self)
def __add__(self, other):
"""Add scalar/quantity or matrix."""
return np.add(self, other)
def __radd__(self, other):
"""Right add."""
return np.add(other, self)
def __sub__(self, other):
"""Subtract scalar/quantity or matrix."""
return np.subtract(self, other)
def __rsub__(self, other):
"""Right subtract."""
return np.subtract(other, self)
[docs]
def row_keys(self):
return tuple(self.rows.keys()) if self.rows else tuple()
[docs]
def col_keys(self):
return tuple(self.cols.keys()) if self.cols else tuple()
[docs]
def is_compatible(self, other: Any) -> bool:
"""
Check compatibility with another SpectrogramMatrix/object.
Overrides SeriesMatrix.is_compatible to avoid loop range issues due to
mismatch between data shape (Time axis) and metadata shape (Batch/Col).
"""
# 1. Type check
if not isinstance(other, type(self)):
# Fallback or strict check? SeriesMatrix falls back to array shape check.
if hasattr(other, "shape") and np.shape(self) != np.shape(other):
raise ValueError(
f"shape does not match: {self.shape} vs {np.shape(other)}"
)
return True # assume compatible if shapes match and not SpectrogramMatrix
# 2. Shape check
if self.shape != other.shape:
raise ValueError(
f"matrix shape does not match: {self.shape} vs {other.shape}"
)
# 3. Times/Xindex Check
# Check units
t_unit_self = getattr(self.times, "unit", None)
t_unit_other = getattr(other.times, "unit", None)
if (
t_unit_self != t_unit_other
): # Simple equality check sufficient for same implementation
# Try convert? SeriesMatrix logic is strict about unit object equality or equivalence
if t_unit_self is not None and t_unit_other is not None:
if not u.Unit(t_unit_self).is_equivalent(u.Unit(t_unit_other)):
raise ValueError(
f"times unit does not match: {t_unit_self} vs {t_unit_other}"
)
# Check dx/content (for contiguous check, usually handled by caller, but is_compatible checks xindex content equality?)
# SeriesMatrix.is_compatible checks xindex equality if dx matches or fallback.
# But we only need unit compatibility usually for ops?
# is_contiguous calls is_compatible.
# Let's keep it simple: check units match. Content matching is handled by append logic (overlap check etc).
# 4. Meta/Channel Unit consistency
if self.meta is None and other.meta is None:
return True
if self.meta is None or other.meta is None:
raise ValueError("Metadata mismatch: one has metadata, the other does not")
if self.meta.shape != other.meta.shape:
# Should match if shapes match (unless metadata structure differs profoundly)
# But let's proceed to loop over valid meta range
raise ValueError(
f"metadata shape mismatch: {self.meta.shape} vs {other.meta.shape}"
)
for i in range(self.meta.shape[0]):
for j in range(self.meta.shape[1]):
u1 = self.meta[i, j].unit
u2 = other.meta[i, j].unit
if u1 != u2:
# Allow None vs None
if u1 is None and u2 is None:
continue
if u1 is None or u2 is None:
raise ValueError(
f"Unit mismatch at meta ({i},{j}): {u1} vs {u2}"
)
if not u1.is_equivalent(u2):
raise ValueError(
f"Unit mismatch at meta ({i},{j}): {u1} vs {u2}"
)
return True
[docs]
def row_index(self, key):
if not self.rows:
raise KeyError(f"Invalid row key: {key}")
try:
return list(self.row_keys()).index(key)
except ValueError:
raise KeyError(f"Invalid row key: {key}")
[docs]
def col_index(self, key):
if not self.cols:
raise KeyError(f"Invalid column key: {key}")
try:
return list(self.col_keys()).index(key)
except ValueError:
raise KeyError(f"Invalid column key: {key}")
def __getitem__(self, key):
from gwexpy.types.seriesmatrix_validation import _slice_metadata_dict
# Handle label-based indexing
if isinstance(key, str):
try:
key = self.row_index(key)
except KeyError:
raise
if (
isinstance(key, (list, np.ndarray))
and len(key) > 0
and isinstance(key[0], str)
):
key = [self.row_index(k) for k in key]
# Handle tuple keys (Row, Col) or (Row, Col, Time, Freq)
if isinstance(key, tuple):
new_key = list(key)
# Row index (0)
if len(new_key) > 0:
if isinstance(new_key[0], str):
new_key[0] = self.row_index(new_key[0])
elif (
isinstance(new_key[0], (list, np.ndarray))
and len(new_key[0]) > 0
and isinstance(new_key[0][0], str)
):
new_key[0] = [self.row_index(k) for k in new_key[0]]
# Col index (1) - only if we have at least 2 dims relevant to metadata (4D case or 3D with abuse?)
# SpectrogramMatrix 4D: (Row, Col, Time, Freq). 3D: (Batch, Time, Freq).
# For 3D, col index is not applicable in the same way, but let's assume standard behavior.
if len(new_key) > 1:
# Check if second element is string
if isinstance(new_key[1], str):
try:
new_key[1] = self.col_index(new_key[1])
except (KeyError, IndexError):
# If columns are not defined or key not found, it might be a time-slice?
# But for 4D matrix, dim 1 IS Col.
if self.ndim == 4:
raise
pass
elif (
isinstance(new_key[1], (list, np.ndarray))
and len(new_key[1]) > 0
and isinstance(new_key[1][0], str)
):
if self.ndim == 4:
new_key[1] = [self.col_index(k) for k in new_key[1]]
key = tuple(new_key)
# Access raw data
raw_data = self.view(np.ndarray)[key]
# Check for scalar element extraction (returning Spectrogram)
is_single_element = False
r_idx, c_idx = 0, 0
if self.ndim == 3: # (Batch, Time, Freq)
if isinstance(key, (int, np.integer)):
is_single_element = True
r_idx = int(key)
c_idx = 0
elif self.ndim == 4: # (Row, Col, Time, Freq)
if isinstance(key, tuple) and len(key) >= 2:
r, c = key[0], key[1]
if isinstance(r, (int, np.integer)) and isinstance(
c, (int, np.integer)
):
is_single_element = True
r_idx, c_idx = int(r), int(c)
if is_single_element:
# Return Spectrogram
meta = self.meta[r_idx, c_idx] if self.meta is not None else None
unit = meta.unit if meta else self.unit
name = meta.name if meta and meta.name else self.name
channel = meta.channel if meta else None
# raw_data should be (Time, Freq)
if raw_data.ndim != 2:
# Should not happen if indices are correct for 3D/4D
raise ValueError(
f"Extracted data has wrong dimension for Spectrogram: {raw_data.ndim} (expected 2)"
)
return self.series_class(
raw_data,
times=self.times,
frequencies=self.frequencies,
unit=unit,
name=name,
channel=channel,
epoch=getattr(self, "epoch", None),
)
# Return Sub-Matrix
# We assume raw_data is ndarray. View as SpectrogramMatrix.
ret = np.asarray(raw_data).view(type(self))
ret._value = ret.view(np.ndarray)
# Propagate global props
ret.times = getattr(self, "times", None)
ret.frequencies = getattr(self, "frequencies", None)
ret.unit = getattr(self, "unit", None)
ret.epoch = getattr(self, "epoch", None)
# Propagate/Slice Metadata (Rows, Cols, Meta)
# This is complex for general slicing. Simplification:
# If ndim preserved, try to slice rows/cols.
# If ndim reduced (e.g. 4D -> 3D), adjust.
# Basic case: Batch slicing on 3D or Row slicing on 4D
main_key = key[0] if isinstance(key, tuple) else key
# 3D: (Batch, T, F) -> Slice batch
if self.ndim == 3 and ret.ndim == 3:
if self.rows:
ret.rows = _slice_metadata_dict(self.rows, main_key, "row")
if self.meta is not None:
# meta is (N, 1)
try:
sliced = self.meta[main_key]
if isinstance(sliced, np.ndarray):
ret.meta = sliced.view(MetaDataMatrix) # type: ignore[assignment]
else:
ret.meta = sliced
except (IndexError, TypeError):
pass
# 4D: (Row, Col, T, F) -> Slice row, maybe col
elif self.ndim == 4 and ret.ndim == 4:
r_key = key[0] if isinstance(key, tuple) else key
c_key = key[1] if isinstance(key, tuple) and len(key) > 1 else slice(None)
if self.rows:
ret.rows = _slice_metadata_dict(self.rows, r_key, "row")
if self.cols:
ret.cols = _slice_metadata_dict(self.cols, c_key, "col")
if self.meta is not None:
try:
# meta is (Row, Col)
# If key is simple tuple (slice, slice)
if isinstance(key, tuple) and len(key) <= 2:
sliced = self.meta[key]
if isinstance(sliced, np.ndarray):
ret.meta = sliced.view(MetaDataMatrix) # type: ignore[assignment]
else:
ret.meta = sliced
else:
# complex slicing?
pass
except (IndexError, TypeError, KeyError):
pass
# 4D -> 3D reduction (e.g. slice out one Row or one Col)
elif self.ndim == 4 and ret.ndim == 3:
# Case A: row is scalar, col is slice -> result Batch is Col
# Case B: row is slice, col is scalar -> result Batch is Row
r_idx = key[0] if isinstance(key, tuple) else key
col_idx: Any = (
key[1] if isinstance(key, tuple) and len(key) > 1 else slice(None)
)
is_row_scalar = isinstance(r_idx, (int, np.integer))
is_col_scalar = isinstance(col_idx, (int, np.integer))
if is_row_scalar:
# Batch = Col
if self.cols:
ret.rows = _slice_metadata_dict(self.cols, col_idx, "row")
elif is_col_scalar:
# Batch = Row
if self.rows:
ret.rows = _slice_metadata_dict(self.rows, r_idx, "row")
if self.meta is not None:
try:
sliced = self.meta[r_idx, c_idx]
if isinstance(sliced, np.ndarray):
# Ensure meta for 3D is (N, 1)
ret.meta = sliced.reshape(-1, 1).view(MetaDataMatrix)
else:
# Scalar meta?
pass
except (IndexError, TypeError, KeyError):
pass
# Propagate custom attributes
for k, v in getattr(self, "__dict__", {}).items():
if k.startswith("_gwex_"):
ret.__dict__[k] = v
return ret
[docs]
def to_series_2Dlist(self):
"""Convert matrix to a 2D nested list of Spectrogram objects."""
r_keys = self.row_keys()
c_keys = self.col_keys()
if self.ndim == 3:
return [[self[i] for _ in range(1)] for i in range(len(r_keys))]
return [[self[i, j] for j in range(len(c_keys))] for i in range(len(r_keys))]
[docs]
def to_series_1Dlist(self):
"""Convert matrix to a flat 1D list of Spectrogram objects."""
r_keys = self.row_keys()
c_keys = self.col_keys()
results = []
if self.ndim == 3:
for i in range(len(r_keys)):
results.append(self[i])
elif self.ndim == 4:
for i in range(len(r_keys)):
for j in range(len(c_keys)):
results.append(self[i, j])
else:
raise ValueError(f"Unsupported SpectrogramMatrix dimension: {self.ndim}")
return results
[docs]
def to_list(self):
"""Convert to SpectrogramList."""
from .collections import SpectrogramList
return SpectrogramList(self.to_series_1Dlist())
[docs]
def to_dict(self):
"""Convert to SpectrogramDict."""
from .collections import SpectrogramDict
r_keys = self.row_keys()
c_keys = self.col_keys()
results = SpectrogramDict()
if self.ndim == 3:
for i, rk in enumerate(r_keys):
results[rk] = self[i]
elif self.ndim == 4:
for i, rk in enumerate(r_keys):
for j, ck in enumerate(c_keys):
if len(c_keys) == 1:
results[rk] = self[i, j]
else:
results[(rk, ck)] = self[i, j]
return results
def _all_element_units_equivalent(self):
"""Check whether all element units are mutually equivalent."""
if self.meta is None:
return True, self.unit
ref_unit = self.meta[0, 0].unit
for m in self.meta.reshape(-1):
if m.unit is None:
continue
if not m.unit.is_equivalent(ref_unit):
return False, ref_unit
return True, ref_unit
@property
def shape3D(self):
# Override Base logic to return relevant 3D view (Batch, Time, Freq) for display?
# Or (Row, Col, Time) if we treat Freq as hidden dim?
# For uniformity with SeriesMatrix which is (Row, Col, Sample),
# if we are 4D (Row, Col, Time, Freq), we might want to return (Row, Col, Time) as 'main' shape with _x_axis_index logic?
# But core checks shape[-1].
return self.shape
[docs]
def plot_summary(self, **kwargs):
"""
Plot Matrix as side-by-side Spectrograms and percentile summaries.
"""
from gwexpy.plot.plot import plot_summary
return plot_summary(self, **kwargs)
def __reduce__(self):
"""
Customize pickle serialization to ensure metadata preservation.
Returns standard numpy reduce tuple but appends __dict__ only if not automatically handled.
"""
picked = list(super().__reduce__())
# picked is [func, args, state]
# state is (version, shape, dtype, isFortran, rawdata)
state = picked[2]
# Append our __dict__ to state tuple to ensure it's saved
full_state = state + (self.__dict__,)
picked[2] = full_state
return tuple(picked)
def __setstate__(self, state):
"""
Restore state from pickle.
"""
# The last element contains our __dict__
my_dict = state[-1]
# The rest is for numpy
np_state = state[:-1]
super().__setstate__(np_state)
self.__dict__.update(my_dict)