Source code for ionworks.simulation

"""
Simulation client for running battery simulations.

This module provides the :class:`SimulationClient` for running battery
simulations using the Universal Cycler Protocol (UCP) format. It supports
single simulations, batch simulations with design of experiments (DOE),
and PyBaMM-based modeling.
"""

from __future__ import annotations

from datetime import datetime, timedelta
import time
from typing import Any, cast

from pydantic import BaseModel, Field, ValidationError
import requests

from .errors import IonworksError


[docs] class QuickModelConfig(BaseModel): """Quick model configuration for protocol-based simulations.""" capacity: float = Field(default=1.0, description="Cell capacity in Ah") chemistry: str = Field(default="NMC/Graphite", description="Chemistry name")
[docs] class ProtocolExperimentConfig(BaseModel): """Protocol experiment configuration.""" protocol: str = Field(description="YAML protocol string (UCP format)") name: str = Field(description="Protocol name for template naming")
[docs] class ProtocolSimulationRequest(BaseModel): """Request model for single protocol-based simulation.""" parameterized_model: Any = Field( description=( "Model can be: quick_model dict, full model dict, or model ID string" ) ) protocol_experiment: ProtocolExperimentConfig = Field( description="Protocol experiment configuration" ) experiment_parameters: dict[str, float] | None = Field( default=None, description=("Experiment parameters for any inputs in the protocol."), ) design_parameters: dict[str, float] | None = Field( default=None, description="Design parameters for the simulation" ) max_backward_jumps: int | None = Field( default=None, description="Maximum backward jumps allowed (for goto statements)", ) study_id: str | None = Field(default=None, description="Optional study UUID") extra_variables: list[str] | None = Field( default=None, description=( "Optional list of extra variables to include in simulation output " "(e.g., ['Negative electrode potential [V]', 'Positive electrode " "potential [V]']). If provided, these override any extra variables " "defined in the experiment template." ), )
[docs] class DOERow(BaseModel): """Design of experiments row configuration.""" type: str = Field(description="Type: 'range', 'discrete', or 'normal'") name: str = Field(description="Parameter name") # For range type min: float | None = Field(default=None, description="Minimum value") max: float | None = Field(default=None, description="Maximum value") count: int | None = Field( default=None, description="Number of samples (for grid/random)" ) # For discrete type values: list[float] | None = Field(default=None, description="Discrete values") # For normal type mean: float | None = Field(default=None, description="Mean value") std: float | None = Field(default=None, description="Standard deviation")
[docs] class DesignParametersDOE(BaseModel): """Design of experiments configuration.""" sampling: str = Field( description="Sampling strategy: 'grid', 'random', or 'latin_hypercube'" ) rows: list[DOERow] = Field(description="DOE row configurations") count: int | None = Field( default=None, description="Total count for non-grid sampling" )
[docs] class ProtocolSimulationBatchRequest(BaseModel): """Request model for batch protocol-based simulation.""" parameterized_model: Any = Field( description=( "Model can be: quick_model dict, full model dict, or model ID string" ) ) protocol_experiment: ProtocolExperimentConfig = Field( description="Protocol experiment configuration" ) design_parameters_doe: DesignParametersDOE = Field( description="Design of experiments configuration" ) experiment_parameters: dict[str, float] | None = Field( default=None, description=("Experiment parameters for any inputs in the protocol."), ) max_backward_jumps: int | None = Field( default=None, description="Maximum backward jumps allowed (for goto statements)", ) study_id: str | None = Field(default=None, description="Optional study UUID") extra_variables: list[str] | None = Field( default=None, description=( "Optional list of extra variables to include in simulation output " "(e.g., ['Negative electrode potential [V]', 'Positive electrode " "potential [V]']). If provided, these override any extra variables " "defined in the experiment template." ), )
[docs] class SimulationResponse(BaseModel): """Response model for simulation creation.""" simulation_id: str = Field(description="Simulation UUID") job_id: str = Field(description="Job UUID")
[docs] class SimulationClient: """Client for running simulations."""
[docs] def __init__(self, client: Any) -> None: """Initialize the SimulationClient. Parameters ---------- client : Any The HTTP client instance for making API requests. """ self.client = client
[docs] def protocol(self, config: dict[str, Any]) -> SimulationResponse: """Create a single protocol-based simulation. Parameters ---------- config : dict[str, Any] Configuration dictionary containing: - parameterized_model: quick_model dict, full model dict, or model ID string - protocol_experiment: ProtocolExperimentConfig dict with protocol and name fields - experiment_parameters: Optional dict with initial_soc and initial_temperature - design_parameters: Optional dict[str, float] - max_backward_jumps: Optional int - study_id: Optional str - extra_variables: Optional list[str] — extra variables to include in simulation output Returns ------- SimulationResponse Response containing simulation_id and job_id. Raises ------ ValueError If the configuration is invalid. """ endpoint = "/simulations/protocol" try: validated_config = ProtocolSimulationRequest(**config) response_data = self.client.post( endpoint, json_payload=validated_config.model_dump(exclude_none=True) ) return SimulationResponse(**response_data) except ValidationError as e: raise ValueError(f"Invalid protocol simulation configuration: {e}") from e
[docs] def protocol_batch(self, config: dict[str, Any]) -> list[SimulationResponse]: """Create multiple protocol-based simulations using DOE. Parameters ---------- config : dict[str, Any] Configuration dictionary containing: - parameterized_model: quick_model dict, full model dict, or model ID string - protocol_experiment: ProtocolExperimentConfig dict with protocol and name fields - design_parameters_doe: DesignParametersDOE dict - experiment_parameters: Optional dict - max_backward_jumps: Optional int - study_id: Optional str - extra_variables: Optional list[str] — extra variables to include in simulation output Returns ------- list[SimulationResponse] List of responses, each containing simulation_id and job_id. Raises ------ ValueError If the configuration is invalid. """ endpoint = "/simulations/protocol/batch" try: validated_config = ProtocolSimulationBatchRequest(**config) response_data = self.client.post( endpoint, json_payload=validated_config.model_dump(exclude_none=True) ) if not isinstance(response_data, list): msg = ( f"Unexpected response format from {endpoint}: expected a " f"list, got {type(response_data).__name__}" ) raise ValueError(msg) return [SimulationResponse(**item) for item in response_data] except ValidationError as e: raise ValueError( f"Invalid batch protocol simulation configuration: {e}" ) from e
[docs] def list(self) -> list[dict[str, Any]]: """List all simulations for the current user. Returns ------- list[dict[str, Any]] List of simulation objects with joined model and experiment data. """ endpoint = "/simulations" response_data = self.client.get(endpoint) if not isinstance(response_data, list): msg = ( f"Unexpected response format from {endpoint}: expected a list, " f"got {type(response_data).__name__}" ) raise ValueError(msg) return response_data
[docs] def get(self, simulation_id: str) -> dict[str, Any]: """Get a specific simulation by ID. Parameters ---------- simulation_id : str The UUID of the simulation to retrieve. Returns ------- dict[str, Any] Simulation object with full joined data including model, experiment, and simulation_data (null if not completed). """ endpoint = f"/simulations/{simulation_id}" response_data = self.client.get(endpoint) return cast(dict[str, Any], response_data)
[docs] def get_result(self, simulation_id: str) -> dict[str, Any]: """Get simulation data/result for a completed simulation. Parameters ---------- simulation_id : str The UUID of the simulation. Returns ------- dict[str, Any] Simulation data object containing time_series, steps, and metrics. Returns 404 if simulation hasn't completed yet. Raises ------ Exception If simulation data not found (simulation may not be completed yet). The client will raise an appropriate error for 404 responses. """ endpoint = f"/simulations/{simulation_id}/result" response_data = self.client.get(endpoint) return cast(dict[str, Any], response_data)
def _poll_simulations( self, simulation_ids: list[str], timeout: int, poll_interval: int, verbose: bool, ) -> dict[str, dict[str, Any]]: """Poll simulations until all complete or timeout is reached. Parameters ---------- simulation_ids : list[str] List of simulation IDs to poll. timeout : int Maximum time to wait in seconds. poll_interval : int Time between polls in seconds. verbose : bool Whether to print status updates. Returns ------- dict[str, dict[str, Any]] Dict mapping simulation IDs to their completed result dicts. Raises ------ TimeoutError If no simulations complete within the timeout. """ timeout_delta = timedelta(seconds=timeout) start_time = datetime.now() completed: dict[str, dict[str, Any]] = {} if verbose: print(f"Polling for {len(simulation_ids)} simulation(s) completion...") while datetime.now() - start_time < timeout_delta: for sim_id in simulation_ids: if sim_id in completed: continue try: simulation = self.get(sim_id) if ( simulation.get("storage_folder") or simulation.get("simulation_data") # Legacy fallback ): completed[sim_id] = simulation except (IonworksError, requests.exceptions.RequestException) as exc: if verbose: print(f" Warning: Error polling simulation {sim_id}: {exc}") elapsed = (datetime.now() - start_time).seconds if verbose: print( f" Status: {len(completed)}/{len(simulation_ids)} completed " f"(elapsed: {elapsed}s)" ) if len(completed) == len(simulation_ids): if verbose: print("All simulations completed!") return completed time.sleep(poll_interval) # Timeout reached if verbose: print( f"Timeout: Only {len(completed)}/{len(simulation_ids)} " f"simulations completed within {timeout} seconds" ) if not completed: msg = f"No simulations completed within {timeout} seconds" raise TimeoutError(msg) return completed
[docs] def wait_for_completion( self, simulation_id: str | list[str], timeout: int = 60, poll_interval: int = 2, verbose: bool = True, ) -> dict[str, Any] | list[dict[str, Any]]: """Wait for simulation(s) to complete by polling until done or timeout. Parameters ---------- simulation_id : str | list[str] Single simulation ID or list of simulation IDs to wait for. timeout : int Maximum time to wait in seconds (default: 60). poll_interval : int Time between polls in seconds (default: 2). verbose : bool Whether to print status updates (default: True). Returns ------- dict[str, Any] | list[dict[str, Any]] Completed simulation(s). Returns single dict if single ID provided, list of dicts if list of IDs provided. Only returns completed simulations if timeout is reached. Raises ------ TimeoutError If timeout is reached before all simulations complete. """ is_single = isinstance(simulation_id, str) simulation_ids = [simulation_id] if is_single else simulation_id # type: ignore[list-item] completed = self._poll_simulations( simulation_ids, timeout, poll_interval, verbose ) if is_single: if simulation_ids[0] not in completed: msg = ( f"Simulation {simulation_ids[0]} did not complete within " f"{timeout} seconds" ) raise TimeoutError(msg) return completed[simulation_ids[0]] return [completed[sim_id] for sim_id in simulation_ids if sim_id in completed]