Skip to content

plotting

Trajectory visualization and plotting utilities.

This module provides reusable building blocks for visualizing trajectory optimization results. It is intentionally minimal - we provide common utilities that can be composed together, not a complete solution that tries to do everything for you.

2D Plots (plotly-based): Two-layer API for time series visualization::

    from openscvx.plotting import plot_states, plot_controls, plot_vector_norm

    # High-level: subplot grid with individual scaling per component
    plot_states(results, ["position", "velocity"]).show()
    plot_controls(results, ["thrust"]).show()

    # Low-level: single component
    plot_state_component(results, "position", component=2).show()  # z only

    # Specialized plots
    plot_vector_norm(results, "thrust", bounds=(rho_min, rho_max)).show()
    plot_projections_2d(results, velocity_var_name="velocity").show()

3D Visualization (viser-based): The viser submodule provides composable primitives for building interactive 3D visualizations. See openscvx.plotting.viser for details::

    from openscvx.plotting import viser
    server = viser.create_server(positions)
    viser.add_gates(server, gate_vertices)
    server.sleep_forever()

For problem-specific visualization examples (drones, rockets, etc.), see examples/plotting_viser.py.

plot_control_component(result: OptimizationResults, control_name: str, component: int = 0) -> go.Figure

Plot a single component of a control variable vs time.

This is the low-level function for plotting one scalar control over time. For plotting all components of a control, use plot_controls().

Parameters:

Name Type Description Default
result OptimizationResults

Optimization results containing control trajectories

required
control_name str

Name of the control variable

required
component int

Component index (0-indexed). For scalar controls, use 0.

0

Returns:

Type Description
Figure

Plotly figure with single plot

Example

plot_control_component(result, "thrust", 0) # Plot thrust_x

Source code in openscvx/plotting/plotting.py
def plot_control_component(
    result: OptimizationResults,
    control_name: str,
    component: int = 0,
) -> go.Figure:
    """Plot a single component of a control variable vs time.

    This is the low-level function for plotting one scalar control over time.
    For plotting all components of a control, use plot_controls().

    Args:
        result: Optimization results containing control trajectories
        control_name: Name of the control variable
        component: Component index (0-indexed). For scalar controls, use 0.

    Returns:
        Plotly figure with single plot

    Example:
        >>> plot_control_component(result, "thrust", 0)  # Plot thrust_x
    """
    available = {c.name for c in result._controls}
    if control_name not in available:
        raise ValueError(f"Control '{control_name}' not found. Available: {sorted(available)}")

    dim = _get_var_dim(result, control_name, result._controls)
    if component < 0 or component >= dim:
        raise ValueError(f"Component {component} out of range for '{control_name}' (dim={dim})")

    t_nodes = result.nodes["time"].flatten()
    has_trajectory = bool(result.trajectory) and control_name in result.trajectory
    t_full = result.trajectory["time"].flatten() if has_trajectory else None

    label = f"{control_name}_{component}" if dim > 1 else control_name

    fig = go.Figure()
    fig.update_layout(title_text=label, template="plotly_dark")

    if has_trajectory:
        data = result.trajectory[control_name]
        y = data if data.ndim == 1 else data[:, component]
        fig.add_trace(
            go.Scatter(
                x=t_full,
                y=y,
                mode="lines",
                name="Trajectory",
                line={"color": "green", "width": 2},
            )
        )

    if control_name in result.nodes:
        data = result.nodes[control_name]
        y = data if data.ndim == 1 else data[:, component]
        fig.add_trace(
            go.Scatter(
                x=t_nodes,
                y=y,
                mode="markers",
                name="Nodes",
                marker={"color": "cyan", "size": 6},
            )
        )

    fig.update_xaxes(title_text="Time (s)")
    fig.update_yaxes(title_text=label)
    return fig

plot_controls(result: OptimizationResults, control_names: list[str] | None = None, include_private: bool = False, cols: int = 3) -> go.Figure

Plot control variables in a subplot grid.

Each component of each control gets its own subplot with individual y-axis scaling. This is the primary function for visualizing control trajectories.

Parameters:

Name Type Description Default
result OptimizationResults

Optimization results containing control trajectories

required
control_names list[str] | None

List of control names to plot. If None, plots all controls.

None
include_private bool

Whether to include private controls (names starting with '_')

False
cols int

Maximum number of columns in subplot grid

3

Returns:

Type Description
Figure

Plotly figure with subplot grid

Examples:

>>> plot_controls(result, ["thrust"])  # 3 subplots for x, y, z
>>> plot_controls(result)  # All controls
Source code in openscvx/plotting/plotting.py
def plot_controls(
    result: OptimizationResults,
    control_names: list[str] | None = None,
    include_private: bool = False,
    cols: int = 3,
) -> go.Figure:
    """Plot control variables in a subplot grid.

    Each component of each control gets its own subplot with individual y-axis
    scaling. This is the primary function for visualizing control trajectories.

    Args:
        result: Optimization results containing control trajectories
        control_names: List of control names to plot. If None, plots all controls.
        include_private: Whether to include private controls (names starting with '_')
        cols: Maximum number of columns in subplot grid

    Returns:
        Plotly figure with subplot grid

    Examples:
        >>> plot_controls(result, ["thrust"])  # 3 subplots for x, y, z
        >>> plot_controls(result)  # All controls
    """

    controls = result._controls
    if not include_private:
        controls = [c for c in controls if not c.name.startswith("_")]

    if control_names is not None:
        available = {c.name for c in controls}
        missing = set(control_names) - available
        if missing:
            raise ValueError(f"Controls not found in result: {missing}")
        # Preserve order from control_names
        control_order = {name: i for i, name in enumerate(control_names)}
        controls = sorted(
            [c for c in controls if c.name in control_names],
            key=lambda c: control_order[c.name],
        )

    # Build list of (display_name, var_name, component_idx)
    components = []
    for c in controls:
        dim = _get_var_dim(result, c.name, result._controls)
        if dim == 1:
            components.append((c.name, c.name, 0))
        else:
            for i in range(dim):
                components.append((f"{c.name}_{i}", c.name, i))

    if not components:
        raise ValueError("No control components to plot")

    n_cols = min(cols, len(components))
    n_rows = (len(components) + n_cols - 1) // n_cols

    fig = make_subplots(rows=n_rows, cols=n_cols, subplot_titles=[c[0] for c in components])
    fig.update_layout(title_text="Control Trajectories", template="plotly_dark")

    for idx, (_, var_name, comp_idx) in enumerate(components):
        row = (idx // n_cols) + 1
        col = (idx % n_cols) + 1

        # Get bounds for this component
        var = _get_var(result, var_name, result._controls)
        min_val = var.min[comp_idx] if var.min is not None else None
        max_val = var.max[comp_idx] if var.max is not None else None

        _add_component_traces(
            fig,
            result,
            var_name,
            comp_idx,
            row,
            col,
            show_legend=(idx == 0),
            min_val=min_val,
            max_val=max_val,
        )

    # Add x-axis labels to bottom row
    for col_idx in range(1, n_cols + 1):
        fig.update_xaxes(title_text="Time (s)", row=n_rows, col=col_idx)

    return fig

plot_projections_2d(result: OptimizationResults, var_name: str = 'position', velocity_var_name: str | None = None, cmap: str = 'viridis') -> go.Figure

Plot XY, XZ, YZ projections of a 3D variable.

Useful for visualizing 3D trajectories in 2D plane views.

Parameters:

Name Type Description Default
result OptimizationResults

Optimization results containing trajectories

required
var_name str

Name of the 3D variable to plot (default: "position")

'position'
velocity_var_name str | None

Optional name of velocity variable for coloring by speed. If provided, trajectory points are colored by velocity magnitude.

None
cmap str

Matplotlib colormap name for velocity coloring (default: "viridis")

'viridis'

Returns:

Type Description
Figure

Plotly figure with three subplots (XY, XZ, YZ planes)

Source code in openscvx/plotting/plotting.py
def plot_projections_2d(
    result: OptimizationResults,
    var_name: str = "position",
    velocity_var_name: str | None = None,
    cmap: str = "viridis",
) -> go.Figure:
    """Plot XY, XZ, YZ projections of a 3D variable.

    Useful for visualizing 3D trajectories in 2D plane views.

    Args:
        result: Optimization results containing trajectories
        var_name: Name of the 3D variable to plot (default: "position")
        velocity_var_name: Optional name of velocity variable for coloring by speed.
            If provided, trajectory points are colored by velocity magnitude.
        cmap: Matplotlib colormap name for velocity coloring (default: "viridis")

    Returns:
        Plotly figure with three subplots (XY, XZ, YZ planes)
    """
    import numpy as np

    has_trajectory = bool(result.trajectory) and var_name in result.trajectory
    has_nodes = var_name in result.nodes

    if not has_trajectory and not has_nodes:
        available_traj = set(result.trajectory.keys()) if result.trajectory else set()
        available_nodes = set(result.nodes.keys())
        raise ValueError(
            f"Variable '{var_name}' not found. "
            f"Available in trajectory: {sorted(available_traj)}, nodes: {sorted(available_nodes)}"
        )

    fig = make_subplots(
        rows=2,
        cols=2,
        subplot_titles=("XY Plane", "XZ Plane", "YZ Plane"),
        specs=[[{}, {}], [{}, None]],
    )

    # Subplot positions: (x_idx, y_idx, row, col)
    subplots = [(0, 1, 1, 1), (0, 2, 1, 2), (1, 2, 2, 1)]

    # Compute velocity norms if velocity variable is provided
    traj_vel_norm = None
    node_vel_norm = None
    if velocity_var_name is not None:
        if has_trajectory and velocity_var_name in result.trajectory:
            traj_vel_norm = np.linalg.norm(result.trajectory[velocity_var_name], axis=1)
        if has_nodes and velocity_var_name in result.nodes:
            node_vel_norm = np.linalg.norm(result.nodes[velocity_var_name], axis=1)

    # Colorbar config (only shown once)
    colorbar_cfg = {"title": "‖velocity‖", "x": 1.02, "y": 0.5, "len": 0.9}

    # Plot trajectory if available
    if has_trajectory:
        data = result.trajectory[var_name]
        for i, (xi, yi, row, col) in enumerate(subplots):
            if traj_vel_norm is not None:
                marker = {
                    "size": 4,
                    "color": traj_vel_norm,
                    "colorscale": cmap,
                    "showscale": (i == 0),
                    "colorbar": colorbar_cfg if i == 0 else None,
                }
                fig.add_trace(
                    go.Scatter(
                        x=data[:, xi],
                        y=data[:, yi],
                        mode="markers",
                        marker=marker,
                        name="Trajectory",
                        legendgroup="trajectory",
                        showlegend=(i == 0),
                    ),
                    row=row,
                    col=col,
                )
            else:
                fig.add_trace(
                    go.Scatter(
                        x=data[:, xi],
                        y=data[:, yi],
                        mode="lines",
                        line={"color": "green", "width": 2},
                        name="Trajectory",
                        legendgroup="trajectory",
                        showlegend=(i == 0),
                    ),
                    row=row,
                    col=col,
                )

    # Plot nodes if available
    if has_nodes:
        data = result.nodes[var_name]
        # Only show colorbar on nodes if trajectory doesn't have one
        show_node_colorbar = (traj_vel_norm is None) and (node_vel_norm is not None)
        for i, (xi, yi, row, col) in enumerate(subplots):
            if node_vel_norm is not None:
                marker = {
                    "size": 8,
                    "color": node_vel_norm,
                    "colorscale": cmap,
                    "showscale": show_node_colorbar and (i == 0),
                    "colorbar": colorbar_cfg if (show_node_colorbar and i == 0) else None,
                    "line": {"color": "white", "width": 1},
                }
            else:
                marker = {"color": "cyan", "size": 6}
            fig.add_trace(
                go.Scatter(
                    x=data[:, xi],
                    y=data[:, yi],
                    mode="markers",
                    marker=marker,
                    name="Nodes",
                    legendgroup="nodes",
                    showlegend=(i == 0),
                ),
                row=row,
                col=col,
            )

    # Set axis titles
    fig.update_xaxes(title_text="X", row=1, col=1)
    fig.update_yaxes(title_text="Y", row=1, col=1)
    fig.update_xaxes(title_text="X", row=1, col=2)
    fig.update_yaxes(title_text="Z", row=1, col=2)
    fig.update_xaxes(title_text="Y", row=2, col=1)
    fig.update_yaxes(title_text="Z", row=2, col=1)

    # Set equal aspect ratio for each subplot
    layout_opts = {
        "title": f"{var_name} - XY, XZ, YZ Projections",
        "template": "plotly_dark",
        "xaxis": {"scaleanchor": "y"},
        "xaxis2": {"scaleanchor": "y2"},
        "xaxis3": {"scaleanchor": "y3"},
    }
    # Move legend to bottom-right when using colorbar to avoid overlap
    if velocity_var_name is not None:
        layout_opts["legend"] = {
            "orientation": "h",
            "yanchor": "bottom",
            "y": -0.15,
            "xanchor": "center",
            "x": 0.5,
        }
    fig.update_layout(**layout_opts)

    return fig

plot_scp_convergence_histories(result: OptimizationResults) -> go.Figure

Plot SCP convergence histories: trust region weight, reduction histories, and acceptance ratio.

Creates three separate plots: 1. Trust region weight (lam_prox) history 2. Actual and predicted reduction histories (overlaid) 3. Acceptance ratio history

Parameters:

Name Type Description Default
result OptimizationResults

OptimizationResults containing the convergence histories.

required

Returns:

Type Description
Figure

Plotly figure with three subplots

Example

problem.initialize() result = problem.solve() plot_scp_convergence_histories(result).show()

Source code in openscvx/plotting/scp_iteration.py
def plot_scp_convergence_histories(result: OptimizationResults) -> go.Figure:
    """Plot SCP convergence histories: trust region weight, reduction histories,
    and acceptance ratio.

    Creates three separate plots:
    1. Trust region weight (lam_prox) history
    2. Actual and predicted reduction histories (overlaid)
    3. Acceptance ratio history

    Args:
        result: OptimizationResults containing the convergence histories.

    Returns:
        Plotly figure with three subplots

    Example:
        >>> problem.initialize()
        >>> result = problem.solve()
        >>> plot_scp_convergence_histories(result).show()
    """
    if not isinstance(result, OptimizationResults):
        raise TypeError(f"Expected OptimizationResults, got {type(result)}")

    # Create subplots: 3 rows, 1 column
    fig = make_subplots(
        rows=3,
        cols=1,
        subplot_titles=(
            "Trust Region Weight History",
            "Actual vs Predicted Reduction History",
            "Acceptance Ratio History",
        ),
        vertical_spacing=0.12,
    )

    # Prepare iteration indices (0-indexed for plotting)
    iterations_lam_prox = np.arange(len(result.lam_prox_history))
    iterations_reduction = np.arange(len(result.actual_reduction_history))
    iterations_ratio = np.arange(len(result.acceptance_ratio_history))

    # Plot 1: Trust region weight history
    if len(result.lam_prox_history) > 0:
        fig.add_trace(
            go.Scatter(
                x=iterations_lam_prox,
                y=result.lam_prox_history,
                mode="lines+markers",
                name="lam_prox",
                line={"color": "cyan", "width": 2},
                marker={"size": 6},
                hovertemplate="Iteration: %{x}<br>lam_prox: %{y:.3g}<extra></extra>",
            ),
            row=1,
            col=1,
        )
    else:
        fig.add_annotation(
            text="No trust region weight history available",
            xref="x1",
            yref="y1",
            x=0.5,
            y=0.5,
            showarrow=False,
            row=1,
            col=1,
        )

    # Plot 2: Actual and predicted reduction histories
    if len(result.actual_reduction_history) > 0 and len(result.pred_reduction_history) > 0:
        # Ensure both histories have the same length
        min_len = min(len(result.actual_reduction_history), len(result.pred_reduction_history))
        actual_reduction = result.actual_reduction_history[:min_len]
        predicted_reduction = result.pred_reduction_history[:min_len]
        iterations_reduction = np.arange(min_len)

        fig.add_trace(
            go.Scatter(
                x=iterations_reduction,
                y=actual_reduction,
                mode="lines+markers",
                name="Actual Reduction",
                line={"color": "green", "width": 2},
                marker={"size": 6},
                hovertemplate="Iteration: %{x}<br>Actual Reduction: %{y:.3g}<extra></extra>",
            ),
            row=2,
            col=1,
        )

        fig.add_trace(
            go.Scatter(
                x=iterations_reduction,
                y=predicted_reduction,
                mode="lines+markers",
                name="Predicted Reduction",
                line={"color": "orange", "width": 2},
                marker={"size": 6},
                hovertemplate="Iteration: %{x}<br>Predicted Reduction: %{y:.3g}<extra></extra>",
            ),
            row=2,
            col=1,
        )
    else:
        fig.add_annotation(
            text="No reduction history available",
            xref="x2",
            yref="y2",
            x=0.5,
            y=0.5,
            showarrow=False,
            row=2,
            col=1,
        )

    # Plot 3: Acceptance ratio history
    if len(result.acceptance_ratio_history) > 0:
        fig.add_trace(
            go.Scatter(
                x=iterations_ratio,
                y=result.acceptance_ratio_history,
                mode="lines+markers",
                name="Acceptance Ratio",
                line={"color": "magenta", "width": 2},
                marker={"size": 6},
                hovertemplate="Iteration: %{x}<br>Acceptance Ratio: %{y:.3g}<extra></extra>",
            ),
            row=3,
            col=1,
        )

        # Add reference lines at typical thresholds (eta_1=1e-6, eta_2=0.9)
        fig.add_hline(
            y=1e-6,
            line_dash="dash",
            line_color="red",
            opacity=0.5,
            annotation_text="η₁ = 1e-6",
            row=3,
            col=1,
        )
        fig.add_hline(
            y=0.9,
            line_dash="dash",
            line_color="yellow",
            opacity=0.5,
            annotation_text="η₂ = 0.9",
            row=3,
            col=1,
        )
    else:
        fig.add_annotation(
            text="No acceptance ratio history available",
            xref="x3",
            yref="y3",
            x=0.5,
            y=0.5,
            showarrow=False,
            row=3,
            col=1,
        )

    # Update layout
    fig.update_layout(
        title_text="SCP Convergence Histories",
        template="plotly_dark",
        showlegend=True,
        height=900,
        legend={
            "yanchor": "top",
            "y": 0.99,
            "xanchor": "left",
            "x": 1.02,
            "bgcolor": "rgba(0, 0, 0, 0.5)",
        },
    )

    # Update axes labels
    fig.update_xaxes(title_text="Iteration", row=1, col=1)
    fig.update_yaxes(title_text="lam_prox", type="log", row=1, col=1)

    fig.update_xaxes(title_text="Iteration", row=2, col=1)
    fig.update_yaxes(title_text="Reduction", row=2, col=1)  # Linear scale

    fig.update_xaxes(title_text="Iteration", row=3, col=1)
    fig.update_yaxes(title_text="Acceptance Ratio (ρ)", row=3, col=1, range=[-0.5, 1.5])

    return fig

plot_scp_iterations(result: OptimizationResults, state_names: list[str] | None = None, control_names: list[str] | None = None, cmap_name: str = 'viridis', show_propagation: bool = True) -> go.Figure

Plot all SCP iterations overlaid with colormap-based coloring.

Shows the evolution of states and controls across SCP iterations. Early iterations are dark, later iterations are bright (following the colormap). This makes convergence visible at a glance.

Parameters:

Name Type Description Default
result OptimizationResults

Optimization results containing iteration history

required
state_names list[str] | None

Optional list of state names to include. If None, plots all states.

None
control_names list[str] | None

Optional list of control names to include. If None, plots all controls.

None
cmap_name str

Matplotlib colormap name (default: "viridis")

'viridis'
show_propagation bool

If True, show multi-shot propagation lines (default: True)

True

Returns:

Type Description
Figure

Plotly figure with all iterations overlaid

Example

results = problem.solve() plot_scp_iterations(results, ["position", "velocity"]).show()

Source code in openscvx/plotting/scp_iteration.py
def plot_scp_iterations(
    result: OptimizationResults,
    state_names: list[str] | None = None,
    control_names: list[str] | None = None,
    cmap_name: str = "viridis",
    show_propagation: bool = True,
) -> go.Figure:
    """Plot all SCP iterations overlaid with colormap-based coloring.

    Shows the evolution of states and controls across SCP iterations. Early
    iterations are dark, later iterations are bright (following the colormap).
    This makes convergence visible at a glance.

    Args:
        result: Optimization results containing iteration history
        state_names: Optional list of state names to include. If None, plots all states.
        control_names: Optional list of control names to include. If None, plots all controls.
        cmap_name: Matplotlib colormap name (default: "viridis")
        show_propagation: If True, show multi-shot propagation lines (default: True)

    Returns:
        Plotly figure with all iterations overlaid

    Example:
        >>> results = problem.solve()
        >>> plot_scp_iterations(results, ["position", "velocity"]).show()
    """
    import matplotlib.pyplot as plt

    if not result.X:
        raise ValueError("No iteration history available in result.X")

    # Derive dimensions from result data
    n_x = result.X[0].shape[1]
    n_u = result.U[0].shape[1]

    # Find time slice by looking for "time" state
    time_slice = None
    for state in result._states:
        if state.name.lower() == "time":
            time_slice = state._slice
            break

    # Extract multi-shot propagation trajectories
    V_history = result.discretization_history if result.discretization_history else []
    X_prop_history = []
    if V_history and show_propagation:
        i4 = n_x + n_x * n_x + 2 * n_x * n_u
        for V in V_history:
            pos_traj = []
            for i_multi in range(V.shape[1]):
                pos_traj.append(V[:, i_multi].reshape(-1, i4)[:, :n_x])
            X_prop_history.append(np.array(pos_traj))

    n_iterations = len(result.X)
    if X_prop_history:
        n_iterations = min(n_iterations, len(X_prop_history))

    # Filter states and controls (exclude ctcs_aug and time)
    states = [
        s for s in result._states if "ctcs_aug" not in s.name.lower() and s.name.lower() != "time"
    ]
    controls = list(result._controls) if result._controls else []

    state_filter = set(state_names) if state_names else None
    control_filter = set(control_names) if control_names else None

    if state_filter and control_filter is None:
        controls = []
    if control_filter and state_filter is None:
        states = []
    if state_filter:
        states = [s for s in states if s.name in state_filter]
        if not states:
            available = {s.name for s in result._states if "ctcs_aug" not in s.name.lower()}
            raise ValueError(
                f"No states matched filter {state_names}. Available: {sorted(available)}"
            )
    if control_filter:
        controls = [c for c in controls if c.name in control_filter]
        if not controls:
            available = {c.name for c in result._controls}
            raise ValueError(
                f"No controls matched filter {control_names}. Available: {sorted(available)}"
            )

    if not states and not controls:
        raise ValueError("No states or controls to plot")

    # Expand multi-dimensional variables to individual components
    def expand_variables(variables):
        expanded = []
        for var in variables:
            s = var._slice
            start = s.start if isinstance(s, slice) else s
            stop = s.stop if isinstance(s, slice) else start + 1
            n_comp = (stop or start + 1) - (start or 0)

            for i in range(n_comp):
                expanded.append(
                    {
                        "name": f"{var.name}_{i}" if n_comp > 1 else var.name,
                        "idx": start + i,
                        "parent": var.name,
                        "comp": i,
                    }
                )
        return expanded

    expanded_states = expand_variables(states)
    expanded_controls = expand_variables(controls)

    # Grid layout
    n_states = len(expanded_states)
    n_controls = len(expanded_controls)
    n_state_cols = min(7, n_states) if n_states > 0 else 1
    n_control_cols = min(3, n_controls) if n_controls > 0 else 1
    n_state_rows = (n_states + n_state_cols - 1) // n_state_cols if n_states > 0 else 0
    n_control_rows = (n_controls + n_control_cols - 1) // n_control_cols if n_controls > 0 else 0
    total_rows = n_state_rows + n_control_rows
    max_cols = max(n_state_cols, n_control_cols)

    subplot_titles = [s["name"] for s in expanded_states] + [c["name"] for c in expanded_controls]
    fig = make_subplots(
        rows=total_rows,
        cols=max_cols,
        subplot_titles=subplot_titles,
        vertical_spacing=0.08,
        horizontal_spacing=0.05,
    )

    # Get colormap
    cmap = plt.get_cmap(cmap_name)

    def iter_color(iter_idx):
        rgba = cmap(iter_idx / max(n_iterations - 1, 1))
        return f"rgb({int(rgba[0] * 255)},{int(rgba[1] * 255)},{int(rgba[2] * 255)})"

    # Plot all iterations
    for iter_idx in range(n_iterations):
        X_nodes = result.X[iter_idx]
        U_iter = result.U[iter_idx]
        color = iter_color(iter_idx)
        legend_group = f"iter_{iter_idx}"
        show_legend_for_iter = True  # Show legend for first trace of this iteration

        t_nodes = (
            X_nodes[:, time_slice].flatten()
            if time_slice is not None
            else np.linspace(0, result.t_final, X_nodes.shape[0])
        )

        # States
        for state_idx, state in enumerate(expanded_states):
            row = (state_idx // n_state_cols) + 1
            col = (state_idx % n_state_cols) + 1
            idx = state["idx"]

            # Multi-shot propagation lines
            if X_prop_history and iter_idx < len(X_prop_history):
                pos_traj = X_prop_history[iter_idx]
                for j in range(pos_traj.shape[1]):
                    segment_times = pos_traj[:, j, time_slice].flatten()
                    segment_states = pos_traj[:, j, idx]
                    fig.add_trace(
                        go.Scatter(
                            x=segment_times,
                            y=segment_states,
                            mode="lines",
                            line={"color": color, "width": 1.5},
                            legendgroup=legend_group,
                            showlegend=show_legend_for_iter,
                            name=f"Iter {iter_idx}" if show_legend_for_iter else None,
                            hoverinfo="skip",
                        ),
                        row=row,
                        col=col,
                    )
                    show_legend_for_iter = False

            # Nodes
            fig.add_trace(
                go.Scatter(
                    x=t_nodes,
                    y=X_nodes[:, idx],
                    mode="markers",
                    marker={"color": color, "size": 5},
                    legendgroup=legend_group,
                    showlegend=show_legend_for_iter,
                    name=f"Iter {iter_idx}" if show_legend_for_iter else None,
                    hovertemplate=f"iter {iter_idx}<br>t=%{{x:.2f}}<br>y=%{{y:.3g}}<extra></extra>",
                ),
                row=row,
                col=col,
            )
            show_legend_for_iter = False

        # Controls
        for control_idx, control in enumerate(expanded_controls):
            row = n_state_rows + (control_idx // n_control_cols) + 1
            col = (control_idx % n_control_cols) + 1
            idx = control["idx"]

            fig.add_trace(
                go.Scatter(
                    x=t_nodes,
                    y=U_iter[:, idx],
                    mode="markers",
                    marker={"color": color, "size": 5},
                    legendgroup=legend_group,
                    showlegend=show_legend_for_iter,
                    name=f"Iter {iter_idx}" if show_legend_for_iter else None,
                    hovertemplate=f"iter {iter_idx}<br>t=%{{x:.2f}}<br>y=%{{y:.3g}}<extra></extra>",
                ),
                row=row,
                col=col,
            )
            show_legend_for_iter = False

    # Add bounds (once, using final iteration's time range)
    t_nodes_final = (
        result.X[-1][:, time_slice].flatten()
        if time_slice is not None
        else np.linspace(0, result.t_final, result.X[-1].shape[0])
    )
    t_min, t_max = t_nodes_final.min(), t_nodes_final.max()

    for state_idx, state in enumerate(expanded_states):
        row = (state_idx // n_state_cols) + 1
        col = (state_idx % n_state_cols) + 1
        parent = _get_var(result, state["parent"], result._states)
        comp_idx = state["comp"]

        for bound_val, bound_attr in [(parent.min, "min"), (parent.max, "max")]:
            if bound_val is not None and np.isfinite(bound_val[comp_idx]):
                fig.add_trace(
                    go.Scatter(
                        x=[t_min, t_max],
                        y=[bound_val[comp_idx], bound_val[comp_idx]],
                        mode="lines",
                        line={"color": "red", "width": 1.5, "dash": "dot"},
                        showlegend=False,
                        hoverinfo="skip",
                    ),
                    row=row,
                    col=col,
                )

    for control_idx, control in enumerate(expanded_controls):
        row = n_state_rows + (control_idx // n_control_cols) + 1
        col = (control_idx % n_control_cols) + 1
        parent = _get_var(result, control["parent"], result._controls)
        comp_idx = control["comp"]

        for bound_val in [parent.min, parent.max]:
            if bound_val is not None and np.isfinite(bound_val[comp_idx]):
                fig.add_trace(
                    go.Scatter(
                        x=[t_min, t_max],
                        y=[bound_val[comp_idx], bound_val[comp_idx]],
                        mode="lines",
                        line={"color": "red", "width": 1.5, "dash": "dot"},
                        showlegend=False,
                        hoverinfo="skip",
                    ),
                    row=row,
                    col=col,
                )

    # Layout
    fig.update_layout(
        title_text="SCP Iterations",
        template="plotly_dark",
        showlegend=True,
        legend={
            "title": "Iterations",
            "yanchor": "top",
            "y": 0.99,
            "xanchor": "left",
            "x": 1.02,
            "bgcolor": "rgba(0, 0, 0, 0.5)",
            "itemclick": "toggle",
            "itemdoubleclick": "toggleothers",
        },
    )

    for col_idx in range(1, max_cols + 1):
        fig.update_xaxes(title_text="Time (s)", row=total_rows, col=col_idx)

    return fig

plot_state_component(result: OptimizationResults, state_name: str, component: int = 0) -> go.Figure

Plot a single component of a state variable vs time.

This is the low-level function for plotting one scalar value over time. For plotting all components of a state, use plot_states().

Parameters:

Name Type Description Default
result OptimizationResults

Optimization results containing state trajectories

required
state_name str

Name of the state variable

required
component int

Component index (0-indexed). For scalar states, use 0.

0

Returns:

Type Description
Figure

Plotly figure with single plot

Example

plot_state_component(result, "position", 2) # Plot z-component

Source code in openscvx/plotting/plotting.py
def plot_state_component(
    result: OptimizationResults,
    state_name: str,
    component: int = 0,
) -> go.Figure:
    """Plot a single component of a state variable vs time.

    This is the low-level function for plotting one scalar value over time.
    For plotting all components of a state, use plot_states().

    Args:
        result: Optimization results containing state trajectories
        state_name: Name of the state variable
        component: Component index (0-indexed). For scalar states, use 0.

    Returns:
        Plotly figure with single plot

    Example:
        >>> plot_state_component(result, "position", 2)  # Plot z-component
    """
    available = {s.name for s in result._states}
    if state_name not in available:
        raise ValueError(f"State '{state_name}' not found. Available: {sorted(available)}")

    dim = _get_var_dim(result, state_name, result._states)
    if component < 0 or component >= dim:
        raise ValueError(f"Component {component} out of range for '{state_name}' (dim={dim})")

    t_nodes = result.nodes["time"].flatten()
    has_trajectory = bool(result.trajectory) and state_name in result.trajectory
    t_full = result.trajectory["time"].flatten() if has_trajectory else None

    label = f"{state_name}_{component}" if dim > 1 else state_name

    fig = go.Figure()
    fig.update_layout(title_text=label, template="plotly_dark")

    if has_trajectory:
        data = result.trajectory[state_name]
        y = data if data.ndim == 1 else data[:, component]
        fig.add_trace(
            go.Scatter(
                x=t_full,
                y=y,
                mode="lines",
                name="Trajectory",
                line={"color": "green", "width": 2},
            )
        )

    if state_name in result.nodes:
        data = result.nodes[state_name]
        y = data if data.ndim == 1 else data[:, component]
        fig.add_trace(
            go.Scatter(
                x=t_nodes,
                y=y,
                mode="markers",
                name="Nodes",
                marker={"color": "cyan", "size": 6},
            )
        )

    fig.update_xaxes(title_text="Time (s)")
    fig.update_yaxes(title_text=label)
    return fig

plot_states(result: OptimizationResults, state_names: list[str] | None = None, include_private: bool = False, cols: int = 4) -> go.Figure

Plot state variables in a subplot grid.

Each component of each state gets its own subplot with individual y-axis scaling. This is the primary function for visualizing state trajectories.

Parameters:

Name Type Description Default
result OptimizationResults

Optimization results containing state trajectories

required
state_names list[str] | None

List of state names to plot. If None, plots all states.

None
include_private bool

Whether to include private states (names starting with '_')

False
cols int

Maximum number of columns in subplot grid

4

Returns:

Type Description
Figure

Plotly figure with subplot grid

Examples:

>>> plot_states(result, ["position"])  # 3 subplots for x, y, z
>>> plot_states(result, ["position", "velocity"])  # 6 subplots
>>> plot_states(result)  # All states
Source code in openscvx/plotting/plotting.py
def plot_states(
    result: OptimizationResults,
    state_names: list[str] | None = None,
    include_private: bool = False,
    cols: int = 4,
) -> go.Figure:
    """Plot state variables in a subplot grid.

    Each component of each state gets its own subplot with individual y-axis
    scaling. This is the primary function for visualizing state trajectories.

    Args:
        result: Optimization results containing state trajectories
        state_names: List of state names to plot. If None, plots all states.
        include_private: Whether to include private states (names starting with '_')
        cols: Maximum number of columns in subplot grid

    Returns:
        Plotly figure with subplot grid

    Examples:
        >>> plot_states(result, ["position"])  # 3 subplots for x, y, z
        >>> plot_states(result, ["position", "velocity"])  # 6 subplots
        >>> plot_states(result)  # All states
    """

    states = result._states
    if not include_private:
        states = [s for s in states if not s.name.startswith("_")]

    if state_names is not None:
        available = {s.name for s in states}
        missing = set(state_names) - available
        if missing:
            raise ValueError(f"States not found in result: {missing}")
        # Preserve order from state_names
        state_order = {name: i for i, name in enumerate(state_names)}
        states = sorted(
            [s for s in states if s.name in state_names],
            key=lambda s: state_order[s.name],
        )

    # Build list of (display_name, var_name, component_idx)
    components = []
    for s in states:
        dim = _get_var_dim(result, s.name, result._states)
        if dim == 1:
            components.append((s.name, s.name, 0))
        else:
            for i in range(dim):
                components.append((f"{s.name}_{i}", s.name, i))

    if not components:
        raise ValueError("No state components to plot")

    n_cols = min(cols, len(components))
    n_rows = (len(components) + n_cols - 1) // n_cols

    fig = make_subplots(rows=n_rows, cols=n_cols, subplot_titles=[c[0] for c in components])
    fig.update_layout(title_text="State Trajectories", template="plotly_dark")

    for idx, (_, var_name, comp_idx) in enumerate(components):
        row = (idx // n_cols) + 1
        col = (idx % n_cols) + 1

        # Get bounds for this component
        var = _get_var(result, var_name, result._states)
        min_val = var.min[comp_idx] if var.min is not None else None
        max_val = var.max[comp_idx] if var.max is not None else None

        _add_component_traces(
            fig,
            result,
            var_name,
            comp_idx,
            row,
            col,
            show_legend=(idx == 0),
            min_val=min_val,
            max_val=max_val,
        )

    # Add x-axis labels to bottom row
    for col_idx in range(1, n_cols + 1):
        fig.update_xaxes(title_text="Time (s)", row=n_rows, col=col_idx)

    return fig

plot_trust_region_heatmap(result: OptimizationResults)

Plot heatmap of the final trust-region deltas (TR_history[-1]).

Source code in openscvx/plotting/plotting.py
def plot_trust_region_heatmap(result: OptimizationResults):
    """Plot heatmap of the final trust-region deltas (TR_history[-1])."""
    if not result.TR_history:
        raise ValueError("Result has no TR_history to plot")

    tr_mat = result.TR_history[-1]

    # Build variable names list
    var_names = []
    for var_list in [result._states, result._controls]:
        for var in var_list:
            dim = _get_var_dim(result, var.name, var_list)
            if dim == 1:
                var_names.append(var.name)
            else:
                var_names.extend(f"{var.name}_{i}" for i in range(dim))

    # TR matrix is (n_states+n_controls, n_nodes): rows = variables, cols = nodes
    if tr_mat.shape[0] == len(var_names):
        z = tr_mat
    elif tr_mat.shape[1] == len(var_names):
        z = tr_mat.T
    else:
        raise ValueError("TR matrix dimensions do not align with state/control components")

    x_len = z.shape[1]
    t_nodes = result.nodes["time"].flatten()
    x_labels = t_nodes if len(t_nodes) == x_len else list(range(x_len))

    fig = go.Figure(data=go.Heatmap(z=z, x=x_labels, y=var_names, colorscale="Viridis"))
    fig.update_layout(
        title="Trust Region Delta Magnitudes (last iteration)", template="plotly_dark"
    )
    fig.update_xaxes(title_text="Node / Time", side="bottom")
    fig.update_yaxes(title_text="State / Control component", side="left")
    return fig

plot_vector_norm(result: OptimizationResults, var_name: str, bounds: tuple[float, float] | None = None) -> go.Figure

Plot the 2-norm of a vector variable over time.

Useful for visualizing thrust magnitude, velocity magnitude, etc.

Parameters:

Name Type Description Default
result OptimizationResults

Optimization results containing trajectories

required
var_name str

Name of the vector variable (state or control)

required
bounds tuple[float, float] | None

Optional (min, max) bounds to show as horizontal dashed lines

None

Returns:

Type Description
Figure

Plotly figure

Source code in openscvx/plotting/plotting.py
def plot_vector_norm(
    result: OptimizationResults,
    var_name: str,
    bounds: tuple[float, float] | None = None,
) -> go.Figure:
    """Plot the 2-norm of a vector variable over time.

    Useful for visualizing thrust magnitude, velocity magnitude, etc.

    Args:
        result: Optimization results containing trajectories
        var_name: Name of the vector variable (state or control)
        bounds: Optional (min, max) bounds to show as horizontal dashed lines

    Returns:
        Plotly figure
    """
    import numpy as np

    has_trajectory = bool(result.trajectory) and var_name in result.trajectory
    has_nodes = var_name in result.nodes

    if not has_trajectory and not has_nodes:
        available_traj = set(result.trajectory.keys()) if result.trajectory else set()
        available_nodes = set(result.nodes.keys())
        raise ValueError(
            f"Variable '{var_name}' not found. "
            f"Available in trajectory: {sorted(available_traj)}, nodes: {sorted(available_nodes)}"
        )

    fig = go.Figure()

    # Plot trajectory norm if available
    if has_trajectory:
        t_full = result.trajectory["time"].flatten()
        data = result.trajectory[var_name]
        norm = np.linalg.norm(data, axis=1)
        fig.add_trace(
            go.Scatter(
                x=t_full,
                y=norm,
                mode="lines",
                line={"color": "green", "width": 2},
                name="Trajectory",
                legendgroup="trajectory",
            )
        )

    # Plot node norms if available
    if has_nodes:
        t_nodes = result.nodes["time"].flatten()
        data = result.nodes[var_name]
        norm = np.linalg.norm(data, axis=1)
        fig.add_trace(
            go.Scatter(
                x=t_nodes,
                y=norm,
                mode="markers",
                marker={"color": "cyan", "size": 6},
                name="Nodes",
                legendgroup="nodes",
            )
        )

    # Add bounds if provided
    if bounds is not None:
        min_bound, max_bound = bounds
        fig.add_hline(
            y=min_bound,
            line={"color": "red", "width": 2, "dash": "dash"},
            annotation_text="Min",
            annotation_position="right",
        )
        fig.add_hline(
            y=max_bound,
            line={"color": "red", "width": 2, "dash": "dash"},
            annotation_text="Max",
            annotation_position="right",
        )

    fig.update_layout(
        title=f"‖{var_name}‖₂",
        xaxis_title="Time (s)",
        yaxis_title="Norm",
        template="plotly_dark",
    )

    return fig

plot_virtual_control_heatmap(result: OptimizationResults)

Plot heatmap of the final virtual control magnitudes (VC_history[-1]).

Source code in openscvx/plotting/plotting.py
def plot_virtual_control_heatmap(result: OptimizationResults):
    """Plot heatmap of the final virtual control magnitudes (VC_history[-1])."""
    if not result.VC_history:
        raise ValueError("Result has no VC_history to plot")

    vc_mat = result.VC_history[-1]

    # Build state names list
    state_names = []
    for var in result._states:
        dim = _get_var_dim(result, var.name, result._states)
        if dim == 1:
            state_names.append(var.name)
        else:
            state_names.extend(f"{var.name}_{i}" for i in range(dim))

    # Align so rows = states, cols = nodes
    if vc_mat.shape[1] == len(state_names):
        z = vc_mat.T
    elif vc_mat.shape[0] == len(state_names):
        z = vc_mat
    else:
        raise ValueError("VC matrix shape does not align with state components")

    x_len = z.shape[1]
    t_nodes = result.nodes["time"].flatten()

    # Virtual control uses N-1 intervals
    if len(t_nodes) == x_len + 1:
        x_labels = t_nodes[:-1]
    elif len(t_nodes) == x_len:
        x_labels = t_nodes
    else:
        x_labels = list(range(x_len))

    fig = go.Figure(data=go.Heatmap(z=z, x=x_labels, y=state_names, colorscale="Magma"))
    fig.update_layout(title="Virtual Control Magnitudes (last iteration)", template="plotly_dark")
    fig.update_xaxes(title_text="Node Interval (N-1)")
    fig.update_yaxes(title_text="State component")
    return fig