Source code for ionworks.cell_measurement

"""Cell measurement client for managing time series test data.

This module provides the :class:`CellMeasurementClient` for uploading,
retrieving, and managing measurement data from battery cell testing. It
supports efficient upload of large datasets using signed URLs and parquet
format.
"""

from __future__ import annotations

from concurrent.futures import ThreadPoolExecutor
import io
import math
from typing import Any

import numpy as np
import pandas as pd
import polars as pl
import requests

from .errors import IonworksError
from .models import (
    CellMeasurement,
    CellMeasurementBundleResponse,
    CellMeasurementDetail,
    ConfirmUploadResponse,
    InitiateUploadResponse,
    StepsAndCycles,
)
from .validators import (
    DataFrame,
    df_to_dict_validator,
    dict_to_df_validator,
    get_dataframe_backend,
    validate_measurement_data,
)


def _to_polars(df: DataFrame | dict) -> pl.DataFrame:
    """Convert pandas DataFrame or dict to polars DataFrame.

    Parameters
    ----------
    df : DataFrame | dict
        Input data as pandas DataFrame, polars DataFrame, or dict.

    Returns
    -------
    pl.DataFrame
        Polars DataFrame.
    """
    if isinstance(df, pl.DataFrame):
        return df
    if isinstance(df, pd.DataFrame):
        return pl.from_pandas(df)
    if isinstance(df, dict):
        return pl.DataFrame(df)
    raise TypeError(f"Expected DataFrame or dict, got {type(df).__name__}")


def _sanitize_dict_values(d: dict[str, list[Any]]) -> dict[str, list[Any]]:
    """Sanitize dict values for JSON compatibility (replace inf/nan with None)."""
    result = {}
    for key, values in d.items():
        if isinstance(values, list):
            sanitized = []
            for v in values:
                if isinstance(v, float | np.floating):
                    if math.isinf(v) or math.isnan(v):
                        sanitized.append(None)
                    else:
                        sanitized.append(v)
                else:
                    sanitized.append(v)
            result[key] = sanitized
        else:
            result[key] = values
    return result


[docs] class CellMeasurementClient: """Client for managing cell measurement data.""" #: Default timeout for signed URL uploads as (connect, read) in seconds. UPLOAD_TIMEOUT: tuple[float, float] = (10, 300)
[docs] def __init__(self, client: Any) -> None: """Initialize the CellMeasurementClient. Parameters ---------- client : Any The HTTP client instance for making API calls. """ self.client = client
[docs] def list(self, cell_instance_id: str) -> list[CellMeasurement]: """List all cell measurements for a cell instance by instance ID.""" endpoint = f"/cell_instances/{cell_instance_id}/cell_measurements" response_data = self.client.get(endpoint) return [CellMeasurement(**item) for item in response_data]
[docs] def get(self, measurement_id: str) -> CellMeasurement: """Get a specific cell measurement by its ID only.""" endpoint = f"/cell_measurements/{measurement_id}" response_data = self.client.get(endpoint) return CellMeasurement(**response_data)
[docs] def detail( self, measurement_id: str, use_signed_url: bool = True, include_steps: bool = True, include_cycles: bool = True, include_time_series: bool = True, ) -> CellMeasurementDetail: """Fetch measurement data. By default uses flat endpoints with parallel requests and downloads time series via signed URL. Set ``use_signed_url=False`` to use the legacy bundled ``/detail`` endpoint instead. Use the ``include_*`` flags to skip fetching data you don't need, which avoids unnecessary requests. Parameters ---------- measurement_id : str The ID of the cell measurement. use_signed_url : bool, optional If True (default), uses flat endpoints with parallel requests and signed-URL download. If False, uses the legacy ``/detail`` endpoint. include_steps : bool, optional Whether to fetch step data. Defaults to True. include_cycles : bool, optional Whether to fetch cycle metrics. Defaults to True. include_time_series : bool, optional Whether to fetch time series data. Defaults to True. Returns ------- CellMeasurementDetail Measurement details with requested data fields. Fields not requested will be None. """ if not use_signed_url: return self._detail_legacy( measurement_id, include_steps=include_steps, include_cycles=include_cycles, include_time_series=include_time_series, ) base = f"/cell_measurements/{measurement_id}" futures: dict[str, Any] = {} with ThreadPoolExecutor(max_workers=3) as pool: futures["meta"] = pool.submit(self.client.get, base) if include_steps or include_cycles: futures["sc"] = pool.submit( self.client.get, f"{base}/steps_and_cycles", ) if include_time_series: futures["url"] = pool.submit( self.client.get, f"{base}/time_series/signed_url", ) meta = futures["meta"].result() measurement = CellMeasurement(**meta) steps = None cycles = None if "sc" in futures: sc = futures["sc"].result() if include_steps: steps = sc.get("steps") if include_cycles: cycles = sc.get("cycles") time_series = None if "url" in futures: signed_url = futures["url"].result()["signed_url"] time_series = self._download_from_signed_url(signed_url) return CellMeasurementDetail( measurement=measurement, instance_id=measurement.cell_instance_id, steps=steps, cycles=cycles, time_series=time_series, )
def _detail_legacy( self, measurement_id: str, include_steps: bool = True, include_cycles: bool = True, include_time_series: bool = True, ) -> CellMeasurementDetail: """Fetch detail via the deprecated /detail endpoint.""" endpoint = f"/cell_measurements/{measurement_id}/detail" data = self.client.get(endpoint) # Extract IDs from nested objects instance = data.get("instance") or {} spec = data.get("specification") or {} return CellMeasurementDetail( measurement=data["measurement"], instance_id=instance.get("id"), specification_id=spec.get("id"), steps=data.get("steps") if include_steps else None, time_series=(data.get("time_series") if include_time_series else None), cycles=(data.get("cycles") if include_cycles else None), )
[docs] def steps(self, measurement_id: str) -> DataFrame: """Get step data for a measurement. Parameters ---------- measurement_id : str The ID of the cell measurement. Returns ------- DataFrame Step data (polars or pandas based on config). """ endpoint = f"/cell_measurements/{measurement_id}/steps" data = self.client.get(endpoint) return dict_to_df_validator(data["steps"])
[docs] def cycles(self, measurement_id: str) -> DataFrame: """Get cycle metrics for a measurement. Parameters ---------- measurement_id : str The ID of the cell measurement. Returns ------- DataFrame Cycle metrics (polars or pandas based on config). """ endpoint = f"/cell_measurements/{measurement_id}/cycles" data = self.client.get(endpoint) return dict_to_df_validator(data["cycles"])
[docs] def steps_and_cycles(self, measurement_id: str) -> StepsAndCycles: """Get steps and cycles in one call. More efficient than calling :meth:`steps` and :meth:`cycles` separately since cycles are derived from steps on the server. Parameters ---------- measurement_id : str The ID of the cell measurement. Returns ------- StepsAndCycles Object with ``steps`` and ``cycles`` DataFrames. """ endpoint = f"/cell_measurements/{measurement_id}/steps_and_cycles" data = self.client.get(endpoint) return StepsAndCycles(**data)
[docs] def time_series(self, measurement_id: str) -> DataFrame: """Download full time series via signed URL. Downloads the raw parquet file directly from storage. Use this for full dataset access without backend processing. Parameters ---------- measurement_id : str The ID of the cell measurement. Returns ------- DataFrame Full time series data. """ endpoint = f"/cell_measurements/{measurement_id}/time_series/signed_url" data = self.client.get(endpoint) return self._download_from_signed_url(data["signed_url"])
def _download_from_signed_url(self, signed_url: str) -> DataFrame: """Download parquet file from a signed URL. Returns a polars or pandas DataFrame depending on the configured dataframe backend. """ # Handle Docker-internal URLs when running locally url = signed_url.replace("host.docker.internal", "localhost") try: response = requests.get(url, timeout=self.UPLOAD_TIMEOUT) response.raise_for_status() except requests.exceptions.RequestException as e: raise IonworksError(f"Failed to download from signed URL: {e}") from None buffer = io.BytesIO(response.content) df = pl.read_parquet(buffer) if get_dataframe_backend() == "pandas": return df.to_pandas() return df
[docs] def update(self, measurement_id: str, data: dict[str, Any]) -> CellMeasurement: """Update an existing cell measurement. Parameters ---------- measurement_id : str The ID of the cell measurement to update. data : dict[str, Any] Dictionary containing the fields to update. Returns ------- CellMeasurement The updated cell measurement. """ endpoint = f"/cell_measurements/{measurement_id}" response_data = self.client.put(endpoint, data) return CellMeasurement(**response_data)
[docs] def delete(self, measurement_id: str) -> None: """Delete a cell measurement by measurement ID only.""" endpoint = f"/cell_measurements/{measurement_id}" self.client.delete(endpoint)
def _dataframe_to_parquet(self, df: pl.DataFrame) -> bytes: """Convert a polars DataFrame to parquet bytes. Parameters ---------- df : pl.DataFrame The DataFrame to convert. Returns ------- bytes Parquet file content (uses zstd compression). """ buffer = io.BytesIO() df.write_parquet(buffer, compression="zstd") return buffer.getvalue() def _initiate_upload( self, cell_instance_id: str, name: str, notes: str | None = None, ) -> InitiateUploadResponse: """Initiate a signed URL upload for a new measurement. Parameters ---------- cell_instance_id : str The ID of the cell instance. name : str Name for the new measurement. notes : str | None Optional notes for the measurement. Returns ------- InitiateUploadResponse Response containing measurement_id, signed_url, token, path. """ endpoint = ( f"/cell_instances/{cell_instance_id}/cell_measurements/initiate-upload" ) measurement_data: dict[str, Any] = {"name": name} if notes: measurement_data["notes"] = notes payload = {"measurement": measurement_data} response_data = self.client.post(endpoint, payload) return InitiateUploadResponse(**response_data) def _upload_to_signed_url(self, signed_url: str, parquet_bytes: bytes) -> None: """Upload parquet bytes to a signed URL. Parameters ---------- signed_url : str The signed URL to upload to. parquet_bytes : bytes The parquet file content to upload. Raises ------ IonworksError If the upload fails. """ # Handle Docker-internal URLs when running locally url = signed_url.replace("host.docker.internal", "localhost") try: response = requests.put( url, data=parquet_bytes, headers={"Content-Type": "application/octet-stream"}, timeout=self.UPLOAD_TIMEOUT, ) response.raise_for_status() except requests.exceptions.RequestException as e: raise IonworksError(f"Failed to upload to signed URL: {e}") from None def _confirm_upload( self, measurement_id: str, cell_instance_id: str, measurement_data: dict[str, Any], steps: dict[str, list[Any]] | None = None, ) -> ConfirmUploadResponse: """Confirm a signed URL upload after successful file upload. This creates the measurement record in the database. The measurement data must be passed again (same as initiate) to ensure no orphaned records are created if the file upload fails. Parameters ---------- measurement_id : str The pre-generated ID for the measurement (from initiate). cell_instance_id : str The ID of the parent cell instance. measurement_data : dict[str, Any] Measurement metadata (name, notes, etc.) — same as passed to initiate. steps : dict[str, list[Any]] | None Pre-calculated steps data. If not provided, will be auto-calculated. Returns ------- ConfirmUploadResponse Response containing instance, measurement, steps_created, and file_path. """ endpoint = f"/cell_measurements/{measurement_id}/confirm-upload" payload: dict[str, Any] = { "cell_instance_id": cell_instance_id, "measurement": measurement_data, } if steps: payload["steps"] = steps response_data = self.client.post(endpoint, payload) return ConfirmUploadResponse(**response_data)
[docs] def create( self, cell_instance_id: str, measurement_detail: dict[str, Any], validate_strict: bool = False, ) -> CellMeasurementBundleResponse: """Create a new cell measurement with steps and time series data. Uses signed URL upload for better performance with large datasets. Data is uploaded directly to storage as parquet, bypassing backend JSON parsing. No database record is created until the upload is confirmed, preventing orphaned records if upload fails. Parameters ---------- cell_instance_id : str The ID of the cell instance to create the measurement for. measurement_detail : dict[str, Any] Dictionary containing 'measurement', 'steps', and 'time_series'. - measurement: dict with 'name' (required) and 'notes' - time_series: pandas DataFrame or dict with time series data - steps: optional dict with pre-calculated steps data validate_strict : bool, optional If False (default), skips strict validation. If True, runs strict validation including minimum points per step. Returns ------- CellMeasurementBundleResponse Response containing the created measurement, steps count, and file path. Raises ------ MeasurementValidationError If data validation fails. Validation checks include: - Positive current should correspond to discharge - Cumulative values should reset at each step """ measurement_info = measurement_detail["measurement"] name = measurement_info["name"] notes = measurement_info.get("notes") # Step 1: Convert time_series to polars DataFrame # Accepts pandas DataFrame, polars DataFrame, or dict time_series = _to_polars(measurement_detail["time_series"]) # Step 2: Validate the data before any upload data_type = measurement_info.get("data_type") validate_measurement_data( time_series, strict=validate_strict, data_type=data_type ) # Step 3: Initiate upload (validates metadata, returns signed URL, no DB record) initiate_result = self._initiate_upload(cell_instance_id, name, notes) # Step 4: Convert to parquet parquet_bytes = self._dataframe_to_parquet(time_series) # Step 5: Upload to signed URL self._upload_to_signed_url(initiate_result.signed_url, parquet_bytes) # Step 6: Confirm upload (creates the measurement record) steps = measurement_detail.get("steps") if steps is not None: # Convert DataFrame to dict if needed, and sanitize inf/nan values steps = df_to_dict_validator(steps) # If steps was already a dict, sanitize it for JSON compatibility if isinstance(steps, dict): steps = _sanitize_dict_values(steps) # Strip internal-only keys before sending to the API api_measurement_info = { k: v for k, v in measurement_info.items() if k != "data_type" } confirm_result = self._confirm_upload( measurement_id=initiate_result.measurement_id, cell_instance_id=cell_instance_id, measurement_data=api_measurement_info, steps=steps, ) return CellMeasurementBundleResponse( measurement=confirm_result.measurement, steps_created=confirm_result.steps_created, file_path=confirm_result.file_path, )
[docs] def create_or_get( self, cell_instance_id: str, measurement_detail: dict[str, Any], validate_strict: bool = False, ) -> CellMeasurement: """Create a new cell measurement or get existing. Always returns a :class:`CellMeasurement` regardless of whether the measurement was newly created or already existed. Parameters ---------- cell_instance_id : str The ID of the cell instance. measurement_detail : dict[str, Any] Dictionary containing ``measurement`` and ``time_series`` (same as :meth:`create`). validate_strict : bool, optional If False (default), skips strict validation. If True, runs strict validation including minimum points per step. Returns ------- CellMeasurement The measurement (newly created or existing). """ try: bundle = self.create( cell_instance_id, measurement_detail, validate_strict, ) return bundle.measurement except IonworksError as e: if e.error_code == "CONFLICT" or e.status_code == 409: # Try to get existing measurement by ID from error detail if e.data is not None: detail = e.data.get("detail", {}) existing_id = ( detail.get("existing_id") if isinstance(detail, dict) else None ) if existing_id: return self.get(existing_id) # Fall back to listing and matching by name measurement_name = measurement_detail["measurement"]["name"] measurements = self.list(cell_instance_id) for m in measurements: if m.name == measurement_name: return m raise ValueError( f"Measurement " f"'{measurement_name}' reported " f"as duplicate but could not be " f"found" ) from e raise