from __future__ import annotations
from typing import TYPE_CHECKING, Any
from gwpy.types import Series
if TYPE_CHECKING:
from .core import FitResult, Fitter, fit_series
from .gls import GLS, GeneralizedLeastSquares
from .highlevel import fit_bootstrap_spectrum
__all__ = [
"fit_series",
"FitResult",
"Fitter",
"GeneralizedLeastSquares",
"GLS",
"fit_bootstrap_spectrum",
"enable_series_fit",
"enable_fitting_monkeypatch",
]
def _lazy_series_fit(self: Series, *args: Any, **kwargs: Any) -> Any:
from gwexpy.fitting import fit_series
return fit_series(self, *args, **kwargs)
[docs]
def enable_series_fit() -> None:
"""
Opt-in monkeypatch for gwpy.types.Series.fit.
Note: standard gwexpy classes (TimeSeries, FrequencySeries) already have the .fit() method
via inheritance. This function is generally not needed unless you are using base gwpy objects
directly.
"""
if not hasattr(Series, "fit"):
Series.fit = _lazy_series_fit
# Backward compatibility alias (used in README.md)
enable_fitting_monkeypatch = enable_series_fit
def __getattr__(name: str) -> Any:
if name in ("fit_series", "FitResult", "Fitter"):
try:
from .core import FitResult, Fitter, fit_series
except ImportError as exc: # pragma: no cover
raise ImportError(
"gwexpy.fitting requires optional dependencies (e.g. iminuit) and a working numba setup."
) from exc
if name == "fit_series":
return fit_series
if name == "FitResult":
return FitResult
return Fitter
if name in ("GeneralizedLeastSquares", "GLS"):
try:
from .gls import GLS, GeneralizedLeastSquares
except ImportError as exc: # pragma: no cover
raise ImportError(
"gwexpy.fitting requires optional dependencies (e.g. iminuit)."
) from exc
return GeneralizedLeastSquares if name == "GeneralizedLeastSquares" else GLS
if name == "fit_bootstrap_spectrum":
try:
from .highlevel import fit_bootstrap_spectrum
except ImportError as exc: # pragma: no cover
raise ImportError(
"gwexpy.fitting.highlevel requires optional dependencies."
) from exc
return fit_bootstrap_spectrum
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
def __dir__() -> list[str]:
return sorted(__all__)