from __future__ import annotations
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, cast
import numpy as np
from gwexpy.interop._registry import ConverterRegistry
from ._optional import require_optional
if TYPE_CHECKING: # pragma: no cover - type hints only
from gwexpy.timeseries import TimeSeries, TimeSeriesMatrix
[docs]
class TimeSeriesWindowDataset:
"""
Simple windowed Dataset wrapper for torch training loops.
"""
def __init__(
self,
series,
*,
window: int,
stride: int = 1,
horizon: int = 0,
labels: TimeSeries | TimeSeriesMatrix | np.ndarray | Callable | None = None,
multivariate: bool = False,
align: str = "intersection",
device=None,
dtype=None,
):
torch = require_optional("torch")
self.torch = torch
self.device = device
self.dtype = dtype
self.window = int(window)
self.stride = int(stride)
self.horizon = int(horizon)
if self.window <= 0 or self.stride <= 0:
raise ValueError("window and stride must be positive integers.")
TimeSeries = cast(Any, ConverterRegistry.get_constructor("TimeSeries"))
TimeSeriesDict = cast(Any, ConverterRegistry.get_constructor("TimeSeriesDict"))
TimeSeriesList = cast(Any, ConverterRegistry.get_constructor("TimeSeriesList"))
TimeSeriesMatrix = cast(
Any, ConverterRegistry.get_constructor("TimeSeriesMatrix")
)
from .base import to_plain_array
data_obj = series
if multivariate and isinstance(series, (TimeSeriesDict, TimeSeriesList)):
series_obj = cast(Any, series)
data_obj = series_obj.to_matrix(align=align)
if isinstance(data_obj, TimeSeriesMatrix):
matrix_obj = cast(Any, data_obj)
self.t0 = matrix_obj.t0
self.dt = matrix_obj.dt
vals = to_plain_array(matrix_obj)
self._feature_names = getattr(matrix_obj, "channel_names", None)
self.data = vals.reshape(-1, vals.shape[-1])
self.unit = None
elif isinstance(data_obj, TimeSeries):
ts_obj = cast(Any, data_obj)
self.t0 = ts_obj.t0
self.dt = ts_obj.dt
self.data = to_plain_array(ts_obj)[None, :]
self.unit = getattr(ts_obj, "unit", None)
self._feature_names = (
[ts_obj.name] if getattr(ts_obj, "name", None) else None
)
else:
raise TypeError(
f"Unsupported type for TimeSeriesWindowDataset: {type(data_obj)}"
)
self.labels = labels
if isinstance(labels, (TimeSeries, TimeSeriesMatrix)):
label_vals = to_plain_array(cast(Any, labels))
self.label_array = label_vals.reshape(-1, label_vals.shape[-1])
elif isinstance(labels, np.ndarray):
arr = labels
self.label_array = (
arr.reshape(-1, arr.shape[-1]) if arr.ndim > 1 else arr[None, :]
)
else:
self.label_array = cast(Any, None)
max_start = self.data.shape[-1] - self.window - self.horizon + 1
if max_start <= 0:
raise ValueError("window/horizon configuration yields no samples.")
self.starts = list(range(0, max_start, self.stride))
def __len__(self):
return len(self.starts)
def _slice_x(self, start: int):
end = start + self.window
x_np = self.data[:, start:end]
return self.torch.as_tensor(x_np, device=self.device, dtype=self.dtype)
def _slice_label(self, start: int, x_tensor):
if self.labels is None:
return None
if callable(self.labels):
return self.labels(x_tensor, start)
if self.label_array is None:
return None
idx = start + self.window + self.horizon - 1
if idx >= self.label_array.shape[-1]:
raise IndexError("Label index exceeds label array length.")
y_np = self.label_array[:, idx]
return self.torch.as_tensor(y_np, device=self.device, dtype=self.dtype)
def __getitem__(self, idx: int):
start = self.starts[idx]
x_tensor = self._slice_x(start)
y = self._slice_label(start, x_tensor)
return (x_tensor, y) if self.labels is not None else x_tensor
[docs]
def to_torch_dataset(
obj,
*,
window: int,
stride: int = 1,
horizon: int = 0,
labels: TimeSeries | TimeSeriesMatrix | np.ndarray | Callable | None = None,
multivariate: bool = False,
align: str = "intersection",
device=None,
dtype=None,
):
"""
Convenience wrapper to build a TimeSeriesWindowDataset.
"""
return TimeSeriesWindowDataset(
obj,
window=window,
stride=stride,
horizon=horizon,
labels=labels,
multivariate=multivariate,
align=align,
device=device,
dtype=dtype,
)
[docs]
def to_torch_dataloader(
dataset,
*,
batch_size: int = 1,
shuffle: bool = False,
num_workers: int = 0,
**kwargs,
):
"""
Create a torch DataLoader from the provided dataset.
"""
torch = require_optional("torch")
return torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
**kwargs,
)