Skip to content

propagation

Trajectory propagation for trajectory optimization.

This module provides implementations of trajectory propagation methods that simulate the nonlinear system dynamics forward in time. Propagation is used to evaluate solution quality, verify constraint satisfaction, and generate high-fidelity trajectories from optimized control sequences.

Current Implementations

Forward Simulation: The default propagation method that integrates the nonlinear dynamics forward in time using adaptive or fixed-step numerical integration (via Diffrax). Supports both ZOH and FOH control interpolation schemes.

Planned Architecture (ABC-based):

A base class will be introduced to enable pluggable propagation methods. This will enable users to implement custom propagation methods. Future propagators will implement the Propagator interface:

# propagation/base.py (planned):
class Propagator(ABC):
    def __init__(self, integrator: Integrator):
        '''Initialize with a numerical integrator.'''
        self.integrator = integrator

    @abstractmethod
    def propagate(self, dynamics, x0, u_traj, time_grid) -> Array:
        '''Propagate trajectory forward in time.

        Args:
            dynamics: Continuous-time dynamics object
            x0: Initial state
            u_traj: Control trajectory
            time_grid: Time points for dense output

        Returns:
            State trajectory evaluated at time_grid points
        '''
        ...

get_propagation_solver(state_dot: Dynamics, settings: Config, discretizer: Discretizer) -> callable

Create a propagation solver function.

This function creates a solver that propagates the system state using the specified dynamics and settings.

Parameters:

Name Type Description Default
state_dot Dynamics

Dynamics object containing state derivative function.

required
settings Config

Configuration settings for propagation.

required
discretizer Discretizer

Discretizer instance (used for dis_type).

required

Returns:

Name Type Description
callable callable

A function that solves the propagation problem.

Source code in openscvx/propagation/propagation.py
def get_propagation_solver(
    state_dot: Dynamics, settings: Config, discretizer: Discretizer
) -> callable:
    """Create a propagation solver function.

    This function creates a solver that propagates the system state using the
    specified dynamics and settings.

    Args:
        state_dot: Dynamics object containing state derivative function.
        settings: Configuration settings for propagation.
        discretizer: Discretizer instance (used for ``dis_type``).

    Returns:
        callable: A function that solves the propagation problem.
    """

    u_foh_mask = getattr(settings.sim.u, "foh_mask", None)
    foh_mask = _resolve_foh_mask(discretizer.dis_type, settings.sim.n_controls, u_foh_mask)

    def propagation_solver(V0, tau_grid, u_cur, u_next, tau_init, node, save_time, mask, params):
        param_map_update = params
        return solve_ivp_diffrax_prop(
            f=prop_aug_dy,
            tau_final=tau_grid[1],  # scalar
            y_0=V0,  # shape (n_states,)
            args=(
                u_cur,  # shape (1, n_controls)
                u_next,  # shape (1, n_controls)
                tau_init,  # shape (1, 1)
                node,  # shape (1, 1)
                state_dot,  # function or array
                foh_mask,
                settings.sim.n,
                param_map_update,
                # additional named parameters as **kwargs
            ),
            tau_0=tau_grid[0],  # scalar
            solver_name=settings.prp.solver,
            rtol=settings.prp.rtol,
            atol=settings.prp.atol,
            extra_kwargs=settings.prp.args,
            save_time=save_time,  # shape (MAX_TAU_LEN,)
            mask=mask,  # shape (MAX_TAU_LEN,), dtype=bool
        )

    return propagation_solver

prop_aug_dy(tau: float, x: np.ndarray, u_current: np.ndarray, u_next: np.ndarray, tau_init: float, node: int, state_dot: callable, foh_mask: np.ndarray, N: int, params: dict) -> np.ndarray

Compute the augmented dynamics for propagation.

This function computes the time-dilated dynamics for propagating the system state, taking into account the per-control hold type (ZOH or FOH). The time-dilation multiplication is already included in state_dot symbolically.

Parameters:

Name Type Description Default
tau float

Current normalized time in [0,1].

required
x ndarray

Current state vector.

required
u_current ndarray

Control input at current node.

required
u_next ndarray

Control input at next node.

required
tau_init float

Initial normalized time.

required
node int

Current node index.

required
state_dot callable

Function computing time-dilated state derivatives.

required
foh_mask ndarray

Float array of shape (n_u,) — 1.0 for FOH controls, 0.0 for ZOH controls.

required
N int

Number of nodes in trajectory.

required
params dict

Dictionary of additional parameters passed to state_dot.

required

Returns:

Type Description
ndarray

np.ndarray: Time-dilated state derivatives.

Source code in openscvx/propagation/propagation.py
def prop_aug_dy(
    tau: float,
    x: np.ndarray,
    u_current: np.ndarray,
    u_next: np.ndarray,
    tau_init: float,
    node: int,
    state_dot: callable,
    foh_mask: np.ndarray,
    N: int,
    params: dict,
) -> np.ndarray:
    """Compute the augmented dynamics for propagation.

    This function computes the time-dilated dynamics for propagating the system
    state, taking into account the per-control hold type (ZOH or FOH). The
    time-dilation multiplication is already included in ``state_dot``
    symbolically.

    Args:
        tau (float): Current normalized time in [0,1].
        x (np.ndarray): Current state vector.
        u_current (np.ndarray): Control input at current node.
        u_next (np.ndarray): Control input at next node.
        tau_init (float): Initial normalized time.
        node (int): Current node index.
        state_dot (callable): Function computing time-dilated state derivatives.
        foh_mask (np.ndarray): Float array of shape ``(n_u,)`` — ``1.0`` for
            FOH controls, ``0.0`` for ZOH controls.
        N (int): Number of nodes in trajectory.
        params: Dictionary of additional parameters passed to state_dot.

    Returns:
        np.ndarray: Time-dilated state derivatives.
    """
    x = x[None, :]

    beta = (tau - tau_init) * N * foh_mask
    u = u_current + beta * (u_next - u_current)

    return state_dot(x, u, node, params).squeeze()

propagate_trajectory_results(params: dict, settings: Config, result: OptimizationResults, propagation_solver: callable, dynamics_discrete: Optional[Callable] = None, algebraic_prop: Optional[dict] = None, discretizer: Optional[Discretizer] = None) -> OptimizationResults

Propagate the optimal trajectory and compute additional results.

This function takes the optimal control solution and propagates it through the nonlinear dynamics to compute the actual state trajectory and other metrics.

When states_prop includes propagation-only states (e.g. via dynamics_prop / states_prop), x_full has shape (n_times, n_prop_states) with n_prop_states > n_opt_states. The discrete dynamics and cost use only the optimization-state portion; propagation-only states are preserved from the last propagated step and included in trajectory.

Parameters:

Name Type Description Default
params dict

System parameters.

required
settings Config

Configuration settings.

required
result OptimizationResults

Optimization results object.

required
propagation_solver callable

Function for propagating the system state.

required
dynamics_discrete callable

Discrete dynamics map used to apply node-wise impulsive/discrete updates before continuous propagation.

None
algebraic_prop dict

Dictionary mapping output names to vmapped JAX functions.

None
discretizer Optional[Discretizer]

Discretizer instance (used for dis_type). Defaults to None which uses FOH.

None

Returns:

Name Type Description
OptimizationResults OptimizationResults

Updated results object containing: - t_full: Full time vector - x_full: Full state trajectory - u_full: Full control trajectory - cost: Computed cost - ctcs_violation: CTCS constraint violation - trajectory: Dict containing each variables values at full propagation fidelity

Source code in openscvx/propagation/post_processing.py
def propagate_trajectory_results(
    params: dict,
    settings: Config,
    result: OptimizationResults,
    propagation_solver: callable,
    dynamics_discrete: Optional[Callable] = None,
    algebraic_prop: Optional[dict] = None,
    discretizer: Optional[Discretizer] = None,
) -> OptimizationResults:
    """Propagate the optimal trajectory and compute additional results.

    This function takes the optimal control solution and propagates it through the
    nonlinear dynamics to compute the actual state trajectory and other metrics.

    When ``states_prop`` includes propagation-only states (e.g. via ``dynamics_prop`` /
    ``states_prop``), ``x_full`` has shape ``(n_times, n_prop_states)`` with
    ``n_prop_states > n_opt_states``. The discrete dynamics and cost use only the
    optimization-state portion; propagation-only states are preserved from the last
    propagated step and included in ``trajectory``.

    Args:
        params (dict): System parameters.
        settings (Config): Configuration settings.
        result (OptimizationResults): Optimization results object.
        propagation_solver (callable): Function for propagating the system state.
        dynamics_discrete (callable, optional): Discrete dynamics map used to apply
            node-wise impulsive/discrete updates before continuous propagation.
        algebraic_prop (dict, optional): Dictionary mapping output names to vmapped JAX functions.
        discretizer: Discretizer instance (used for ``dis_type``).
            Defaults to ``None`` which uses FOH.

    Returns:
        OptimizationResults: Updated results object containing:
            - t_full: Full time vector
            - x_full: Full state trajectory
            - u_full: Full control trajectory
            - cost: Computed cost
            - ctcs_violation: CTCS constraint violation
            - trajectory: Dict containing each variables values at full propagation fidelity
    """
    # Get arrays from result
    x = result.x
    u = result.u

    t = np.array(s_to_t(x, u, settings, discretizer)).squeeze()

    # Build dense output times and always include the exact terminal time.
    # This ensures trajectory[..., -1] corresponds to the true final state.
    t_full = np.arange(t[0], t[-1], settings.prp.dt)
    if t_full.size == 0 or not np.isclose(t_full[-1], t[-1]):
        t_full = np.concatenate([t_full, np.array([t[-1]])])

    tau_vals, u_full = t_to_tau(u, t_full, t, settings, discretizer)

    # Create a copy of x_prop for propagation to avoid mutating settings
    # Match free values from initial state to the initial value from the result
    x_prop_for_propagation = copy.copy(settings.sim.x_prop)

    # Only copy for states that exist in optimization (propagation may have extra states at the end)
    n_opt_states = x.shape[1]
    n_prop_states = settings.sim.x_prop.initial.shape[0]

    if n_opt_states == n_prop_states:
        # Same size - copy all
        # Use metadata from settings (immutable configuration)
        mask = jnp.array([t == "Free" for t in settings.sim.x.initial_type], dtype=bool)
        x_prop_for_propagation.initial = jnp.where(mask, x[0, :], settings.sim.x_prop.initial)
    else:
        # Propagation has extra states - only copy the overlapping portion
        # Use metadata from settings (immutable configuration)
        mask = jnp.array([t == "Free" for t in settings.sim.x.initial_type], dtype=bool)
        x_prop_initial_updated = settings.sim.x_prop.initial.copy()
        x_prop_initial_updated[:n_opt_states] = jnp.where(
            mask, x[0, :], settings.sim.x_prop.initial[:n_opt_states]
        )
        x_prop_for_propagation.initial = x_prop_initial_updated

    # Temporarily replace x_prop with our modified copy for propagation
    # Save original to restore after propagation
    original_x_prop = settings.sim.x_prop
    settings.sim.x_prop = x_prop_for_propagation

    try:
        x_full = simulate_nonlinear_time(
            params,
            x,
            u,
            tau_vals,
            t,
            settings,
            propagation_solver,
            dynamics_discrete=dynamics_discrete,
        )
    finally:
        # Always restore original x_prop, even if propagation fails
        settings.sim.x_prop = original_x_prop

    # Calculate cost using utility function and metadata from settings
    # dynamics_discrete operates on optimization states only; when propagation has
    # extra states, pass only the opt-state portion and then reattach the prop-only tail
    x_minus = np.asarray(x_full[-1, :n_opt_states])
    x_plus = np.asarray(
        dynamics_discrete(
            x_minus,
            np.asarray(u[-1]),
            int(settings.sim.n - 1),
            params,
        )
    ).reshape(-1)
    if n_prop_states > n_opt_states:
        # Preserve propagation-only states (not updated by discrete dynamics)
        full_final = np.concatenate([x_plus, np.asarray(x_full[-1, n_opt_states:])], axis=0)
    else:
        full_final = x_plus
    x_for_cost = np.concatenate([x_full[:-1], full_final[None, :]], axis=0)

    cost = calculate_cost_from_boundaries(
        x_for_cost[:, :n_opt_states],
        settings.sim.x.initial_type,
        settings.sim.x.final_type,
    )

    # Calculate CTCS constraint violation (use state after final impulse when applicable)
    if dynamics_discrete is not None and np.any(settings.sim.u._impulsive_mask()):
        ctcs_violation = full_final[settings.sim.ctcs_slice_prop]
    else:
        ctcs_violation = x_full[-1, settings.sim.ctcs_slice_prop]

    # Build trajectory dictionary with all states and controls.
    # result._states is states_prop (opt + propagation-only); each state._slice
    # indexes into the full propagation state, so propagation-only states are included.
    trajectory_dict = {}

    # Add all states (user-defined and augmented)
    for state in result._states:
        trajectory_dict[state.name] = x_full[:, state._slice]

    # Add all controls (user-defined and augmented)
    for control in result._controls:
        trajectory_dict[control.name] = u_full[:, control._slice]

    # Compute algebraic outputs (vmapped over time)
    if algebraic_prop:
        for name, output_fn in algebraic_prop.items():
            # output_fn is vmapped: (T, n_x), (T, n_u), node, params -> (T, output_dim)
            # Pass node=0 since algebraic outputs shouldn't depend on node index
            output_values = output_fn(x_full, u_full, 0, params)
            trajectory_dict[name] = np.asarray(output_values)

    # Update the results object with post-processing data
    result.t_full = t_full
    result.x_full = x_full
    result.u_full = u_full
    result.cost = cost
    result.ctcs_violation = ctcs_violation
    result.trajectory = trajectory_dict

    return result

s_to_t(x: np.ndarray, u: np.ndarray, settings: Config, discretizer: Discretizer) -> list[float]

Convert normalized time s to real time t.

This function converts the normalized time variable s to real time t based on the hold type of the time-dilation control.

Parameters:

Name Type Description Default
x ndarray

State trajectory array, shape (N, n_states).

required
u ndarray

Control trajectory array, shape (N, n_controls).

required
settings Config

Configuration settings.

required
discretizer Discretizer

Discretizer instance (used for dis_type).

required

Returns:

Type Description
list[float]

list[float]: List of real time points.

Source code in openscvx/propagation/propagation.py
def s_to_t(x: np.ndarray, u: np.ndarray, settings: Config, discretizer: Discretizer) -> list[float]:
    """Convert normalized time s to real time t.

    This function converts the normalized time variable s to real time t
    based on the hold type of the time-dilation control.

    Args:
        x: State trajectory array, shape (N, n_states).
        u: Control trajectory array, shape (N, n_controls).
        settings (Config): Configuration settings.
        discretizer: Discretizer instance (used for ``dis_type``).

    Returns:
        list[float]: List of real time points.
    """
    t = [x[:, settings.sim.time_slice][0]]
    tau = np.linspace(0, 1, settings.sim.n)
    idx_s = _time_dilation_index(settings, u.shape[1])
    u_foh_mask = getattr(settings.sim.u, "foh_mask", None)
    foh_mask = _resolve_foh_mask(discretizer.dis_type, u.shape[1], u_foh_mask)
    td_is_foh = foh_mask[idx_s] > 0.5
    for k in range(1, settings.sim.n):
        s_kp = u[k - 1, idx_s]
        s_k = u[k, idx_s]
        if td_is_foh:
            t.append(t[k - 1] + 0.5 * (s_k + s_kp) * (tau[k] - tau[k - 1]))
        else:
            t.append(t[k - 1] + (tau[k] - tau[k - 1]) * (s_kp))
    return t

simulate_nonlinear_time(params: dict, x: np.ndarray, u: np.ndarray, tau_vals: np.ndarray, t: np.ndarray, settings: Config, propagation_solver: callable, dynamics_discrete: Optional[Callable] = None) -> np.ndarray

Simulate the nonlinear system dynamics over time.

This function simulates the system dynamics using the optimal control sequence and returns the resulting state trajectory.

Parameters:

Name Type Description Default
params dict

System parameters.

required
x ndarray

State trajectory array, shape (N, n_states).

required
u ndarray

Control trajectory array, shape (N, n_controls).

required
tau_vals ndarray

Normalized time points for simulation.

required
t ndarray

Real time points.

required
settings Config

Configuration settings.

required
propagation_solver callable

Function for propagating the system state.

required
dynamics_discrete Optional[Callable]

Optional discrete dynamics map f_discrete(x, u, node, params) used to apply impulsive/discrete updates at each node before continuous propagation.

None

Returns:

Type Description
ndarray

np.ndarray: Simulated state trajectory.

Source code in openscvx/propagation/propagation.py
def simulate_nonlinear_time(
    params: dict,
    x: np.ndarray,
    u: np.ndarray,
    tau_vals: np.ndarray,
    t: np.ndarray,
    settings: Config,
    propagation_solver: callable,
    dynamics_discrete: Optional[Callable] = None,
) -> np.ndarray:
    """Simulate the nonlinear system dynamics over time.

    This function simulates the system dynamics using the optimal control sequence
    and returns the resulting state trajectory.

    Args:
        params: System parameters.
        x: State trajectory array, shape (N, n_states).
        u: Control trajectory array, shape (N, n_controls).
        tau_vals (np.ndarray): Normalized time points for simulation.
        t (np.ndarray): Real time points.
        settings: Configuration settings.
        propagation_solver (callable): Function for propagating the system state.
        dynamics_discrete: Optional discrete dynamics map f_discrete(x, u, node, params)
            used to apply impulsive/discrete updates at each node before continuous propagation.

    Returns:
        np.ndarray: Simulated state trajectory.
    """
    x_0 = settings.sim.x_prop.initial

    n_segments = settings.sim.n - 1
    n_states = x_0.shape[0]
    n_tau = len(tau_vals)

    states = np.empty((n_states, n_tau))
    tau = np.linspace(0, 1, settings.sim.n)

    # Precompute control interpolation
    u_interp = np.stack([np.interp(t, t, u[:, i]) for i in range(u.shape[1])], axis=-1)
    _time_dilation_index(settings, u.shape[1])

    has_u_d = np.any(settings.sim.u._impulsive_mask())

    # Bin tau_vals into segments of tau
    tau_inds = np.digitize(tau_vals, tau) - 1
    tau_inds = np.where(tau_inds == settings.sim.n - 1, settings.sim.n - 2, tau_inds)

    prev_count = 0
    out_idx = 0

    for k in range(n_segments):
        controls_current = u_interp[k][None, :]
        controls_next = u_interp[k + 1][None, :]

        # Mask for tau_vals in current segment
        mask = (tau_inds >= k) & (tau_inds < k + 1)
        count = np.sum(mask)

        tau_cur = tau_vals[prev_count : prev_count + count]
        # Ensure integration reaches the segment endpoint.
        # If tau_cur already contains tau[k+1], avoid duplicating it.
        append_endpoint = tau_cur.size == 0 or not np.isclose(tau_cur[-1], tau[k + 1])
        if append_endpoint:
            tau_cur = np.concatenate([tau_cur, np.array([tau[k + 1]])])
            count += 1

        # Pad to fixed length
        pad_len = settings.prp.max_tau_len - count
        tau_cur_padded = np.pad(tau_cur, (0, pad_len), constant_values=tau[k + 1])
        mask_padded = np.concatenate([np.ones(count), np.zeros(pad_len)]).astype(bool)

        # Map prior node state to posterior using discrete dynamics when available.
        if has_u_d and dynamics_discrete is not None:
            x_post = np.asarray(
                dynamics_discrete(
                    np.asarray(x_0),
                    np.asarray(u[k]),
                    int(k),
                    params,
                )
            ).reshape(-1)
        else:
            x_post = x_0

        # Call the continuous propagation solver with padded tau_cur and mask
        sol = _invoke_solver(
            propagation_solver,
            x_post,
            (tau[k], tau[k + 1]),
            controls_current,
            controls_next,
            np.array([[tau[k]]]),
            np.array([[k]]),
            tau_cur_padded,
            mask_padded,
            params,
        )

        # Store requested samples; exclude endpoint only when it was appended
        # solely for continuity propagation to the next segment.
        n_store = count - 1 if append_endpoint else count
        states[:, out_idx : out_idx + n_store] = sol[:n_store].T
        out_idx += n_store
        x_0 = sol[count - 1]  # Last value used as next x_0

        prev_count += n_store

    return states.T

t_to_tau(u: np.ndarray, t: np.ndarray, t_nodal: np.ndarray, settings: Config, discretizer: Discretizer) -> tuple[np.ndarray, np.ndarray]

Convert real time t to normalized time tau.

This function converts real time t to normalized time tau and interpolates the control inputs according to each control's hold type.

Parameters:

Name Type Description Default
u ndarray

Control trajectory array, shape (N, n_controls).

required
t ndarray

Real time points.

required
t_nodal ndarray

Nodal time points.

required
settings Config

Configuration settings.

required
discretizer Discretizer

Discretizer instance (used for dis_type).

required

Returns:

Type Description
tuple[ndarray, ndarray]

tuple[np.ndarray, np.ndarray]: (tau, u_interp) where tau is normalized time and u_interp is interpolated controls.

Source code in openscvx/propagation/propagation.py
def t_to_tau(
    u: np.ndarray, t: np.ndarray, t_nodal: np.ndarray, settings: Config, discretizer: Discretizer
) -> tuple[np.ndarray, np.ndarray]:
    """Convert real time t to normalized time tau.

    This function converts real time t to normalized time tau and interpolates
    the control inputs according to each control's hold type.

    Args:
        u (np.ndarray): Control trajectory array, shape (N, n_controls).
        t (np.ndarray): Real time points.
        t_nodal (np.ndarray): Nodal time points.
        settings (Config): Configuration settings.
        discretizer: Discretizer instance (used for ``dis_type``).

    Returns:
        tuple[np.ndarray, np.ndarray]: (tau, u_interp) where tau is normalized time and u_interp is
            interpolated controls.
    """
    u_foh_mask = getattr(settings.sim.u, "foh_mask", None)
    foh_mask = _resolve_foh_mask(discretizer.dis_type, u.shape[1], u_foh_mask)
    foh_mask_bool = foh_mask > 0.5

    def u_lam(new_t):
        idx = np.searchsorted(t_nodal, new_t, side="right") - 1
        idx = np.clip(idx, 0, len(t_nodal) - 1)
        zoh_vals = u[idx, :]
        foh_vals = np.array([np.interp(new_t, t_nodal, u[:, i]) for i in range(u.shape[1])])
        return np.where(foh_mask_bool, foh_vals, zoh_vals)

    u_interp = np.array([u_lam(t_i) for t_i in t])

    tau = np.zeros(len(t))
    tau_nodal = np.linspace(0, 1, settings.sim.n)
    idx_s = _time_dilation_index(settings, u.shape[1])
    td_is_foh = foh_mask[idx_s] > 0.5
    for k in range(1, len(t)):
        k_nodal = np.where(t_nodal < t[k])[0][-1]
        s_kp = u[k_nodal, idx_s]
        tp = t_nodal[k_nodal]
        tau_p = tau_nodal[k_nodal]

        s_k = u[k_nodal + 1, idx_s]
        if td_is_foh:
            tau[k] = tau_p + 2 * (t[k] - tp) / (s_k + s_kp)
        else:
            tau[k] = tau_p + (t[k] - tp) / s_kp
    return tau, u_interp