Source code for gwexpy.types.axis_api

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, cast

from .axis import AxisDescriptor

if TYPE_CHECKING:
    from .mixin._protocols import AxisApiHost

__all__ = ["AxisApiMixin"]


[docs] class AxisApiMixin(ABC): @property @abstractmethod def axes(self) -> tuple[AxisDescriptor, ...]: """Tuple of AxisDescriptor objects for each dimension. Returns ------- tuple of AxisDescriptor Each descriptor contains the axis name and index values. """ pass @property def axis_names(self) -> tuple[str, ...]: """Names of all axes as a tuple of strings. Returns ------- tuple of str The name of each axis in order. """ return tuple(ax.name for ax in self.axes)
[docs] def axis(self, key: int | str) -> AxisDescriptor: """Get an axis descriptor by index or name. Parameters ---------- key : int or str Axis index (0-based) or name. Returns ------- AxisDescriptor The requested axis descriptor. Raises ------ KeyError If axis name not found. TypeError If key is not int or str. """ if isinstance(key, int): return self.axes[key] elif isinstance(key, str): for ax in self.axes: if ax.name == key: return ax raise KeyError(f"Axis '{key}' not found, available: {self.axis_names}") else: raise TypeError(f"Axis key must be int or str, got {type(key)}")
def _get_axis_index(self, key: int | str) -> int: if isinstance(key, int): if key < 0: key += len(self.axes) if not 0 <= key < len(self.axes): raise IndexError(f"Axis index {key} out of range") return key elif isinstance(key, str): names = self.axis_names if key in names: return names.index(key) raise KeyError(f"Axis '{key}' not found in {names}") else: raise TypeError(f"Axis key must be int or str, got {type(key)}")
[docs] def rename_axes( self: AxisApiHost, mapping: dict[str, str], *, inplace: bool = False ) -> Any: """Rename axes using a mapping of old names to new names. Parameters ---------- mapping : dict Mapping from old axis names to new names. inplace : bool, optional If True, modify in place. Otherwise return a copy. Returns ------- self or copy """ if not inplace: new_obj = self.copy() new_obj.rename_axes(mapping, inplace=True) return new_obj old_names = self.axis_names new_names_list = list(old_names) for old, new in mapping.items(): if old not in old_names: raise ValueError(f"Axis '{old}' not found in {old_names}") idx = old_names.index(old) new_names_list[idx] = new if len(set(new_names_list)) != len(new_names_list): raise ValueError(f"Duplicate axis names resulted: {new_names_list}") for i, new_name in enumerate(new_names_list): if new_name != old_names[i]: self._set_axis_name(i, new_name) return self
def _set_axis_name(self, index: int, name: str): pass
[docs] def isel(self, indexers=None, **kwargs): """Select by integer indices along specified axes. Parameters ---------- indexers : dict, optional Mapping of axis name/index to integer index or slice. **kwargs Additional indexers as keyword arguments. Returns ------- subset Sliced array. """ if indexers is None: indexers = {} indexers = {**indexers, **kwargs} num_axes = len(self.axes) slices = [slice(None)] * num_axes for key, sel in indexers.items(): axis_idx = self._get_axis_index(key) slices[axis_idx] = sel return self._isel_tuple(tuple(slices))
[docs] def sel(self, indexers=None, *, method="nearest", **kwargs): """Select by coordinate values along specified axes. Parameters ---------- indexers : dict, optional Mapping of axis name to coordinate value or slice. method : str, optional Selection method: 'nearest' (default). **kwargs Additional indexers as keyword arguments. Returns ------- subset Sliced array at nearest coordinate values. """ if indexers is None: indexers = {} indexers = {**indexers, **kwargs} isel_indexers = {} for key, val in indexers.items(): ax = self.axis(key) ax_idx = self._get_axis_index(key) if isinstance(val, slice): isel_indexers[ax_idx] = ax.iloc_slice(val) else: isel_indexers[ax_idx] = ax.iloc_nearest(val) return self.isel(isel_indexers)
[docs] def swapaxes(self: AxisApiHost, axis1: int | str, axis2: int | str) -> Any: idx1 = self._get_axis_index(axis1) idx2 = self._get_axis_index(axis2) if idx1 == idx2: return self.copy() return self._swapaxes_int(idx1, idx2)
[docs] def transpose(self, *axes): """Permute the dimensions of an array.""" # Normalize axes ndim = len(self.axes) if not axes: axes = tuple(range(ndim))[::-1] elif len(axes) == 1 and isinstance(axes[0], (tuple, list)): axes = tuple(axes[0]) axes = tuple(axes) if len(axes) != ndim: raise ValueError("axes don't match array") # Convert to int indices perm_int = [self._get_axis_index(ax) for ax in axes] return self._transpose_int(tuple(perm_int))
@property def T(self): return self.transpose() @abstractmethod def _isel_tuple(self, item_tuple): pass @abstractmethod def _swapaxes_int(self, a: int, b: int): pass def _transpose_int(self, axes: tuple[int, ...]): """Default transpose implementation using mixin hooks if subclass handles simple swaps? Or delegate to super().transpose then fix metadata. """ # Since we are mixin, we expect 'super()' to be the numpy array usually. # But calling super().transpose() might call generic object implementation? # We rely on subclass MRO to trigger numpy implementation. # But wait, if AxisApiMixin is first, super() is the next class (GwpyArray). base = cast(Any, super()) new_obj = base.transpose(*axes) # new_obj is the transposed array (view). # We need to reorder metadata. # Original axes list old_axes_info = list(zip(range(len(self.axes)), self.axis_names)) # Apply permutation to info # axes is the new order of dimension lookups. # i.e. new_obj axis i comes from self axis axes[i]. if hasattr(new_obj, "_set_axis_name"): for i, origin_idx in enumerate(axes): origin_name = old_axes_info[origin_idx][1] new_obj._set_axis_name(i, origin_name) return new_obj