Skip to content

linearize_discretize

LinearizeDiscretize

Bases: Discretizer

Linearize-then-discretize via augmented ODE integration.

Computes continuous-time Jacobians (df/dx, df/du) via JAX forward-mode autodiff, then integrates them alongside the nonlinear dynamics through an augmented state vector using a multi-shooting approach to produce discrete-time matrices.

Supports ZOH (zero-order hold) and FOH (first-order hold) control interpolation between nodes.

This is the default discretization scheme in OpenSCvx.

Parameters:

Name Type Description Default
dis_type Union[str, Sequence[str]]

Control hold type. "FOH" (first-order hold) or "ZOH" (zero-order hold) applies the same hold to every control. A per-control sequence (e.g. ["FOH", "ZOH", "FOH"]) sets the hold independently for each control, merged with any per-control parameterization ("FOH" / "ZOH"). Defaults to "FOH".

'FOH'
ode_solver str

Diffrax solver name. Any solver from Diffrax <https://docs.kidger.site/diffrax/usage/how-to-choose-a-solver/>_ is valid. Defaults to "Tsit5".

'Tsit5'
diffrax_kwargs Optional[dict[str, Any]]

Preferred Diffrax keyword overrides. These map to :func:openscvx.integrators.solve_ivp_diffrax kwargs, and unknown keys are forwarded to :func:diffrax.diffeqsolve (e.g. stepsize_controller). Set rtol/atol here when using the default PID controller. Defaults to {}.

None
Source code in openscvx/discretization/linearize_discretize.py
class LinearizeDiscretize(Discretizer):
    """Linearize-then-discretize via augmented ODE integration.

    Computes continuous-time Jacobians (df/dx, df/du) via JAX forward-mode
    autodiff, then integrates them alongside the nonlinear dynamics through
    an augmented state vector using a multi-shooting approach to produce
    discrete-time matrices.

    Supports ZOH (zero-order hold) and FOH (first-order hold) control
    interpolation between nodes.

    This is the default discretization scheme in OpenSCvx.

    Args:
        dis_type: Control hold type. ``"FOH"`` (first-order hold) or
            ``"ZOH"`` (zero-order hold) applies the same hold to every
            control.  A per-control sequence (e.g. ``["FOH", "ZOH", "FOH"]``)
            sets the hold independently for each control, merged with any
            per-control ``parameterization`` (``"FOH"`` / ``"ZOH"``).
            Defaults to ``"FOH"``.
        ode_solver: Diffrax solver name. Any solver from
            `Diffrax <https://docs.kidger.site/diffrax/usage/how-to-choose-a-solver/>`_
            is valid. Defaults to ``"Tsit5"``.
        diffrax_kwargs: Preferred Diffrax keyword overrides. These map to
            :func:`openscvx.integrators.solve_ivp_diffrax` kwargs, and unknown
            keys are forwarded to :func:`diffrax.diffeqsolve` (e.g.
            ``stepsize_controller``). Set ``rtol``/``atol`` here when using
            the default PID controller. Defaults to ``{}``.
    """

    def __init__(
        self,
        dis_type: Union[str, Sequence[str]] = "FOH",
        ode_solver: str = "Tsit5",
        diffrax_kwargs: Optional[dict[str, Any]] = None,
    ):
        self.dis_type = dis_type
        self.ode_solver = ode_solver
        self.diffrax_kwargs = dict(diffrax_kwargs) if diffrax_kwargs is not None else {}

    def _resolve_diffrax_kwargs(self) -> dict[str, Any]:
        """Build kwargs for :func:`solve_ivp_diffrax`.

        Unknown keys from ``self.diffrax_kwargs`` are forwarded to
        :func:`diffrax.diffeqsolve` via ``extra_kwargs`` so users can pass
        objects like ``stepsize_controller`` directly from discretizer settings.
        """
        kwargs: dict[str, Any] = {
            "solver_name": self.ode_solver,
            "rtol": DEFAULT_DIFFRAX_RTOL,
            "atol": DEFAULT_DIFFRAX_ATOL,
        }

        user_kwargs = dict(self.diffrax_kwargs)
        extra_kwargs: dict[str, Any] = {}

        nested_extra = user_kwargs.pop("extra_kwargs", None)
        if nested_extra is not None:
            extra_kwargs.update(dict(nested_extra))

        direct_keys = {"tau_0", "num_substeps", "solver_name", "rtol", "atol"}
        for key, value in user_kwargs.items():
            if key in direct_keys:
                kwargs[key] = value
            else:
                extra_kwargs[key] = value

        kwargs["extra_kwargs"] = extra_kwargs
        return kwargs

    def _resolve_rk45_kwargs(self, *, is_not_compiled: bool) -> dict[str, Any]:
        """Build kwargs for :func:`solve_ivp_rk45`."""
        kwargs: dict[str, Any] = {"is_not_compiled": is_not_compiled}
        direct_keys = {"tau_0", "num_substeps", "is_not_compiled"}
        for key, value in self.diffrax_kwargs.items():
            if key in direct_keys:
                kwargs[key] = value
        return kwargs

    def get_solver(self, dynamics: "Dynamics", settings: "Config") -> callable:
        """Create a multi-shoot discretization solver.

        Computes Jacobians of ``dynamics.f`` via ``jax.jacfwd``, vmaps all
        functions for batch evaluation across nodes, and returns a callable
        that integrates the augmented variational equations.

        Args:
            dynamics: System dynamics. Only ``dynamics.f`` is used; Jacobians
                are computed internally via JAX autodiff.
            settings: Problem configuration.

        Returns:
            Callable ``(x, u, params) -> (A_d, B_d, C_d, x_prop, V)``.
        """
        # Compute continuous-time Jacobians from dynamics.f
        A_fn = jax.jacfwd(dynamics.f, argnums=0)
        B_fn = jax.jacfwd(dynamics.f, argnums=1)

        # Vmap for batch evaluation across nodes
        f_vmapped = jax.vmap(dynamics.f, in_axes=(0, 0, 0, None))
        A_vmapped = jax.vmap(A_fn, in_axes=(0, 0, 0, None))
        B_vmapped = jax.vmap(B_fn, in_axes=(0, 0, 0, None))

        # Capture discretizer settings for the returned closure
        discretizer = self

        return lambda x, u, params: _calculate_discretization(
            x=x,
            u=u,
            state_dot=f_vmapped,
            A=A_vmapped,
            B=B_vmapped,
            settings=settings,
            discretizer=discretizer,
            params=params,
        )

    def citation(self) -> List[str]:
        """Return BibTeX citations for the linearize-then-discretize discretization method.

        Returns:
            List containing the BibTeX entries
        """
        return [
            r"""@article{kamath2023real,
  title={Real-time sequential conic optimization for multi-phase rocket landing guidance},
  author={Kamath, Abhinav G and Elango, Purnanand and Yu, Yue and Mceowen, Skye and Chari, Govind M
    and Carson III, John M and A{\c{c}}{\i}kme{\c{s}}e, Beh{\c{c}}et},
  journal={IFAC-PapersOnLine},
  volume={56},
  number={2},
  pages={3118--3125},
  year={2023},
  publisher={Elsevier}
}""",
        ]
citation() -> List[str]

Return BibTeX citations for the linearize-then-discretize discretization method.

Returns:

Type Description
List[str]

List containing the BibTeX entries

Source code in openscvx/discretization/linearize_discretize.py
    def citation(self) -> List[str]:
        """Return BibTeX citations for the linearize-then-discretize discretization method.

        Returns:
            List containing the BibTeX entries
        """
        return [
            r"""@article{kamath2023real,
  title={Real-time sequential conic optimization for multi-phase rocket landing guidance},
  author={Kamath, Abhinav G and Elango, Purnanand and Yu, Yue and Mceowen, Skye and Chari, Govind M
    and Carson III, John M and A{\c{c}}{\i}kme{\c{s}}e, Beh{\c{c}}et},
  journal={IFAC-PapersOnLine},
  volume={56},
  number={2},
  pages={3118--3125},
  year={2023},
  publisher={Elsevier}
}""",
        ]
get_solver(dynamics: Dynamics, settings: Config) -> callable

Create a multi-shoot discretization solver.

Computes Jacobians of dynamics.f via jax.jacfwd, vmaps all functions for batch evaluation across nodes, and returns a callable that integrates the augmented variational equations.

Parameters:

Name Type Description Default
dynamics Dynamics

System dynamics. Only dynamics.f is used; Jacobians are computed internally via JAX autodiff.

required
settings Config

Problem configuration.

required

Returns:

Type Description
callable

Callable (x, u, params) -> (A_d, B_d, C_d, x_prop, V).

Source code in openscvx/discretization/linearize_discretize.py
def get_solver(self, dynamics: "Dynamics", settings: "Config") -> callable:
    """Create a multi-shoot discretization solver.

    Computes Jacobians of ``dynamics.f`` via ``jax.jacfwd``, vmaps all
    functions for batch evaluation across nodes, and returns a callable
    that integrates the augmented variational equations.

    Args:
        dynamics: System dynamics. Only ``dynamics.f`` is used; Jacobians
            are computed internally via JAX autodiff.
        settings: Problem configuration.

    Returns:
        Callable ``(x, u, params) -> (A_d, B_d, C_d, x_prop, V)``.
    """
    # Compute continuous-time Jacobians from dynamics.f
    A_fn = jax.jacfwd(dynamics.f, argnums=0)
    B_fn = jax.jacfwd(dynamics.f, argnums=1)

    # Vmap for batch evaluation across nodes
    f_vmapped = jax.vmap(dynamics.f, in_axes=(0, 0, 0, None))
    A_vmapped = jax.vmap(A_fn, in_axes=(0, 0, 0, None))
    B_vmapped = jax.vmap(B_fn, in_axes=(0, 0, 0, None))

    # Capture discretizer settings for the returned closure
    discretizer = self

    return lambda x, u, params: _calculate_discretization(
        x=x,
        u=u,
        state_dot=f_vmapped,
        A=A_vmapped,
        B=B_vmapped,
        settings=settings,
        discretizer=discretizer,
        params=params,
    )

calculate_impulsive_discretization(x_nodes: np.ndarray, u_nodes: np.ndarray, state_dot_discrete: callable, A_discrete: callable, B_discrete: callable, params: dict) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]

Evaluate discrete/impulsive dynamics and Jacobians at all nodes.

Source code in openscvx/discretization/linearize_discretize.py
def calculate_impulsive_discretization(
    x_nodes: np.ndarray,
    u_nodes: np.ndarray,
    state_dot_discrete: callable,
    A_discrete: callable,
    B_discrete: callable,
    params: dict,
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
    """Evaluate discrete/impulsive dynamics and Jacobians at all nodes."""
    n_nodes = x_nodes.shape[0]
    nodes = jnp.arange(n_nodes)
    n_x = x_nodes.shape[1]
    n_u = u_nodes.shape[1]

    x_prop_plus = state_dot_discrete(x_nodes, u_nodes, nodes, params)
    D_d = A_discrete(x_nodes, u_nodes, nodes, params)
    E_d = B_discrete(x_nodes, u_nodes, nodes, params)

    W_end = jnp.concatenate(
        (
            x_prop_plus,
            D_d.reshape(n_nodes, n_x * n_x),
            E_d.reshape(n_nodes, n_x * n_u),
        ),
        axis=1,
    )
    W = W_end.reshape(-1, 1)
    return x_prop_plus, D_d, E_d, W

get_impulsive_discretization_solver(dyn_discrete: Dynamics) -> callable

Create a solver for discrete/impulsive dynamics linearization.

Source code in openscvx/discretization/linearize_discretize.py
def get_impulsive_discretization_solver(dyn_discrete: "Dynamics") -> callable:
    """Create a solver for discrete/impulsive dynamics linearization."""
    A_fn = jax.jacfwd(dyn_discrete.f, argnums=0)
    B_fn = jax.jacfwd(dyn_discrete.f, argnums=1)

    f_vmapped = jax.vmap(dyn_discrete.f, in_axes=(0, 0, 0, None))
    A_vmapped = jax.vmap(A_fn, in_axes=(0, 0, 0, None))
    B_vmapped = jax.vmap(B_fn, in_axes=(0, 0, 0, None))

    return lambda x_nodes, u_nodes, params: calculate_impulsive_discretization(
        x_nodes=x_nodes,
        u_nodes=u_nodes,
        state_dot_discrete=f_vmapped,
        A_discrete=A_vmapped,
        B_discrete=B_vmapped,
        params=params,
    )