"""Reusable validator functions and composable pipelines for value normalization.
Provides functions for composable inbound/outbound value normalization
(e.g., converting between pandas DataFrames and dictionaries).
"""
from collections.abc import Callable, Iterable
import math
import os
import pathlib
from typing import Any
from dotenv import load_dotenv
import numpy as np
import pandas as pd
import polars as pl
import pybamm
from pybamm.expression_tree.operations.serialise import convert_symbol_to_json
from scipy.integrate import cumulative_trapezoid
from scipy.stats import t as t_dist
from .errors import IonworksError
# --- DataFrame Backend Configuration ---------------------------------------- #
# Load .env file before reading environment variables
load_dotenv()
# Type alias for DataFrame (pandas or polars)
DataFrame = pd.DataFrame | pl.DataFrame
def _get_default_backend() -> str:
"""Get default backend from environment variable or fall back to 'polars'."""
env_val = os.getenv("IONWORKS_DATAFRAME_BACKEND", "polars").lower()
if env_val not in ("polars", "pandas"):
return "polars"
return env_val
# Module-level configuration for DataFrame return type
# Initialized from IONWORKS_DATAFRAME_BACKEND env var, defaults to "polars"
_dataframe_backend: str = _get_default_backend()
[docs]
def set_dataframe_backend(backend: str) -> None:
"""Set the default DataFrame backend for data fetching.
This overrides the IONWORKS_DATAFRAME_BACKEND environment variable.
Parameters
----------
backend : str
DataFrame backend to use: "polars" or "pandas".
Raises
------
ValueError
If backend is not "polars" or "pandas".
"""
global _dataframe_backend
if backend not in ("polars", "pandas"):
raise ValueError(f"backend must be 'polars' or 'pandas', got '{backend}'")
_dataframe_backend = backend
[docs]
def get_dataframe_backend() -> str:
"""Get the current DataFrame backend setting.
Returns
-------
str
Current backend: "polars" or "pandas".
"""
return _dataframe_backend
# --- Measurement Data Validators -------------------------------------------- #
[docs]
class MeasurementValidationError(IonworksError):
"""Exception raised when measurement data validation fails."""
[docs]
def __init__(self, message: str, errors: list[str] | None = None) -> None:
super().__init__(message)
self.errors = errors or []
def _get_column(df: DataFrame, col: str) -> np.ndarray:
"""
Extract a column as a numpy array from either pandas or polars DataFrame.
Parameters
----------
df : DataFrame
pandas or polars DataFrame.
col : str
Column name.
Returns
-------
np.ndarray
Column values as numpy array.
"""
if isinstance(df, pl.DataFrame):
return df.get_column(col).to_numpy()
return df[col].to_numpy()
def _has_column(df: DataFrame, col: str) -> bool:
"""Check if a column exists in the DataFrame."""
return col in df.columns
def _get_step_group_indices(step_data: np.ndarray) -> np.ndarray:
"""Compute step group indices for each row (0-indexed, based on contiguous groups).
Parameters
----------
step_data : np.ndarray
Array of step numbers/identifiers.
Returns
-------
np.ndarray
Array where each element is the step group index (0, 1, 2, ...) for
that row.
"""
changes = np.concatenate([[True], np.diff(step_data) != 0])
return np.cumsum(changes) - 1
[docs]
def positive_current_is_charge(
t: np.ndarray,
current: np.ndarray,
voltage: np.ndarray,
) -> tuple[bool, float]:
"""Determine whether positive current corresponds to charging.
Fits ``voltage = intercept + slope * capacity`` using weighted least
squares (weights = dt). If the slope is non-negative (voltage rises
or stays flat as cumulative charge increases) then positive current is
charging; otherwise positive current is discharging.
Parameters
----------
t : np.ndarray
Time values [s].
current : np.ndarray
Current values [A].
voltage : np.ndarray
Voltage values [V].
Returns
-------
is_charge : bool
``True`` if positive current is charging (slope >= 0).
``False`` if positive current is discharging (slope < 0).
Returns ``False`` (assume discharge) when there is insufficient data.
p_value : float
Two-sided p-value for the slope being nonzero. Lower means more
confident. Returns 1.0 when there is insufficient data for a
t-test.
"""
t = np.asarray(t, dtype=float)
if t.max() == t.min():
return False, 1.0
current = np.asarray(current, dtype=float)
voltage = np.asarray(voltage, dtype=float)
x = cumulative_trapezoid(y=current, x=t, initial=0)
y = voltage
dof = len(t) - 2
if dof <= 0:
if len(t) < 2 or x[1] == x[0]:
return False, 1.0
slope = (y[1] - y[0]) / (x[1] - x[0])
return bool(slope >= 0), 1.0
w = np.diff(t, prepend=t[0])
W = np.sum(w)
x_bar = np.dot(w, x) / W
y_bar = np.dot(w, y) / W
dx = x - x_bar
Sxx = np.dot(w, dx**2)
if Sxx <= 0:
return False, 1.0
slope = np.dot(w, dx * (y - y_bar)) / Sxx
intercept = y_bar - slope * x_bar
residuals = y - (intercept + slope * x)
s2 = max(np.dot(w, residuals**2) / dof, 0.0)
slope_se = np.sqrt(s2 / Sxx)
if slope_se <= 0:
if slope == 0.0:
return True, 1.0
return bool(slope >= 0), 0.0
t_stat = slope / slope_se
p_value = float(2.0 * t_dist.sf(np.abs(t_stat), dof))
return bool(slope >= 0), p_value
[docs]
def validate_positive_current_is_discharge( # noqa: PLR0913
df: DataFrame,
current_col: str = "Current [A]",
voltage_col: str = "Voltage [V]",
time_col: str = "Time [s]",
step_col: str | None = None,
rest_tol: float = 1e-3,
) -> list[str]:
"""
Validate that positive current corresponds to discharge.
Discharge should cause voltage to decrease. This function analyzes the
relationship between current direction and voltage change to verify the
sign convention is correct.
Uses weighted least squares (V vs cumulative Q) per step, then a
confidence-weighted vote across steps to decide the overall convention.
This matches the algorithm in ``ionworksdata.transform``.
Parameters
----------
df : DataFrame
Time series data with current and voltage columns (pandas or polars).
current_col : str
Name of the current column.
voltage_col : str
Name of the voltage column.
time_col : str
Name of the time column.
step_col : str, optional
Name of the step column. If provided, analyzes per-step. Otherwise,
infers steps from current sign changes.
rest_tol : float
Tolerance for considering current as zero (rest).
Returns
-------
list[str]
List of validation error messages. Empty if validation passes.
"""
if not _has_column(df, current_col) or not _has_column(df, voltage_col):
return []
if not _has_column(df, time_col):
return []
current = _get_column(df, current_col)
voltage = _get_column(df, voltage_col)
time = _get_column(df, time_col)
if len(current) == 0:
return []
# Determine step groups
if step_col and _has_column(df, step_col):
step_data = _get_column(df, step_col)
else:
# Infer steps from current sign changes
max_abs = np.max(np.abs(current))
if max_abs == 0:
return []
normalized = current / max_abs
step_data = np.sign(normalized * (np.abs(normalized) > rest_tol))
step_groups = _get_step_group_indices(step_data)
num_steps = step_groups[-1] + 1
# Mean current per step (for rest filtering)
step_current_sum = np.bincount(step_groups, weights=current, minlength=num_steps)
step_counts = np.bincount(step_groups, minlength=num_steps).astype(float)
step_counts[step_counts == 0] = 1
mean_current = step_current_sum / step_counts
# Identify non-rest steps
non_rest_steps = set(np.where(np.abs(mean_current) >= rest_tol)[0])
if not non_rest_steps:
return []
# Classify each non-rest step using WLS V-vs-Q
charge_weight = 0.0
discharge_weight = 0.0
charge_count = 0
discharge_count = 0
for step_id in non_rest_steps:
mask = step_groups == step_id
is_charge, p_value = positive_current_is_charge(
time[mask], current[mask], voltage[mask]
)
confidence = 1.0 - p_value
if is_charge:
charge_weight += confidence
charge_count += 1
else:
discharge_weight += confidence
discharge_count += 1
if charge_weight + discharge_weight > 0:
should_flag = charge_weight > discharge_weight
else:
should_flag = charge_count > discharge_count
if should_flag:
return [
"Current sign convention error: positive current appears to "
"be charge, not discharge. Voltage increases when current is "
"positive, but for discharge, voltage should decrease. Use "
"ionworksdata.transform.set_positive_current_for_discharge"
"(data) to fix this."
]
return []
[docs]
def validate_cumulative_values_reset_per_step(
df: DataFrame,
step_col: str = "Step count",
cumulative_cols: list[str] | None = None,
tolerance: float = 1e-6,
) -> list[str]:
"""Validate cumulative values reset to ~0 at each step and only increase.
Parameters
----------
df : DataFrame
Time series data (pandas or polars).
step_col : str
Name of the column containing step numbers.
cumulative_cols : list[str], optional
List of cumulative column names to validate. If None, checks for common
capacity and energy columns.
tolerance : float
Tolerance for considering a value as "zero" at step start.
Returns
-------
list[str]
List of validation error messages. Empty if validation passes.
"""
errors = []
if not _has_column(df, step_col):
return []
if cumulative_cols is None:
cumulative_cols = [
"Discharge capacity [A.h]",
"Charge capacity [A.h]",
"Discharge energy [W.h]",
"Charge energy [W.h]",
]
cols_to_check = [col for col in cumulative_cols if _has_column(df, col)]
if not cols_to_check:
return []
fix_hint = (
"Use ionworksdata.transform.set_capacity(data) and/or "
"ionworksdata.transform.set_energy(data) to fix this."
)
step_data = _get_column(df, step_col)
if len(step_data) == 0:
return []
step_groups = _get_step_group_indices(step_data)
# Find step boundaries (first index of each step)
step_boundaries = np.where(np.diff(step_groups, prepend=-1) != 0)[0]
for col in cols_to_check:
values = _get_column(df, col)
# Check 1: Values at step starts should be ~0
start_values = values[step_boundaries]
non_zero_mask = np.abs(start_values) > tolerance
non_zero_steps = np.where(non_zero_mask)[0]
for step_idx in non_zero_steps:
errors.append(
f"Column '{col}' does not reset at start of "
f"step {step_idx}: expected ~0, got "
f"{start_values[step_idx]:.6f}. Cumulative values "
f"should reset to 0 at the start of each step. "
f"{fix_hint}"
)
# Check 2: Values should be monotonically non-decreasing within each step
# Compute diff and check where it's negative within same step
value_diffs = np.diff(values, prepend=values[0])
step_diffs = np.diff(step_groups, prepend=step_groups[0])
# Mask: same step (diff == 0) and value decreased
within_step = step_diffs == 0
decreased = value_diffs < -tolerance
# Find first decrease per step
problem_indices = np.where(within_step & decreased)[0]
if len(problem_indices) > 0:
# Group by step and report first decrease per step
problem_steps = step_groups[problem_indices]
unique_problem_steps = np.unique(problem_steps)
for step_idx in unique_problem_steps:
# Find first index in this step with decrease
step_problem_indices = problem_indices[problem_steps == step_idx]
first_idx = step_problem_indices[0]
errors.append(
f"Column '{col}' decreases within step "
f"{step_idx} at index {first_idx}: value went "
f"from {values[first_idx - 1]:.6f} to "
f"{values[first_idx]:.6f}. Cumulative values "
f"should only increase within a step. "
f"{fix_hint}"
)
return errors
[docs]
def validate_minimum_points_per_step(
df: DataFrame,
step_col: str = "Step count",
min_points: int = 2,
) -> list[str]:
"""
Validate that each step has at least a minimum number of data points.
Parameters
----------
df : DataFrame
Time series data (pandas or polars).
step_col : str
Name of the column containing step numbers.
min_points : int
Minimum number of points required per step.
Returns
-------
list[str]
List of validation error messages. Empty if validation passes.
"""
if not _has_column(df, step_col):
return []
step_data = _get_column(df, step_col)
if len(step_data) == 0:
return []
step_groups = _get_step_group_indices(step_data)
num_steps = step_groups[-1] + 1
# Vectorized count per step
step_counts = np.bincount(step_groups, minlength=num_steps)
# Find steps with insufficient points
insufficient_mask = step_counts < min_points
insufficient_steps = np.where(insufficient_mask)[0]
errors = []
for step_idx in insufficient_steps:
num_points = step_counts[step_idx]
errors.append(
f"Step {step_idx} has only {num_points} data point(s), "
f"but at least {min_points} are required."
)
return errors
[docs]
def validate_time_starts_at_zero(
df: DataFrame,
tolerance: float = 1e-6,
) -> list[str]:
"""Validate that 'Time [s]' starts at 0.
Parameters
----------
df : DataFrame
Time series data (pandas or polars).
tolerance : float
Tolerance for considering the start value as zero.
Returns
-------
list[str]
List of validation error messages. Empty if validation passes.
"""
time_col = "Time [s]"
if not _has_column(df, time_col):
return []
time_data = _get_column(df, time_col)
if len(time_data) == 0:
return []
if abs(time_data[0]) > tolerance:
return [
f"Column '{time_col}' must start at 0, but starts at "
f"{time_data[0]}. Use ionworksdata.transform.reset_time"
f"(data) to fix this. To indicate the absolute time when a step starts, "
"use the start_time field in the measurement metadata."
]
return []
[docs]
def validate_time_monotonic(
df: DataFrame,
time_col: str = "Time [s]",
tolerance: float = 1e-12,
) -> list[str]:
"""Validate that the time column is monotonically non-decreasing.
Parameters
----------
df : DataFrame
Time series data (pandas or polars).
time_col : str
Name of the time column.
tolerance : float
Numerical tolerance; time[i] must be >= time[i-1] - tolerance.
Returns
-------
list[str]
List of validation error messages. Empty if validation passes.
"""
if not _has_column(df, time_col):
return []
time_data = _get_column(df, time_col)
if len(time_data) < 2:
return []
diffs = np.diff(time_data)
bad_mask = diffs < -tolerance
if not np.any(bad_mask):
return []
bad_indices = np.where(bad_mask)[0]
first_idx = bad_indices[0]
return [
f"Column '{time_col}' must be monotonically non-decreasing. "
f"At index {first_idx + 1}: {time_data[first_idx]:.6f}s < "
f"previous {time_data[first_idx - 1]:.6f}s. "
f"Use a cumulative time series (e.g. ionworksdata.transform or "
f"ensure per-step time is converted to global elapsed time)."
]
[docs]
def validate_step_count_sequential(
df: DataFrame,
) -> list[str]:
"""Validate that 'Step count' exists, starts at 0, and increases by 1.
Parameters
----------
df : DataFrame
Time series data (pandas or polars).
Returns
-------
list[str]
List of validation error messages. Empty if validation passes.
"""
step_col = "Step count"
fix_hint = "Use ionworksdata.transform.set_step_count(data) to fix this."
if not _has_column(df, step_col):
return [
f"Column '{step_col}' is required but was not found in "
f"the data. Available columns: {list(df.columns)}. "
f"{fix_hint}"
]
step_data = _get_column(df, step_col)
if len(step_data) == 0:
return []
errors: list[str] = []
if step_data[0] != 0:
errors.append(
f"Column '{step_col}' must start at 0, but starts at "
f"{step_data[0]}. {fix_hint}"
)
raw_diffs = np.diff(step_data)
bad_mask = (raw_diffs != 0) & (raw_diffs != 1)
if np.any(bad_mask):
bad_indices = np.where(bad_mask)[0]
# Show up to 5 example transitions
examples = []
for idx in bad_indices[:5]:
examples.append(
f"index {idx}: {step_data[idx]} -> "
f"{step_data[idx + 1]} (diff={raw_diffs[idx]})"
)
more = ""
if len(bad_indices) > 5:
more = f" (and {len(bad_indices) - 5} more)"
errors.append(
f"Column '{step_col}' must increase by 1 at each step "
f"transition, but found {len(bad_indices)} invalid "
f"transition(s): " + "; ".join(examples) + f".{more} {fix_hint}"
)
return errors
[docs]
def validate_cycle_constant_within_step(
df: DataFrame,
step_col: str = "Step count",
cycle_col: str | None = None,
) -> list[str]:
"""
Validate that cycle number does not change within a step.
Parameters
----------
df : DataFrame
Time series data (pandas or polars).
step_col : str
Name of the column containing step numbers.
cycle_col : str, optional
Name of the column containing cycle numbers. If None, tries common names.
Returns
-------
list[str]
List of validation error messages. Empty if validation passes.
"""
if not _has_column(df, step_col):
return []
# Find cycle column
if cycle_col is None:
for col in ["Cycle count", "Cycle number", "Cycle from cycler"]:
if _has_column(df, col):
cycle_col = col
break
if cycle_col is None or not _has_column(df, cycle_col):
return []
step_data = _get_column(df, step_col)
if len(step_data) == 0:
return []
cycle_data = _get_column(df, cycle_col)
step_groups = _get_step_group_indices(step_data)
# Detect cycle changes within steps:
# A cycle change within a step occurs when:
# - The cycle value differs from the previous row
# - AND we're in the same step group
cycle_diffs = np.diff(cycle_data, prepend=cycle_data[0])
step_diffs = np.diff(step_groups, prepend=step_groups[0])
# Within-step cycle change: same step (step_diff == 0) but cycle changed
within_step_cycle_change = (step_diffs == 0) & (cycle_diffs != 0)
problem_indices = np.where(within_step_cycle_change)[0]
if len(problem_indices) == 0:
return []
# Group by step and report
problem_steps = step_groups[problem_indices]
unique_problem_steps = np.unique(problem_steps)
errors = []
for step_idx in unique_problem_steps:
# Find all unique cycles in this step
step_mask = step_groups == step_idx
unique_cycles = np.unique(cycle_data[step_mask])
errors.append(
f"Cycle number changes within step {step_idx}: "
f"found cycles {unique_cycles.tolist()}. "
f"Each step should belong to a single cycle. "
f"Use ionworksdata.transform.set_cycle_count(data) "
f"to fix this."
)
return errors
[docs]
def validate_ocp_columns(df: DataFrame) -> list[str]:
"""Validate that OCP data has required columns.
Checks that the DataFrame contains:
1. A 'Voltage [V]' column
2. At least one x-axis column: 'Capacity [A.h]', 'Stoichiometry', or 'SOC'
Parameters
----------
df : DataFrame
Time series data (pandas or polars).
Returns
-------
list[str]
List of validation error messages. Empty if validation passes.
"""
errors: list[str] = []
if not _has_column(df, "Voltage [V]"):
errors.append(
"OCP data must contain a 'Voltage [V]' column. "
f"Available columns: {list(df.columns)}"
)
x_axis_columns = ["Capacity [A.h]", "Stoichiometry", "SOC"]
if not any(_has_column(df, col) for col in x_axis_columns):
errors.append(
"OCP data must contain at least one x-axis column: "
f"{', '.join(x_axis_columns)}. "
f"Available columns: {list(df.columns)}"
)
return errors
[docs]
def validate_time_series_row_count(
df: DataFrame,
max_rows: int = 1000,
) -> list[str]:
"""Validate that the time series does not exceed the maximum row count.
Datasets larger than ``max_rows`` should be uploaded via the standard
upload flow and then referenced with ``"db:<measurement_id>"`` or
``iwdata.DataLoader.from_db(MEASUREMENT_ID)`` in pipeline configurations.
Parameters
----------
df : DataFrame
Time series data (pandas or polars).
max_rows : int
Maximum allowed number of rows.
Returns
-------
list[str]
List of validation error messages. Empty if validation passes.
"""
n_rows = len(df)
if n_rows > max_rows:
return [
f"Time series has {n_rows} rows, which exceeds the maximum of "
f"{max_rows} rows for inline data. Upload the data as a "
f"measurement using client.cell_measurement.create() and "
f'reference it with "db:<measurement_id>" or '
f"iwdata.DataLoader.from_db(MEASUREMENT_ID) instead."
]
return []
[docs]
def validate_measurement_data(
df: DataFrame,
strict: bool = False,
data_type: str | None = None,
) -> None:
"""Validate measurement time series data before upload.
For standard cycler data (``data_type=None``), performs:
1. Positive current should correspond to discharge (voltage decreases)
2. Time starts at 0
3. Time is monotonically non-decreasing
4. 'Step count' column exists, starts at 0, and increases by 1
5. Cumulative values (capacity, energy) should reset at each step
start and only increase within steps
6. Each step has at least 2 data points (strict mode only)
7. Cycle number does not change within a step (strict mode only)
For OCP data (``data_type="ocp"``), only validates:
1. 'Voltage [V]' column exists
2. 'Step count' column exists and is sequential
Parameters
----------
df : DataFrame
Time series data to validate (pandas or polars DataFrame).
strict : bool
If False (default), skip strict checks. If True,
run additional checks: minimum 2 points per step
and cycle number constant within each step.
data_type : str | None
The type of data being validated. Use ``"ocp"`` for open-circuit
potential data, which relaxes validation to skip current, time,
capacity, and energy checks. Default is ``None`` (standard cycler
data).
Raises
------
MeasurementValidationError
If any validation checks fail. The exception contains a list of all
errors found.
"""
all_errors = []
step_col = "Step count"
if data_type == "ocp":
ocp_col_errors = validate_ocp_columns(df)
all_errors.extend(ocp_col_errors)
step_seq_errors = validate_step_count_sequential(df)
all_errors.extend(step_seq_errors)
else:
# Check 1: Positive current should be discharge
current_errors = validate_positive_current_is_discharge(df, step_col=step_col)
all_errors.extend(current_errors)
# Check 2: Time starts at 0
time_errors = validate_time_starts_at_zero(df)
all_errors.extend(time_errors)
# Check 3: Time is monotonic
monotonic_errors = validate_time_monotonic(df)
all_errors.extend(monotonic_errors)
# Check 4: 'Step count' column exists, starts at 0,
# and increments by 1
step_seq_errors = validate_step_count_sequential(df)
all_errors.extend(step_seq_errors)
if _has_column(df, step_col):
# Check 5: Cumulative values should reset at each step
cumulative_errors = validate_cumulative_values_reset_per_step(df, step_col)
all_errors.extend(cumulative_errors)
if strict:
# Check 6: At least 2 points per step
points_errors = validate_minimum_points_per_step(df, step_col)
all_errors.extend(points_errors)
# Check 7: Cycle constant within step
cycle_errors = validate_cycle_constant_within_step(df, step_col)
all_errors.extend(cycle_errors)
if all_errors:
raise MeasurementValidationError(
f"Measurement data validation failed with {len(all_errors)} error(s):\n"
+ "\n".join(f" - {err}" for err in all_errors),
errors=all_errors,
)
# --- Atomic validators ------------------------------------------------------ #
[docs]
def df_to_dict_validator(v: Any) -> Any:
"""Convert DataFrame to dict with orient='list' for serialization."""
if isinstance(v, pd.DataFrame):
# Replace inf/-inf and NaN with None for JSON compatibility
return v.replace([np.inf, -np.inf, np.nan], None).to_dict(orient="list")
if isinstance(v, pl.DataFrame):
# Replace inf/-inf and NaN with None for JSON compatibility
# Process each column individually to avoid name conflicts
result = {}
for col_name in v.columns:
col = v[col_name]
if col.dtype.is_float():
# Replace inf/-inf and NaN with None
sanitized = col.to_list()
sanitized = [
None if (x is not None and (math.isinf(x) or math.isnan(x))) else x
for x in sanitized
]
result[col_name] = sanitized
else:
result[col_name] = col.to_list()
return result
return v
[docs]
def dict_to_df_validator(v: Any, return_type: str | None = None) -> Any:
"""Convert dict to DataFrame for data processing.
Parameters
----------
v : Any
Value to convert. If dict, converts to DataFrame.
return_type : str | None
Type of DataFrame to return: "polars" or "pandas".
If None, uses the global setting from set_dataframe_backend().
Returns
-------
Any
DataFrame if input was dict, otherwise unchanged.
"""
if isinstance(v, dict):
backend = return_type if return_type is not None else _dataframe_backend
# Check if all values are scalars (not lists/arrays)
all_scalars = all(
not isinstance(val, list | tuple | np.ndarray) for val in v.values()
)
if backend == "pandas":
if all_scalars:
return pd.DataFrame(v, index=[0])
return pd.DataFrame(v)
if all_scalars:
return pl.DataFrame({k: [val] for k, val in v.items()})
return pl.DataFrame(v)
return v
[docs]
def parameter_validator(v: Any) -> Any:
"""Convert pybamm.Symbol values to JSON-serializable form."""
if isinstance(v, pybamm.Symbol):
return convert_symbol_to_json(v)
return v
[docs]
def float_sanitizer(v: Any) -> Any:
"""Sanitize float values to JSON-compatible forms.
Converts inf, -inf, and NaN to None since these are not JSON-compliant.
"""
if isinstance(v, float) and (math.isinf(v) or math.isnan(v)):
return None
if isinstance(v, np.floating) and (np.isinf(v) or np.isnan(v)):
return None
return v
[docs]
def bounds_tuple_validator(v: Any) -> Any:
"""Convert bounds 2-tuple to list for JSON serialization.
Parameters
----------
v : Any
Value to validate. If it's a tuple with 2 elements, converts to list.
Returns
-------
Any
List if input was a 2-tuple, otherwise unchanged.
"""
if isinstance(v, tuple) and len(v) == 2:
return list(v)
return v
[docs]
def file_scheme_validator(v: Any) -> Any:
"""Convert file:// and folder:// scheme paths to serialized dicts.
Handles ``file:`` prefixed paths (loads CSV as dict) and ``folder:``
prefixed paths (loads time_series.csv and steps.csv as dict).
All other values are returned unchanged.
Raises
------
FileNotFoundError
If the file or folder path doesn't exist.
"""
if isinstance(v, str) and v.startswith("file:"):
path = pathlib.Path(v.split(":")[1]).expanduser().resolve()
if not path.exists() or not path.is_file():
raise FileNotFoundError(f"CSV file not found: {v}")
return df_to_dict_validator(pd.read_csv(path))
if isinstance(v, str) and v.startswith("folder:"):
path = pathlib.Path(v.split(":")[1]).expanduser().resolve()
if not path.exists() or not path.is_dir():
raise FileNotFoundError(f"Folder not found: {v}")
return {
"time_series": df_to_dict_validator(pd.read_csv(path / "time_series.csv")),
"steps": df_to_dict_validator(pd.read_csv(path / "steps.csv")),
}
return v
# --- Pipeline composition helpers ------------------------------------------ #
Validator = Callable[[Any], Any]
def _apply_pipeline(value: Any, validators: Iterable[Validator]) -> Any:
transformed = value
for validator in validators:
transformed = validator(transformed)
return transformed
def _apply_recursive(value: Any, validators: Iterable[Validator]) -> Any:
if isinstance(value, dict):
return {key: _apply_recursive(val, validators) for key, val in value.items()}
if isinstance(value, tuple):
# Apply validators to tuple first (e.g., to convert bounds tuples to lists)
transformed = _apply_pipeline(value, validators)
# If validator converted tuple to list, process recursively
if isinstance(transformed, list):
return [_apply_recursive(item, validators) for item in transformed]
# Otherwise, process tuple items recursively
return [_apply_recursive(item, validators) for item in transformed]
if isinstance(value, list):
return [_apply_recursive(item, validators) for item in value]
return _apply_pipeline(value, validators)
# --- Public pipelines ------------------------------------------------------- #
def _time_series_row_count_validator(v: Any) -> Any:
"""Block inline DataFrames that exceed the row limit.
Raises
------
MeasurementValidationError
If the DataFrame has more than 1000 rows.
"""
if isinstance(v, pd.DataFrame | pl.DataFrame):
errors = validate_time_series_row_count(v)
if errors:
raise MeasurementValidationError(
errors[0],
errors=errors,
)
return v
validators_outbound: list[Validator] = [
float_sanitizer,
bounds_tuple_validator,
file_scheme_validator,
_time_series_row_count_validator,
df_to_dict_validator,
parameter_validator,
]
validators_inbound: list[Validator] = [
dict_to_df_validator,
]
[docs]
def run_validators_outbound(v: Any) -> Any:
"""Recursively apply outbound validators to values and nested containers."""
return _apply_recursive(v, validators_outbound)
[docs]
def run_validators_inbound(v: Any) -> Any:
"""Recursively apply inbound validators to values and nested containers."""
return _apply_recursive(v, validators_inbound)