Source code for ionworks.validators

"""Reusable validator functions and composable pipelines for value normalization.

Provides functions for composable inbound/outbound value normalization
(e.g., converting between pandas DataFrames and dictionaries).
"""

from collections.abc import Callable, Iterable
import math
import os
import pathlib
from typing import Any

from dotenv import load_dotenv
import numpy as np
import pandas as pd
import polars as pl
import pybamm
from pybamm.expression_tree.operations.serialise import convert_symbol_to_json
from scipy.integrate import cumulative_trapezoid
from scipy.stats import t as t_dist

from .errors import IonworksError

# --- DataFrame Backend Configuration ---------------------------------------- #

# Load .env file before reading environment variables
load_dotenv()

# Type alias for DataFrame (pandas or polars)
DataFrame = pd.DataFrame | pl.DataFrame


def _get_default_backend() -> str:
    """Get default backend from environment variable or fall back to 'polars'."""
    env_val = os.getenv("IONWORKS_DATAFRAME_BACKEND", "polars").lower()
    if env_val not in ("polars", "pandas"):
        return "polars"
    return env_val


# Module-level configuration for DataFrame return type
# Initialized from IONWORKS_DATAFRAME_BACKEND env var, defaults to "polars"
_dataframe_backend: str = _get_default_backend()


[docs] def set_dataframe_backend(backend: str) -> None: """Set the default DataFrame backend for data fetching. This overrides the IONWORKS_DATAFRAME_BACKEND environment variable. Parameters ---------- backend : str DataFrame backend to use: "polars" or "pandas". Raises ------ ValueError If backend is not "polars" or "pandas". """ global _dataframe_backend if backend not in ("polars", "pandas"): raise ValueError(f"backend must be 'polars' or 'pandas', got '{backend}'") _dataframe_backend = backend
[docs] def get_dataframe_backend() -> str: """Get the current DataFrame backend setting. Returns ------- str Current backend: "polars" or "pandas". """ return _dataframe_backend
# --- Measurement Data Validators -------------------------------------------- #
[docs] class MeasurementValidationError(IonworksError): """Exception raised when measurement data validation fails."""
[docs] def __init__(self, message: str, errors: list[str] | None = None) -> None: super().__init__(message) self.errors = errors or []
def _get_column(df: DataFrame, col: str) -> np.ndarray: """ Extract a column as a numpy array from either pandas or polars DataFrame. Parameters ---------- df : DataFrame pandas or polars DataFrame. col : str Column name. Returns ------- np.ndarray Column values as numpy array. """ if isinstance(df, pl.DataFrame): return df.get_column(col).to_numpy() return df[col].to_numpy() def _has_column(df: DataFrame, col: str) -> bool: """Check if a column exists in the DataFrame.""" return col in df.columns def _get_step_group_indices(step_data: np.ndarray) -> np.ndarray: """Compute step group indices for each row (0-indexed, based on contiguous groups). Parameters ---------- step_data : np.ndarray Array of step numbers/identifiers. Returns ------- np.ndarray Array where each element is the step group index (0, 1, 2, ...) for that row. """ changes = np.concatenate([[True], np.diff(step_data) != 0]) return np.cumsum(changes) - 1
[docs] def positive_current_is_charge( t: np.ndarray, current: np.ndarray, voltage: np.ndarray, ) -> tuple[bool, float]: """Determine whether positive current corresponds to charging. Fits ``voltage = intercept + slope * capacity`` using weighted least squares (weights = dt). If the slope is non-negative (voltage rises or stays flat as cumulative charge increases) then positive current is charging; otherwise positive current is discharging. Parameters ---------- t : np.ndarray Time values [s]. current : np.ndarray Current values [A]. voltage : np.ndarray Voltage values [V]. Returns ------- is_charge : bool ``True`` if positive current is charging (slope >= 0). ``False`` if positive current is discharging (slope < 0). Returns ``False`` (assume discharge) when there is insufficient data. p_value : float Two-sided p-value for the slope being nonzero. Lower means more confident. Returns 1.0 when there is insufficient data for a t-test. """ t = np.asarray(t, dtype=float) if t.max() == t.min(): return False, 1.0 current = np.asarray(current, dtype=float) voltage = np.asarray(voltage, dtype=float) x = cumulative_trapezoid(y=current, x=t, initial=0) y = voltage dof = len(t) - 2 if dof <= 0: if len(t) < 2 or x[1] == x[0]: return False, 1.0 slope = (y[1] - y[0]) / (x[1] - x[0]) return bool(slope >= 0), 1.0 w = np.diff(t, prepend=t[0]) W = np.sum(w) x_bar = np.dot(w, x) / W y_bar = np.dot(w, y) / W dx = x - x_bar Sxx = np.dot(w, dx**2) if Sxx <= 0: return False, 1.0 slope = np.dot(w, dx * (y - y_bar)) / Sxx intercept = y_bar - slope * x_bar residuals = y - (intercept + slope * x) s2 = max(np.dot(w, residuals**2) / dof, 0.0) slope_se = np.sqrt(s2 / Sxx) if slope_se <= 0: if slope == 0.0: return True, 1.0 return bool(slope >= 0), 0.0 t_stat = slope / slope_se p_value = float(2.0 * t_dist.sf(np.abs(t_stat), dof)) return bool(slope >= 0), p_value
[docs] def validate_positive_current_is_discharge( # noqa: PLR0913 df: DataFrame, current_col: str = "Current [A]", voltage_col: str = "Voltage [V]", time_col: str = "Time [s]", step_col: str | None = None, rest_tol: float = 1e-3, ) -> list[str]: """ Validate that positive current corresponds to discharge. Discharge should cause voltage to decrease. This function analyzes the relationship between current direction and voltage change to verify the sign convention is correct. Uses weighted least squares (V vs cumulative Q) per step, then a confidence-weighted vote across steps to decide the overall convention. This matches the algorithm in ``ionworksdata.transform``. Parameters ---------- df : DataFrame Time series data with current and voltage columns (pandas or polars). current_col : str Name of the current column. voltage_col : str Name of the voltage column. time_col : str Name of the time column. step_col : str, optional Name of the step column. If provided, analyzes per-step. Otherwise, infers steps from current sign changes. rest_tol : float Tolerance for considering current as zero (rest). Returns ------- list[str] List of validation error messages. Empty if validation passes. """ if not _has_column(df, current_col) or not _has_column(df, voltage_col): return [] if not _has_column(df, time_col): return [] current = _get_column(df, current_col) voltage = _get_column(df, voltage_col) time = _get_column(df, time_col) if len(current) == 0: return [] # Determine step groups if step_col and _has_column(df, step_col): step_data = _get_column(df, step_col) else: # Infer steps from current sign changes max_abs = np.max(np.abs(current)) if max_abs == 0: return [] normalized = current / max_abs step_data = np.sign(normalized * (np.abs(normalized) > rest_tol)) step_groups = _get_step_group_indices(step_data) num_steps = step_groups[-1] + 1 # Mean current per step (for rest filtering) step_current_sum = np.bincount(step_groups, weights=current, minlength=num_steps) step_counts = np.bincount(step_groups, minlength=num_steps).astype(float) step_counts[step_counts == 0] = 1 mean_current = step_current_sum / step_counts # Identify non-rest steps non_rest_steps = set(np.where(np.abs(mean_current) >= rest_tol)[0]) if not non_rest_steps: return [] # Classify each non-rest step using WLS V-vs-Q charge_weight = 0.0 discharge_weight = 0.0 charge_count = 0 discharge_count = 0 for step_id in non_rest_steps: mask = step_groups == step_id is_charge, p_value = positive_current_is_charge( time[mask], current[mask], voltage[mask] ) confidence = 1.0 - p_value if is_charge: charge_weight += confidence charge_count += 1 else: discharge_weight += confidence discharge_count += 1 if charge_weight + discharge_weight > 0: should_flag = charge_weight > discharge_weight else: should_flag = charge_count > discharge_count if should_flag: return [ "Current sign convention error: positive current appears to " "be charge, not discharge. Voltage increases when current is " "positive, but for discharge, voltage should decrease. Use " "ionworksdata.transform.set_positive_current_for_discharge" "(data) to fix this." ] return []
[docs] def validate_cumulative_values_reset_per_step( df: DataFrame, step_col: str = "Step count", cumulative_cols: list[str] | None = None, tolerance: float = 1e-6, ) -> list[str]: """Validate cumulative values reset to ~0 at each step and only increase. Parameters ---------- df : DataFrame Time series data (pandas or polars). step_col : str Name of the column containing step numbers. cumulative_cols : list[str], optional List of cumulative column names to validate. If None, checks for common capacity and energy columns. tolerance : float Tolerance for considering a value as "zero" at step start. Returns ------- list[str] List of validation error messages. Empty if validation passes. """ errors = [] if not _has_column(df, step_col): return [] if cumulative_cols is None: cumulative_cols = [ "Discharge capacity [A.h]", "Charge capacity [A.h]", "Discharge energy [W.h]", "Charge energy [W.h]", ] cols_to_check = [col for col in cumulative_cols if _has_column(df, col)] if not cols_to_check: return [] fix_hint = ( "Use ionworksdata.transform.set_capacity(data) and/or " "ionworksdata.transform.set_energy(data) to fix this." ) step_data = _get_column(df, step_col) if len(step_data) == 0: return [] step_groups = _get_step_group_indices(step_data) # Find step boundaries (first index of each step) step_boundaries = np.where(np.diff(step_groups, prepend=-1) != 0)[0] for col in cols_to_check: values = _get_column(df, col) # Check 1: Values at step starts should be ~0 start_values = values[step_boundaries] non_zero_mask = np.abs(start_values) > tolerance non_zero_steps = np.where(non_zero_mask)[0] for step_idx in non_zero_steps: errors.append( f"Column '{col}' does not reset at start of " f"step {step_idx}: expected ~0, got " f"{start_values[step_idx]:.6f}. Cumulative values " f"should reset to 0 at the start of each step. " f"{fix_hint}" ) # Check 2: Values should be monotonically non-decreasing within each step # Compute diff and check where it's negative within same step value_diffs = np.diff(values, prepend=values[0]) step_diffs = np.diff(step_groups, prepend=step_groups[0]) # Mask: same step (diff == 0) and value decreased within_step = step_diffs == 0 decreased = value_diffs < -tolerance # Find first decrease per step problem_indices = np.where(within_step & decreased)[0] if len(problem_indices) > 0: # Group by step and report first decrease per step problem_steps = step_groups[problem_indices] unique_problem_steps = np.unique(problem_steps) for step_idx in unique_problem_steps: # Find first index in this step with decrease step_problem_indices = problem_indices[problem_steps == step_idx] first_idx = step_problem_indices[0] errors.append( f"Column '{col}' decreases within step " f"{step_idx} at index {first_idx}: value went " f"from {values[first_idx - 1]:.6f} to " f"{values[first_idx]:.6f}. Cumulative values " f"should only increase within a step. " f"{fix_hint}" ) return errors
[docs] def validate_minimum_points_per_step( df: DataFrame, step_col: str = "Step count", min_points: int = 2, ) -> list[str]: """ Validate that each step has at least a minimum number of data points. Parameters ---------- df : DataFrame Time series data (pandas or polars). step_col : str Name of the column containing step numbers. min_points : int Minimum number of points required per step. Returns ------- list[str] List of validation error messages. Empty if validation passes. """ if not _has_column(df, step_col): return [] step_data = _get_column(df, step_col) if len(step_data) == 0: return [] step_groups = _get_step_group_indices(step_data) num_steps = step_groups[-1] + 1 # Vectorized count per step step_counts = np.bincount(step_groups, minlength=num_steps) # Find steps with insufficient points insufficient_mask = step_counts < min_points insufficient_steps = np.where(insufficient_mask)[0] errors = [] for step_idx in insufficient_steps: num_points = step_counts[step_idx] errors.append( f"Step {step_idx} has only {num_points} data point(s), " f"but at least {min_points} are required." ) return errors
[docs] def validate_time_starts_at_zero( df: DataFrame, tolerance: float = 1e-6, ) -> list[str]: """Validate that 'Time [s]' starts at 0. Parameters ---------- df : DataFrame Time series data (pandas or polars). tolerance : float Tolerance for considering the start value as zero. Returns ------- list[str] List of validation error messages. Empty if validation passes. """ time_col = "Time [s]" if not _has_column(df, time_col): return [] time_data = _get_column(df, time_col) if len(time_data) == 0: return [] if abs(time_data[0]) > tolerance: return [ f"Column '{time_col}' must start at 0, but starts at " f"{time_data[0]}. Use ionworksdata.transform.reset_time" f"(data) to fix this. To indicate the absolute time when a step starts, " "use the start_time field in the measurement metadata." ] return []
[docs] def validate_time_monotonic( df: DataFrame, time_col: str = "Time [s]", tolerance: float = 1e-12, ) -> list[str]: """Validate that the time column is monotonically non-decreasing. Parameters ---------- df : DataFrame Time series data (pandas or polars). time_col : str Name of the time column. tolerance : float Numerical tolerance; time[i] must be >= time[i-1] - tolerance. Returns ------- list[str] List of validation error messages. Empty if validation passes. """ if not _has_column(df, time_col): return [] time_data = _get_column(df, time_col) if len(time_data) < 2: return [] diffs = np.diff(time_data) bad_mask = diffs < -tolerance if not np.any(bad_mask): return [] bad_indices = np.where(bad_mask)[0] first_idx = bad_indices[0] return [ f"Column '{time_col}' must be monotonically non-decreasing. " f"At index {first_idx + 1}: {time_data[first_idx]:.6f}s < " f"previous {time_data[first_idx - 1]:.6f}s. " f"Use a cumulative time series (e.g. ionworksdata.transform or " f"ensure per-step time is converted to global elapsed time)." ]
[docs] def validate_step_count_sequential( df: DataFrame, ) -> list[str]: """Validate that 'Step count' exists, starts at 0, and increases by 1. Parameters ---------- df : DataFrame Time series data (pandas or polars). Returns ------- list[str] List of validation error messages. Empty if validation passes. """ step_col = "Step count" fix_hint = "Use ionworksdata.transform.set_step_count(data) to fix this." if not _has_column(df, step_col): return [ f"Column '{step_col}' is required but was not found in " f"the data. Available columns: {list(df.columns)}. " f"{fix_hint}" ] step_data = _get_column(df, step_col) if len(step_data) == 0: return [] errors: list[str] = [] if step_data[0] != 0: errors.append( f"Column '{step_col}' must start at 0, but starts at " f"{step_data[0]}. {fix_hint}" ) raw_diffs = np.diff(step_data) bad_mask = (raw_diffs != 0) & (raw_diffs != 1) if np.any(bad_mask): bad_indices = np.where(bad_mask)[0] # Show up to 5 example transitions examples = [] for idx in bad_indices[:5]: examples.append( f"index {idx}: {step_data[idx]} -> " f"{step_data[idx + 1]} (diff={raw_diffs[idx]})" ) more = "" if len(bad_indices) > 5: more = f" (and {len(bad_indices) - 5} more)" errors.append( f"Column '{step_col}' must increase by 1 at each step " f"transition, but found {len(bad_indices)} invalid " f"transition(s): " + "; ".join(examples) + f".{more} {fix_hint}" ) return errors
[docs] def validate_cycle_constant_within_step( df: DataFrame, step_col: str = "Step count", cycle_col: str | None = None, ) -> list[str]: """ Validate that cycle number does not change within a step. Parameters ---------- df : DataFrame Time series data (pandas or polars). step_col : str Name of the column containing step numbers. cycle_col : str, optional Name of the column containing cycle numbers. If None, tries common names. Returns ------- list[str] List of validation error messages. Empty if validation passes. """ if not _has_column(df, step_col): return [] # Find cycle column if cycle_col is None: for col in ["Cycle count", "Cycle number", "Cycle from cycler"]: if _has_column(df, col): cycle_col = col break if cycle_col is None or not _has_column(df, cycle_col): return [] step_data = _get_column(df, step_col) if len(step_data) == 0: return [] cycle_data = _get_column(df, cycle_col) step_groups = _get_step_group_indices(step_data) # Detect cycle changes within steps: # A cycle change within a step occurs when: # - The cycle value differs from the previous row # - AND we're in the same step group cycle_diffs = np.diff(cycle_data, prepend=cycle_data[0]) step_diffs = np.diff(step_groups, prepend=step_groups[0]) # Within-step cycle change: same step (step_diff == 0) but cycle changed within_step_cycle_change = (step_diffs == 0) & (cycle_diffs != 0) problem_indices = np.where(within_step_cycle_change)[0] if len(problem_indices) == 0: return [] # Group by step and report problem_steps = step_groups[problem_indices] unique_problem_steps = np.unique(problem_steps) errors = [] for step_idx in unique_problem_steps: # Find all unique cycles in this step step_mask = step_groups == step_idx unique_cycles = np.unique(cycle_data[step_mask]) errors.append( f"Cycle number changes within step {step_idx}: " f"found cycles {unique_cycles.tolist()}. " f"Each step should belong to a single cycle. " f"Use ionworksdata.transform.set_cycle_count(data) " f"to fix this." ) return errors
[docs] def validate_ocp_columns(df: DataFrame) -> list[str]: """Validate that OCP data has required columns. Checks that the DataFrame contains: 1. A 'Voltage [V]' column 2. At least one x-axis column: 'Capacity [A.h]', 'Stoichiometry', or 'SOC' Parameters ---------- df : DataFrame Time series data (pandas or polars). Returns ------- list[str] List of validation error messages. Empty if validation passes. """ errors: list[str] = [] if not _has_column(df, "Voltage [V]"): errors.append( "OCP data must contain a 'Voltage [V]' column. " f"Available columns: {list(df.columns)}" ) x_axis_columns = ["Capacity [A.h]", "Stoichiometry", "SOC"] if not any(_has_column(df, col) for col in x_axis_columns): errors.append( "OCP data must contain at least one x-axis column: " f"{', '.join(x_axis_columns)}. " f"Available columns: {list(df.columns)}" ) return errors
[docs] def validate_time_series_row_count( df: DataFrame, max_rows: int = 1000, ) -> list[str]: """Validate that the time series does not exceed the maximum row count. Datasets larger than ``max_rows`` should be uploaded via the standard upload flow and then referenced with ``"db:<measurement_id>"`` or ``iwdata.DataLoader.from_db(MEASUREMENT_ID)`` in pipeline configurations. Parameters ---------- df : DataFrame Time series data (pandas or polars). max_rows : int Maximum allowed number of rows. Returns ------- list[str] List of validation error messages. Empty if validation passes. """ n_rows = len(df) if n_rows > max_rows: return [ f"Time series has {n_rows} rows, which exceeds the maximum of " f"{max_rows} rows for inline data. Upload the data as a " f"measurement using client.cell_measurement.create() and " f'reference it with "db:<measurement_id>" or ' f"iwdata.DataLoader.from_db(MEASUREMENT_ID) instead." ] return []
[docs] def validate_measurement_data( df: DataFrame, strict: bool = False, data_type: str | None = None, ) -> None: """Validate measurement time series data before upload. For standard cycler data (``data_type=None``), performs: 1. Positive current should correspond to discharge (voltage decreases) 2. Time starts at 0 3. Time is monotonically non-decreasing 4. 'Step count' column exists, starts at 0, and increases by 1 5. Cumulative values (capacity, energy) should reset at each step start and only increase within steps 6. Each step has at least 2 data points (strict mode only) 7. Cycle number does not change within a step (strict mode only) For OCP data (``data_type="ocp"``), only validates: 1. 'Voltage [V]' column exists 2. 'Step count' column exists and is sequential Parameters ---------- df : DataFrame Time series data to validate (pandas or polars DataFrame). strict : bool If False (default), skip strict checks. If True, run additional checks: minimum 2 points per step and cycle number constant within each step. data_type : str | None The type of data being validated. Use ``"ocp"`` for open-circuit potential data, which relaxes validation to skip current, time, capacity, and energy checks. Default is ``None`` (standard cycler data). Raises ------ MeasurementValidationError If any validation checks fail. The exception contains a list of all errors found. """ all_errors = [] step_col = "Step count" if data_type == "ocp": ocp_col_errors = validate_ocp_columns(df) all_errors.extend(ocp_col_errors) step_seq_errors = validate_step_count_sequential(df) all_errors.extend(step_seq_errors) else: # Check 1: Positive current should be discharge current_errors = validate_positive_current_is_discharge(df, step_col=step_col) all_errors.extend(current_errors) # Check 2: Time starts at 0 time_errors = validate_time_starts_at_zero(df) all_errors.extend(time_errors) # Check 3: Time is monotonic monotonic_errors = validate_time_monotonic(df) all_errors.extend(monotonic_errors) # Check 4: 'Step count' column exists, starts at 0, # and increments by 1 step_seq_errors = validate_step_count_sequential(df) all_errors.extend(step_seq_errors) if _has_column(df, step_col): # Check 5: Cumulative values should reset at each step cumulative_errors = validate_cumulative_values_reset_per_step(df, step_col) all_errors.extend(cumulative_errors) if strict: # Check 6: At least 2 points per step points_errors = validate_minimum_points_per_step(df, step_col) all_errors.extend(points_errors) # Check 7: Cycle constant within step cycle_errors = validate_cycle_constant_within_step(df, step_col) all_errors.extend(cycle_errors) if all_errors: raise MeasurementValidationError( f"Measurement data validation failed with {len(all_errors)} error(s):\n" + "\n".join(f" - {err}" for err in all_errors), errors=all_errors, )
# --- Atomic validators ------------------------------------------------------ #
[docs] def df_to_dict_validator(v: Any) -> Any: """Convert DataFrame to dict with orient='list' for serialization.""" if isinstance(v, pd.DataFrame): # Replace inf/-inf and NaN with None for JSON compatibility return v.replace([np.inf, -np.inf, np.nan], None).to_dict(orient="list") if isinstance(v, pl.DataFrame): # Replace inf/-inf and NaN with None for JSON compatibility # Process each column individually to avoid name conflicts result = {} for col_name in v.columns: col = v[col_name] if col.dtype.is_float(): # Replace inf/-inf and NaN with None sanitized = col.to_list() sanitized = [ None if (x is not None and (math.isinf(x) or math.isnan(x))) else x for x in sanitized ] result[col_name] = sanitized else: result[col_name] = col.to_list() return result return v
[docs] def dict_to_df_validator(v: Any, return_type: str | None = None) -> Any: """Convert dict to DataFrame for data processing. Parameters ---------- v : Any Value to convert. If dict, converts to DataFrame. return_type : str | None Type of DataFrame to return: "polars" or "pandas". If None, uses the global setting from set_dataframe_backend(). Returns ------- Any DataFrame if input was dict, otherwise unchanged. """ if isinstance(v, dict): backend = return_type if return_type is not None else _dataframe_backend # Check if all values are scalars (not lists/arrays) all_scalars = all( not isinstance(val, list | tuple | np.ndarray) for val in v.values() ) if backend == "pandas": if all_scalars: return pd.DataFrame(v, index=[0]) return pd.DataFrame(v) if all_scalars: return pl.DataFrame({k: [val] for k, val in v.items()}) return pl.DataFrame(v) return v
[docs] def parameter_validator(v: Any) -> Any: """Convert pybamm.Symbol values to JSON-serializable form.""" if isinstance(v, pybamm.Symbol): return convert_symbol_to_json(v) return v
[docs] def float_sanitizer(v: Any) -> Any: """Sanitize float values to JSON-compatible forms. Converts inf, -inf, and NaN to None since these are not JSON-compliant. """ if isinstance(v, float) and (math.isinf(v) or math.isnan(v)): return None if isinstance(v, np.floating) and (np.isinf(v) or np.isnan(v)): return None return v
[docs] def bounds_tuple_validator(v: Any) -> Any: """Convert bounds 2-tuple to list for JSON serialization. Parameters ---------- v : Any Value to validate. If it's a tuple with 2 elements, converts to list. Returns ------- Any List if input was a 2-tuple, otherwise unchanged. """ if isinstance(v, tuple) and len(v) == 2: return list(v) return v
[docs] def file_scheme_validator(v: Any) -> Any: """Convert file:// and folder:// scheme paths to serialized dicts. Handles ``file:`` prefixed paths (loads CSV as dict) and ``folder:`` prefixed paths (loads time_series.csv and steps.csv as dict). All other values are returned unchanged. Raises ------ FileNotFoundError If the file or folder path doesn't exist. """ if isinstance(v, str) and v.startswith("file:"): path = pathlib.Path(v.split(":")[1]).expanduser().resolve() if not path.exists() or not path.is_file(): raise FileNotFoundError(f"CSV file not found: {v}") return df_to_dict_validator(pd.read_csv(path)) if isinstance(v, str) and v.startswith("folder:"): path = pathlib.Path(v.split(":")[1]).expanduser().resolve() if not path.exists() or not path.is_dir(): raise FileNotFoundError(f"Folder not found: {v}") return { "time_series": df_to_dict_validator(pd.read_csv(path / "time_series.csv")), "steps": df_to_dict_validator(pd.read_csv(path / "steps.csv")), } return v
# --- Pipeline composition helpers ------------------------------------------ # Validator = Callable[[Any], Any] def _apply_pipeline(value: Any, validators: Iterable[Validator]) -> Any: transformed = value for validator in validators: transformed = validator(transformed) return transformed def _apply_recursive(value: Any, validators: Iterable[Validator]) -> Any: if isinstance(value, dict): return {key: _apply_recursive(val, validators) for key, val in value.items()} if isinstance(value, tuple): # Apply validators to tuple first (e.g., to convert bounds tuples to lists) transformed = _apply_pipeline(value, validators) # If validator converted tuple to list, process recursively if isinstance(transformed, list): return [_apply_recursive(item, validators) for item in transformed] # Otherwise, process tuple items recursively return [_apply_recursive(item, validators) for item in transformed] if isinstance(value, list): return [_apply_recursive(item, validators) for item in value] return _apply_pipeline(value, validators) # --- Public pipelines ------------------------------------------------------- # def _time_series_row_count_validator(v: Any) -> Any: """Block inline DataFrames that exceed the row limit. Raises ------ MeasurementValidationError If the DataFrame has more than 1000 rows. """ if isinstance(v, pd.DataFrame | pl.DataFrame): errors = validate_time_series_row_count(v) if errors: raise MeasurementValidationError( errors[0], errors=errors, ) return v validators_outbound: list[Validator] = [ float_sanitizer, bounds_tuple_validator, file_scheme_validator, _time_series_row_count_validator, df_to_dict_validator, parameter_validator, ] validators_inbound: list[Validator] = [ dict_to_df_validator, ]
[docs] def run_validators_outbound(v: Any) -> Any: """Recursively apply outbound validators to values and nested containers.""" return _apply_recursive(v, validators_outbound)
[docs] def run_validators_inbound(v: Any) -> Any: """Recursively apply inbound validators to values and nested containers.""" return _apply_recursive(v, validators_inbound)