"""
Simulation client for running battery simulations.
This module provides the :class:`SimulationClient` for running battery
simulations using the Universal Cycler Protocol (UCP) format. It supports
single simulations, batch simulations with design of experiments (DOE),
and PyBaMM-based modeling.
"""
from __future__ import annotations
from datetime import datetime, timedelta
import time
from typing import Any, cast
from pydantic import BaseModel, Field, ValidationError
import requests
from .errors import IonworksError
[docs]
class QuickModelConfig(BaseModel):
"""Quick model configuration for protocol-based simulations."""
capacity: float = Field(default=1.0, description="Cell capacity in Ah")
chemistry: str = Field(default="NMC/Graphite", description="Chemistry name")
[docs]
class ProtocolExperimentConfig(BaseModel):
"""Protocol experiment configuration."""
protocol: str = Field(description="YAML protocol string (UCP format)")
name: str = Field(description="Protocol name for template naming")
[docs]
class ProtocolSimulationRequest(BaseModel):
"""Request model for single protocol-based simulation."""
parameterized_model: Any = Field(
description=(
"Model can be: quick_model dict, full model dict, or model ID string"
)
)
protocol_experiment: ProtocolExperimentConfig = Field(
description="Protocol experiment configuration"
)
experiment_parameters: dict[str, float] | None = Field(
default=None,
description=("Experiment parameters for any inputs in the protocol."),
)
design_parameters: dict[str, float] | None = Field(
default=None, description="Design parameters for the simulation"
)
max_backward_jumps: int | None = Field(
default=None,
description="Maximum backward jumps allowed (for goto statements)",
)
study_id: str | None = Field(default=None, description="Optional study UUID")
extra_variables: list[str] | None = Field(
default=None,
description=(
"Optional list of extra variables to include in simulation output "
"(e.g., ['Negative electrode potential [V]', 'Positive electrode "
"potential [V]']). If provided, these override any extra variables "
"defined in the experiment template."
),
)
[docs]
class DOERow(BaseModel):
"""Design of experiments row configuration."""
type: str = Field(description="Type: 'range', 'discrete', or 'normal'")
name: str = Field(description="Parameter name")
# For range type
min: float | None = Field(default=None, description="Minimum value")
max: float | None = Field(default=None, description="Maximum value")
count: int | None = Field(
default=None, description="Number of samples (for grid/random)"
)
# For discrete type
values: list[float] | None = Field(default=None, description="Discrete values")
# For normal type
mean: float | None = Field(default=None, description="Mean value")
std: float | None = Field(default=None, description="Standard deviation")
[docs]
class DesignParametersDOE(BaseModel):
"""Design of experiments configuration."""
sampling: str = Field(
description="Sampling strategy: 'grid', 'random', or 'latin_hypercube'"
)
rows: list[DOERow] = Field(description="DOE row configurations")
count: int | None = Field(
default=None, description="Total count for non-grid sampling"
)
[docs]
class ProtocolSimulationBatchRequest(BaseModel):
"""Request model for batch protocol-based simulation."""
parameterized_model: Any = Field(
description=(
"Model can be: quick_model dict, full model dict, or model ID string"
)
)
protocol_experiment: ProtocolExperimentConfig = Field(
description="Protocol experiment configuration"
)
design_parameters_doe: DesignParametersDOE = Field(
description="Design of experiments configuration"
)
experiment_parameters: dict[str, float] | None = Field(
default=None,
description=("Experiment parameters for any inputs in the protocol."),
)
max_backward_jumps: int | None = Field(
default=None,
description="Maximum backward jumps allowed (for goto statements)",
)
study_id: str | None = Field(default=None, description="Optional study UUID")
extra_variables: list[str] | None = Field(
default=None,
description=(
"Optional list of extra variables to include in simulation output "
"(e.g., ['Negative electrode potential [V]', 'Positive electrode "
"potential [V]']). If provided, these override any extra variables "
"defined in the experiment template."
),
)
[docs]
class SimulationResponse(BaseModel):
"""Response model for simulation creation."""
simulation_id: str = Field(description="Simulation UUID")
job_id: str = Field(description="Job UUID")
[docs]
class SimulationClient:
"""Client for running simulations."""
[docs]
def __init__(self, client: Any) -> None:
"""Initialize the SimulationClient.
Parameters
----------
client : Any
The HTTP client instance for making API requests.
"""
self.client = client
[docs]
def protocol(self, config: dict[str, Any]) -> SimulationResponse:
"""Create a single protocol-based simulation.
Parameters
----------
config : dict[str, Any]
Configuration dictionary containing:
- parameterized_model: quick_model dict, full model dict,
or model ID string
- protocol_experiment: ProtocolExperimentConfig dict with
protocol and name fields
- experiment_parameters: Optional dict with initial_soc and
initial_temperature
- design_parameters: Optional dict[str, float]
- max_backward_jumps: Optional int
- study_id: Optional str
- extra_variables: Optional list[str] — extra variables to
include in simulation output
Returns
-------
SimulationResponse
Response containing simulation_id and job_id.
Raises
------
ValueError
If the configuration is invalid.
"""
endpoint = "/simulations/protocol"
try:
validated_config = ProtocolSimulationRequest(**config)
response_data = self.client.post(
endpoint, json_payload=validated_config.model_dump(exclude_none=True)
)
return SimulationResponse(**response_data)
except ValidationError as e:
raise ValueError(f"Invalid protocol simulation configuration: {e}") from e
[docs]
def protocol_batch(self, config: dict[str, Any]) -> list[SimulationResponse]:
"""Create multiple protocol-based simulations using DOE.
Parameters
----------
config : dict[str, Any]
Configuration dictionary containing:
- parameterized_model: quick_model dict, full model dict,
or model ID string
- protocol_experiment: ProtocolExperimentConfig dict with
protocol and name fields
- design_parameters_doe: DesignParametersDOE dict
- experiment_parameters: Optional dict
- max_backward_jumps: Optional int
- study_id: Optional str
- extra_variables: Optional list[str] — extra variables to
include in simulation output
Returns
-------
list[SimulationResponse]
List of responses, each containing simulation_id and job_id.
Raises
------
ValueError
If the configuration is invalid.
"""
endpoint = "/simulations/protocol/batch"
try:
validated_config = ProtocolSimulationBatchRequest(**config)
response_data = self.client.post(
endpoint, json_payload=validated_config.model_dump(exclude_none=True)
)
if not isinstance(response_data, list):
msg = (
f"Unexpected response format from {endpoint}: expected a "
f"list, got {type(response_data).__name__}"
)
raise ValueError(msg)
return [SimulationResponse(**item) for item in response_data]
except ValidationError as e:
raise ValueError(
f"Invalid batch protocol simulation configuration: {e}"
) from e
[docs]
def list(self) -> list[dict[str, Any]]:
"""List all simulations for the current user.
Returns
-------
list[dict[str, Any]]
List of simulation objects with joined model and experiment data.
"""
endpoint = "/simulations"
response_data = self.client.get(endpoint)
if not isinstance(response_data, list):
msg = (
f"Unexpected response format from {endpoint}: expected a list, "
f"got {type(response_data).__name__}"
)
raise ValueError(msg)
return response_data
[docs]
def get(self, simulation_id: str) -> dict[str, Any]:
"""Get a specific simulation by ID.
Parameters
----------
simulation_id : str
The UUID of the simulation to retrieve.
Returns
-------
dict[str, Any]
Simulation object with full joined data including model, experiment,
and simulation_data (null if not completed).
"""
endpoint = f"/simulations/{simulation_id}"
response_data = self.client.get(endpoint)
return cast(dict[str, Any], response_data)
[docs]
def get_result(self, simulation_id: str) -> dict[str, Any]:
"""Get simulation data/result for a completed simulation.
Parameters
----------
simulation_id : str
The UUID of the simulation.
Returns
-------
dict[str, Any]
Simulation data object containing time_series, steps, and metrics.
Returns 404 if simulation hasn't completed yet.
Raises
------
Exception
If simulation data not found (simulation may not be completed yet).
The client will raise an appropriate error for 404 responses.
"""
endpoint = f"/simulations/{simulation_id}/result"
response_data = self.client.get(endpoint)
return cast(dict[str, Any], response_data)
def _poll_simulations(
self,
simulation_ids: list[str],
timeout: int,
poll_interval: int,
verbose: bool,
) -> dict[str, dict[str, Any]]:
"""Poll simulations until all complete or timeout is reached.
Parameters
----------
simulation_ids : list[str]
List of simulation IDs to poll.
timeout : int
Maximum time to wait in seconds.
poll_interval : int
Time between polls in seconds.
verbose : bool
Whether to print status updates.
Returns
-------
dict[str, dict[str, Any]]
Dict mapping simulation IDs to their completed result dicts.
Raises
------
TimeoutError
If no simulations complete within the timeout.
"""
timeout_delta = timedelta(seconds=timeout)
start_time = datetime.now()
completed: dict[str, dict[str, Any]] = {}
if verbose:
print(f"Polling for {len(simulation_ids)} simulation(s) completion...")
while datetime.now() - start_time < timeout_delta:
for sim_id in simulation_ids:
if sim_id in completed:
continue
try:
simulation = self.get(sim_id)
if (
simulation.get("storage_folder")
or simulation.get("simulation_data") # Legacy fallback
):
completed[sim_id] = simulation
except (IonworksError, requests.exceptions.RequestException) as exc:
if verbose:
print(f" Warning: Error polling simulation {sim_id}: {exc}")
elapsed = (datetime.now() - start_time).seconds
if verbose:
print(
f" Status: {len(completed)}/{len(simulation_ids)} completed "
f"(elapsed: {elapsed}s)"
)
if len(completed) == len(simulation_ids):
if verbose:
print("All simulations completed!")
return completed
time.sleep(poll_interval)
# Timeout reached
if verbose:
print(
f"Timeout: Only {len(completed)}/{len(simulation_ids)} "
f"simulations completed within {timeout} seconds"
)
if not completed:
msg = f"No simulations completed within {timeout} seconds"
raise TimeoutError(msg)
return completed
[docs]
def wait_for_completion(
self,
simulation_id: str | list[str],
timeout: int = 60,
poll_interval: int = 2,
verbose: bool = True,
) -> dict[str, Any] | list[dict[str, Any]]:
"""Wait for simulation(s) to complete by polling until done or timeout.
Parameters
----------
simulation_id : str | list[str]
Single simulation ID or list of simulation IDs to wait for.
timeout : int
Maximum time to wait in seconds (default: 60).
poll_interval : int
Time between polls in seconds (default: 2).
verbose : bool
Whether to print status updates (default: True).
Returns
-------
dict[str, Any] | list[dict[str, Any]]
Completed simulation(s). Returns single dict if single ID
provided, list of dicts if list of IDs provided. Only returns
completed simulations if timeout is reached.
Raises
------
TimeoutError
If timeout is reached before all simulations complete.
"""
is_single = isinstance(simulation_id, str)
simulation_ids = [simulation_id] if is_single else simulation_id # type: ignore[list-item]
completed = self._poll_simulations(
simulation_ids, timeout, poll_interval, verbose
)
if is_single:
if simulation_ids[0] not in completed:
msg = (
f"Simulation {simulation_ids[0]} did not complete within "
f"{timeout} seconds"
)
raise TimeoutError(msg)
return completed[simulation_ids[0]]
return [completed[sim_id] for sim_id in simulation_ids if sim_id in completed]