Skip to content

discretization

Discretization methods for trajectory optimization.

This module provides implementations of discretization schemes that convert continuous-time optimal control problems into discrete-time approximations suitable for numerical optimization.

Discretization and linearization are combined into a single interface (:class:Discretizer) because different schemes may linearize then discretize, discretize then linearize, or use other approaches. The ordering changes the intermediate types, but the input (continuous nonlinear dynamics + reference trajectory) and output (discrete-time linear matrices A_d, B_d, C_d) are always consistent.

:class:Problem uses :class:LinearizeDiscretizeSparse by default (sparse Jacobians and compact variational integration when sparsity patterns exist). :class:LinearizeDiscretize is the dense linearize-then-discretize scheme.

DiscretizeLinearizeVectorize

Bases: Discretizer

Discretization via differentiating through the integrator for each segment individually.

Propagates the nonlinear dynamics for each trajectory segment on its own, then directly computes Jacobians of the propagated solutions (dF/dx, dF/du) via JAX forward-mode autodiff to produce discrete-time Jacobians for each node.

Supports ZOH (zero-order hold) and FOH (first-order hold) control interpolation between nodes, including independent per-control selection.

Use this integration scheme when the nonlinear dynamics are challenging (e.g. stiff/sensitive, badly scaled, or over long time horizons) and require very tight tolerances. A prototypical example is atmospheric entry of a spacecraft.

Parameters:

Name Type Description Default
dis_type Union[str, Sequence[str]]

Control hold type. "FOH" (first-order hold) or "ZOH" (zero-order hold) applies the same hold to every control. A per-control sequence (e.g. ["FOH", "ZOH", "FOH"]) sets the hold independently for each control, merged with any per-control parameterization ("FOH" / "ZOH"). Defaults to "FOH".

'FOH'
ode_solver str

Diffrax solver name. Any solver from Diffrax <https://docs.kidger.site/diffrax/usage/how-to-choose-a-solver/>_ is valid. Defaults to "Tsit5".

'Tsit5'
custom_integrator bool

Use the built-in fixed-step RK45 integrator instead of Diffrax. Faster but less robust. Defaults to False.

False
diffrax_kwargs Optional[dict[str, Any]]

Preferred Diffrax keyword overrides. These map to :func:openscvx.integrators.solve_ivp_diffrax kwargs, and unknown keys are forwarded to :func:diffrax.diffeqsolve (e.g. stepsize_controller). Defaults to {}.

None
args Optional[dict]

Deprecated alias for diffrax_kwargs kept for backward compatibility.

None
Source code in openscvx/discretization/discretize_linearize.py
class DiscretizeLinearizeVectorize(Discretizer):
    """Discretization via differentiating through the integrator for each segment individually.

    Propagates the nonlinear dynamics for each trajectory segment on its own, then directly computes
    Jacobians of the propagated solutions (dF/dx, dF/du) via JAX forward-mode autodiff to produce
    discrete-time Jacobians for each node.

    Supports ZOH (zero-order hold) and FOH (first-order hold) control interpolation between nodes,
    including independent per-control selection.

    Use this integration scheme when the nonlinear dynamics are challenging (e.g. stiff/sensitive,
    badly scaled, or over long time horizons) and require very tight tolerances. A prototypical
    example is atmospheric entry of a spacecraft.

    Args:
        dis_type: Control hold type. ``"FOH"`` (first-order hold) or
            ``"ZOH"`` (zero-order hold) applies the same hold to every
            control.  A per-control sequence (e.g. ``["FOH", "ZOH", "FOH"]``)
            sets the hold independently for each control, merged with any
            per-control ``parameterization`` (``"FOH"`` / ``"ZOH"``).
            Defaults to ``"FOH"``.
        ode_solver: Diffrax solver name. Any solver from
            `Diffrax <https://docs.kidger.site/diffrax/usage/how-to-choose-a-solver/>`_
            is valid. Defaults to ``"Tsit5"``.
        custom_integrator: Use the built-in fixed-step RK45 integrator instead of Diffrax.
            Faster but less robust. Defaults to ``False``.
        diffrax_kwargs: Preferred Diffrax keyword overrides. These map to
            :func:`openscvx.integrators.solve_ivp_diffrax` kwargs, and unknown
            keys are forwarded to :func:`diffrax.diffeqsolve` (e.g.
            ``stepsize_controller``). Defaults to ``{}``.
        args: Deprecated alias for ``diffrax_kwargs`` kept for backward compatibility.
    """

    def __init__(
        self,
        dis_type: Union[str, Sequence[str]] = "FOH",
        ode_solver: str = "Tsit5",
        custom_integrator: bool = False,
        diffrax_kwargs: Optional[dict[str, Any]] = None,
        args: Optional[dict] = None,
    ):
        self.dis_type = dis_type
        self.ode_solver = ode_solver
        self.custom_integrator = custom_integrator
        merged_kwargs: dict[str, Any] = {}
        if diffrax_kwargs is not None:
            merged_kwargs.update(dict(diffrax_kwargs))
        if args is not None:
            merged_kwargs.update(dict(args))
        self.diffrax_kwargs = merged_kwargs

    def _resolve_diffrax_kwargs(self, *, n_segments: int) -> dict[str, Any]:
        kwargs: dict[str, Any] = {
            "solver_name": self.ode_solver,
            # Error budget is distributed since one segment is integrated at a time.
            "rtol": DEFAULT_DIFFRAX_RTOL / n_segments,
            "atol": DEFAULT_DIFFRAX_ATOL / n_segments,
        }
        user_kwargs = dict(self.diffrax_kwargs)
        extra_kwargs: dict[str, Any] = {"adjoint": dfx.ForwardMode()}

        nested_extra = user_kwargs.pop("extra_kwargs", None)
        if nested_extra is not None:
            extra_kwargs.update(dict(nested_extra))

        direct_keys = {"tau_0", "num_substeps", "solver_name", "rtol", "atol"}
        for key, value in user_kwargs.items():
            if key in direct_keys:
                kwargs[key] = value
            else:
                extra_kwargs[key] = value

        if "rtol" in user_kwargs:
            kwargs["rtol"] = user_kwargs["rtol"] / n_segments
        if "atol" in user_kwargs:
            kwargs["atol"] = user_kwargs["atol"] / n_segments

        kwargs["extra_kwargs"] = extra_kwargs
        return kwargs

    def _resolve_rk45_kwargs(self, *, is_not_compiled: bool) -> dict[str, Any]:
        kwargs: dict[str, Any] = {"is_not_compiled": is_not_compiled}
        direct_keys = {"tau_0", "num_substeps", "is_not_compiled"}
        for key, value in self.diffrax_kwargs.items():
            if key in direct_keys:
                kwargs[key] = value
        return kwargs

    def get_solver(self, dynamics: "Dynamics", settings: "Config") -> callable:
        """Create a multiple-shooting discretize-then-linearize-then-vectorize solver.

        Integrates ``dynamics.f`` and computes Jacobians of the discretized function directly
        (i.e. without using the variational equations). Outputs are vmapped for batch evaluation
        across nodes.

        Args:
            dynamics: System dynamics. Only ``dynamics.f`` is used; Jacobians are computed
                internally via JAX autodiff.
            settings: Problem configuration.

        Returns:
            Callable ``(x, u, params) -> (A_d, B_d, C_d, x_prop, V)``.
        """

        N = settings.sim.n
        n_x = settings.sim.n_states
        n_u = settings.sim.n_controls
        u_foh_mask = getattr(settings.sim.u, "foh_mask", None)
        foh_mask = _resolve_foh_mask(self.dis_type, n_u, u_foh_mask)

        single_state_dot = dynamics.f

        def single_dxdt(
            tau: float,
            x: jnp.ndarray,
            u_cur: np.ndarray,
            u_next: np.ndarray,
            node: int,
            params: dict,
        ) -> jnp.ndarray:

            beta = tau * N * foh_mask

            u = u_cur + beta * (u_next - u_cur)
            F = single_state_dot(x, u, node, params)

            return F

        def single_shot(
            x: jnp.ndarray,
            u_cur: np.ndarray,
            u_next: np.ndarray,
            node: int,
            params: dict,
        ) -> jnp.ndarray:

            if self.custom_integrator:
                rk45_kwargs = self._resolve_rk45_kwargs(is_not_compiled=settings.dev.debug)
                sol = solve_ivp_rk45(
                    single_dxdt,
                    1.0 / (N - 1),
                    x,
                    args=(u_cur, u_next, node, params),
                    **rk45_kwargs,
                )
            else:
                diffrax_kwargs = self._resolve_diffrax_kwargs(n_segments=N - 1)
                sol = solve_ivp_diffrax(
                    single_dxdt,
                    1.0 / (N - 1),
                    x,
                    args=(u_cur, u_next, node, params),
                    **diffrax_kwargs,
                )
            return sol

        discretize_then_vectorize = jax.vmap(single_shot, in_axes=(0, 0, 0, 0, None), out_axes=1)
        discretize_then_linearize = jax.jacfwd(single_shot, argnums=(0, 1, 2))
        discretize_then_linearize_then_vectorize = jax.vmap(
            discretize_then_linearize, in_axes=(0, 0, 0, 0, None), out_axes=1
        )

        nodes = jnp.arange(0, N - 1)

        def solver(x, u, params):
            A_d, B_d, C_d = discretize_then_linearize_then_vectorize(
                x[:-1], u[:-1], u[1:], nodes, params
            )
            x_prop = discretize_then_vectorize(x[:-1], u[:-1], u[1:], nodes, params)

            # TODO: providing the histories of A, B, and C can lead to as much as a 20% slowdown.
            # If they aren't getting used, they shouldn't be here. V_multi should be replaced with
            # an output directly corresponding to the history of x_prop.
            V_multi = jnp.concatenate(
                [
                    x_prop,
                    A_d.reshape(-1, N - 1, n_x * n_x),
                    B_d.reshape(-1, N - 1, n_x * n_u),
                    C_d.reshape(-1, N - 1, n_x * n_u),
                ],
                axis=2,
            )
            i4 = V_multi.shape[2]
            V_multi = V_multi.reshape(-1, (N - 1) * i4).T

            return A_d[-1], B_d[-1], C_d[-1], x_prop[-1], V_multi

        return solver

    def citation(self) -> List[str]:
        """Return BibTeX citations for the discretize-then-linearize-then-vectorize method.

        Returns:
            List containing the BibTeX entries
        """
        return [
            r"""@phdthesis{kidger2021on,
  title={{O}n {N}eural {D}ifferential {E}quations},
  author={Patrick Kidger},
  year={2021},
  school={University of Oxford},
}""",
        ]
citation() -> List[str]

Return BibTeX citations for the discretize-then-linearize-then-vectorize method.

Returns:

Type Description
List[str]

List containing the BibTeX entries

Source code in openscvx/discretization/discretize_linearize.py
    def citation(self) -> List[str]:
        """Return BibTeX citations for the discretize-then-linearize-then-vectorize method.

        Returns:
            List containing the BibTeX entries
        """
        return [
            r"""@phdthesis{kidger2021on,
  title={{O}n {N}eural {D}ifferential {E}quations},
  author={Patrick Kidger},
  year={2021},
  school={University of Oxford},
}""",
        ]
get_solver(dynamics: Dynamics, settings: Config) -> callable

Create a multiple-shooting discretize-then-linearize-then-vectorize solver.

Integrates dynamics.f and computes Jacobians of the discretized function directly (i.e. without using the variational equations). Outputs are vmapped for batch evaluation across nodes.

Parameters:

Name Type Description Default
dynamics Dynamics

System dynamics. Only dynamics.f is used; Jacobians are computed internally via JAX autodiff.

required
settings Config

Problem configuration.

required

Returns:

Type Description
callable

Callable (x, u, params) -> (A_d, B_d, C_d, x_prop, V).

Source code in openscvx/discretization/discretize_linearize.py
def get_solver(self, dynamics: "Dynamics", settings: "Config") -> callable:
    """Create a multiple-shooting discretize-then-linearize-then-vectorize solver.

    Integrates ``dynamics.f`` and computes Jacobians of the discretized function directly
    (i.e. without using the variational equations). Outputs are vmapped for batch evaluation
    across nodes.

    Args:
        dynamics: System dynamics. Only ``dynamics.f`` is used; Jacobians are computed
            internally via JAX autodiff.
        settings: Problem configuration.

    Returns:
        Callable ``(x, u, params) -> (A_d, B_d, C_d, x_prop, V)``.
    """

    N = settings.sim.n
    n_x = settings.sim.n_states
    n_u = settings.sim.n_controls
    u_foh_mask = getattr(settings.sim.u, "foh_mask", None)
    foh_mask = _resolve_foh_mask(self.dis_type, n_u, u_foh_mask)

    single_state_dot = dynamics.f

    def single_dxdt(
        tau: float,
        x: jnp.ndarray,
        u_cur: np.ndarray,
        u_next: np.ndarray,
        node: int,
        params: dict,
    ) -> jnp.ndarray:

        beta = tau * N * foh_mask

        u = u_cur + beta * (u_next - u_cur)
        F = single_state_dot(x, u, node, params)

        return F

    def single_shot(
        x: jnp.ndarray,
        u_cur: np.ndarray,
        u_next: np.ndarray,
        node: int,
        params: dict,
    ) -> jnp.ndarray:

        if self.custom_integrator:
            rk45_kwargs = self._resolve_rk45_kwargs(is_not_compiled=settings.dev.debug)
            sol = solve_ivp_rk45(
                single_dxdt,
                1.0 / (N - 1),
                x,
                args=(u_cur, u_next, node, params),
                **rk45_kwargs,
            )
        else:
            diffrax_kwargs = self._resolve_diffrax_kwargs(n_segments=N - 1)
            sol = solve_ivp_diffrax(
                single_dxdt,
                1.0 / (N - 1),
                x,
                args=(u_cur, u_next, node, params),
                **diffrax_kwargs,
            )
        return sol

    discretize_then_vectorize = jax.vmap(single_shot, in_axes=(0, 0, 0, 0, None), out_axes=1)
    discretize_then_linearize = jax.jacfwd(single_shot, argnums=(0, 1, 2))
    discretize_then_linearize_then_vectorize = jax.vmap(
        discretize_then_linearize, in_axes=(0, 0, 0, 0, None), out_axes=1
    )

    nodes = jnp.arange(0, N - 1)

    def solver(x, u, params):
        A_d, B_d, C_d = discretize_then_linearize_then_vectorize(
            x[:-1], u[:-1], u[1:], nodes, params
        )
        x_prop = discretize_then_vectorize(x[:-1], u[:-1], u[1:], nodes, params)

        # TODO: providing the histories of A, B, and C can lead to as much as a 20% slowdown.
        # If they aren't getting used, they shouldn't be here. V_multi should be replaced with
        # an output directly corresponding to the history of x_prop.
        V_multi = jnp.concatenate(
            [
                x_prop,
                A_d.reshape(-1, N - 1, n_x * n_x),
                B_d.reshape(-1, N - 1, n_x * n_u),
                C_d.reshape(-1, N - 1, n_x * n_u),
            ],
            axis=2,
        )
        i4 = V_multi.shape[2]
        V_multi = V_multi.reshape(-1, (N - 1) * i4).T

        return A_d[-1], B_d[-1], C_d[-1], x_prop[-1], V_multi

    return solver

Discretizer

Bases: ABC

Abstract base class for dynamics linearization and discretization.

This class defines the interface for converting continuous-time nonlinear dynamics into discrete-time linear approximations suitable for convex subproblems in successive convexification.

The lifecycle mirrors other OpenSCvx ABCs:

Setup (called once):

  • get_solver: Build a callable that computes discrete-time matrices

Per-iteration (via the returned callable):

  • The callable is invoked with a reference trajectory and parameters, returning discretized matrices (A_d, B_d, C_d, x_prop)

Discretization parameters (hold type, integrator, tolerances) live on each concrete subclass as instance attributes.

Subclasses must implement the get_solver and citation methods.

Example

Implementing a custom discretizer::

class EulerDiscretizer(Discretizer):
    def get_solver(self, dynamics, settings):
        def solver(x, u, params):
            # Euler discretization of dynamics
            ...
            return A_d, B_d, C_d, x_prop, V
        return solver

    def citation(self):
        return []
Source code in openscvx/discretization/base.py
class Discretizer(ABC):
    """Abstract base class for dynamics linearization and discretization.

    This class defines the interface for converting continuous-time nonlinear
    dynamics into discrete-time linear approximations suitable for convex
    subproblems in successive convexification.

    The lifecycle mirrors other OpenSCvx ABCs:

    **Setup (called once):**

    - get_solver: Build a callable that computes discrete-time matrices

    **Per-iteration (via the returned callable):**

    - The callable is invoked with a reference trajectory and parameters,
      returning discretized matrices (A_d, B_d, C_d, x_prop)

    Discretization parameters (hold type, integrator, tolerances) live on each
    concrete subclass as instance attributes.

    Subclasses must implement the ``get_solver`` and ``citation`` methods.

    Example:
        Implementing a custom discretizer::

            class EulerDiscretizer(Discretizer):
                def get_solver(self, dynamics, settings):
                    def solver(x, u, params):
                        # Euler discretization of dynamics
                        ...
                        return A_d, B_d, C_d, x_prop, V
                    return solver

                def citation(self):
                    return []
    """

    #: Control hold type. A single ``"FOH"`` or ``"ZOH"`` string applies the
    #: same hold to every control.  A sequence (e.g.
    #: ``["FOH", "ZOH", "FOH"]``) sets the hold independently for each
    #: control, and is merged with any per-control ``Control.parameterization``
    #: (``"FOH"`` / ``"ZOH"``).
    #: Subclasses must set this in ``__init__``.
    dis_type: DisType

    #: ODE solver name used for integration (e.g., ``"Tsit5"``).  Subclasses
    #: must set this in ``__init__``.
    ode_solver: str

    @abstractmethod
    def get_solver(self, dynamics: "Dynamics", settings: "Config") -> callable:
        """Create a discretization solver callable.

        Called once during problem initialization. Returns a function that
        computes linearized discrete-time dynamics matrices around a reference
        trajectory. The returned callable will be JIT-compiled and cached by
        the framework.

        Implementations are responsible for computing any Jacobians they need.
        The ``dynamics`` object always provides ``dynamics.f`` (the continuous-
        time nonlinear dynamics). Implementations that linearize first may
        compute Jacobians via ``jax.jacfwd(dynamics.f, ...)``.

        Args:
            dynamics: System dynamics object. ``dynamics.f`` is the continuous-
                time nonlinear dynamics function with signature
                ``f(x, u, node, params) -> x_dot``.
            settings: Problem configuration (node count, scaling matrices, etc.).

        Returns:
            Callable with signature
            ``(x: ndarray, u: ndarray, params: dict) -> (A_d, B_d, C_d, x_prop, V)``
            where:

            - ``A_d``: (N-1, n_x, n_x) discretized state transition matrix
            - ``B_d``: (N-1, n_x, n_u) control influence matrix (current node)
            - ``C_d``: (N-1, n_x, n_u) control influence matrix (next node)
            - ``x_prop``: (N-1, n_x) propagated state
            - ``V``: raw integration data (implementation-specific, used for
                diagnostics and history tracking)
        """
        raise NotImplementedError

    @abstractmethod
    def citation(self) -> List[str]:
        """Return BibTeX citations for this discretization method.

        Implementations should return a list of BibTeX entry strings for the
        papers that should be cited when using this discretization scheme.

        Returns:
            List of BibTeX citation strings.
        """
        raise NotImplementedError
citation() -> List[str] abstractmethod

Return BibTeX citations for this discretization method.

Implementations should return a list of BibTeX entry strings for the papers that should be cited when using this discretization scheme.

Returns:

Type Description
List[str]

List of BibTeX citation strings.

Source code in openscvx/discretization/base.py
@abstractmethod
def citation(self) -> List[str]:
    """Return BibTeX citations for this discretization method.

    Implementations should return a list of BibTeX entry strings for the
    papers that should be cited when using this discretization scheme.

    Returns:
        List of BibTeX citation strings.
    """
    raise NotImplementedError
get_solver(dynamics: Dynamics, settings: Config) -> callable abstractmethod

Create a discretization solver callable.

Called once during problem initialization. Returns a function that computes linearized discrete-time dynamics matrices around a reference trajectory. The returned callable will be JIT-compiled and cached by the framework.

Implementations are responsible for computing any Jacobians they need. The dynamics object always provides dynamics.f (the continuous- time nonlinear dynamics). Implementations that linearize first may compute Jacobians via jax.jacfwd(dynamics.f, ...).

Parameters:

Name Type Description Default
dynamics Dynamics

System dynamics object. dynamics.f is the continuous- time nonlinear dynamics function with signature f(x, u, node, params) -> x_dot.

required
settings Config

Problem configuration (node count, scaling matrices, etc.).

required

Returns:

Name Type Description
callable

Callable with signature

callable

(x: ndarray, u: ndarray, params: dict) -> (A_d, B_d, C_d, x_prop, V)

where callable
callable
  • A_d: (N-1, n_x, n_x) discretized state transition matrix
callable
  • B_d: (N-1, n_x, n_u) control influence matrix (current node)
callable
  • C_d: (N-1, n_x, n_u) control influence matrix (next node)
callable
  • x_prop: (N-1, n_x) propagated state
callable
  • V: raw integration data (implementation-specific, used for diagnostics and history tracking)
Source code in openscvx/discretization/base.py
@abstractmethod
def get_solver(self, dynamics: "Dynamics", settings: "Config") -> callable:
    """Create a discretization solver callable.

    Called once during problem initialization. Returns a function that
    computes linearized discrete-time dynamics matrices around a reference
    trajectory. The returned callable will be JIT-compiled and cached by
    the framework.

    Implementations are responsible for computing any Jacobians they need.
    The ``dynamics`` object always provides ``dynamics.f`` (the continuous-
    time nonlinear dynamics). Implementations that linearize first may
    compute Jacobians via ``jax.jacfwd(dynamics.f, ...)``.

    Args:
        dynamics: System dynamics object. ``dynamics.f`` is the continuous-
            time nonlinear dynamics function with signature
            ``f(x, u, node, params) -> x_dot``.
        settings: Problem configuration (node count, scaling matrices, etc.).

    Returns:
        Callable with signature
        ``(x: ndarray, u: ndarray, params: dict) -> (A_d, B_d, C_d, x_prop, V)``
        where:

        - ``A_d``: (N-1, n_x, n_x) discretized state transition matrix
        - ``B_d``: (N-1, n_x, n_u) control influence matrix (current node)
        - ``C_d``: (N-1, n_x, n_u) control influence matrix (next node)
        - ``x_prop``: (N-1, n_x) propagated state
        - ``V``: raw integration data (implementation-specific, used for
            diagnostics and history tracking)
    """
    raise NotImplementedError

DiscretizerSpec

Bases: BaseModel

Validates discretizer configuration from dict/YAML input.

A single spec covers all discretizer types. The type field selects the concrete class; custom_integrator and args are only used by the two vectorized variants and are silently ignored by the others.

Source code in openscvx/discretization/base.py
class DiscretizerSpec(BaseModel):
    """Validates discretizer configuration from dict/YAML input.

    A single spec covers all discretizer types.  The ``type`` field selects
    the concrete class; ``custom_integrator`` and ``args`` are only used by
    the two vectorized variants and are silently ignored by the others.
    """

    type: Literal[
        "VectorizeDiscretizeLinearize",
        "DiscretizeLinearizeVectorize",
        "LinearizeDiscretize",
        "LinearizeDiscretizeSparse",
    ] = "VectorizeDiscretizeLinearize"
    dis_type: Union[str, List[str]] = "FOH"
    ode_solver: str = "Tsit5"
    diffrax_kwargs: Optional[Dict[str, Any]] = None
    custom_integrator: bool = False
    args: Optional[Dict[str, Any]] = None

    model_config = ConfigDict(extra="forbid")

    def build(self) -> Discretizer:
        cls = _DISCRETIZER_MAP.get(self.type)
        if cls is None:
            raise ValueError(
                f"Unknown discretizer {self.type!r}; expected one of {sorted(_DISCRETIZER_MAP)}"
            )
        return cls(**self.model_dump(exclude={"type"}, exclude_unset=True))

LinearizeDiscretize

Bases: Discretizer

Linearize-then-discretize via augmented ODE integration.

Computes continuous-time Jacobians (df/dx, df/du) via JAX forward-mode autodiff, then integrates them alongside the nonlinear dynamics through an augmented state vector using a multi-shooting approach to produce discrete-time matrices.

Supports ZOH (zero-order hold) and FOH (first-order hold) control interpolation between nodes.

This is the default discretization scheme in OpenSCvx.

Parameters:

Name Type Description Default
dis_type Union[str, Sequence[str]]

Control hold type. "FOH" (first-order hold) or "ZOH" (zero-order hold) applies the same hold to every control. A per-control sequence (e.g. ["FOH", "ZOH", "FOH"]) sets the hold independently for each control, merged with any per-control parameterization ("FOH" / "ZOH"). Defaults to "FOH".

'FOH'
ode_solver str

Diffrax solver name. Any solver from Diffrax <https://docs.kidger.site/diffrax/usage/how-to-choose-a-solver/>_ is valid. Defaults to "Tsit5".

'Tsit5'
diffrax_kwargs Optional[dict[str, Any]]

Preferred Diffrax keyword overrides. These map to :func:openscvx.integrators.solve_ivp_diffrax kwargs, and unknown keys are forwarded to :func:diffrax.diffeqsolve (e.g. stepsize_controller). Set rtol/atol here when using the default PID controller. Defaults to {}.

None
Source code in openscvx/discretization/linearize_discretize.py
class LinearizeDiscretize(Discretizer):
    """Linearize-then-discretize via augmented ODE integration.

    Computes continuous-time Jacobians (df/dx, df/du) via JAX forward-mode
    autodiff, then integrates them alongside the nonlinear dynamics through
    an augmented state vector using a multi-shooting approach to produce
    discrete-time matrices.

    Supports ZOH (zero-order hold) and FOH (first-order hold) control
    interpolation between nodes.

    This is the default discretization scheme in OpenSCvx.

    Args:
        dis_type: Control hold type. ``"FOH"`` (first-order hold) or
            ``"ZOH"`` (zero-order hold) applies the same hold to every
            control.  A per-control sequence (e.g. ``["FOH", "ZOH", "FOH"]``)
            sets the hold independently for each control, merged with any
            per-control ``parameterization`` (``"FOH"`` / ``"ZOH"``).
            Defaults to ``"FOH"``.
        ode_solver: Diffrax solver name. Any solver from
            `Diffrax <https://docs.kidger.site/diffrax/usage/how-to-choose-a-solver/>`_
            is valid. Defaults to ``"Tsit5"``.
        diffrax_kwargs: Preferred Diffrax keyword overrides. These map to
            :func:`openscvx.integrators.solve_ivp_diffrax` kwargs, and unknown
            keys are forwarded to :func:`diffrax.diffeqsolve` (e.g.
            ``stepsize_controller``). Set ``rtol``/``atol`` here when using
            the default PID controller. Defaults to ``{}``.
    """

    def __init__(
        self,
        dis_type: Union[str, Sequence[str]] = "FOH",
        ode_solver: str = "Tsit5",
        diffrax_kwargs: Optional[dict[str, Any]] = None,
    ):
        self.dis_type = dis_type
        self.ode_solver = ode_solver
        self.diffrax_kwargs = dict(diffrax_kwargs) if diffrax_kwargs is not None else {}

    def _resolve_diffrax_kwargs(self) -> dict[str, Any]:
        """Build kwargs for :func:`solve_ivp_diffrax`.

        Unknown keys from ``self.diffrax_kwargs`` are forwarded to
        :func:`diffrax.diffeqsolve` via ``extra_kwargs`` so users can pass
        objects like ``stepsize_controller`` directly from discretizer settings.
        """
        kwargs: dict[str, Any] = {
            "solver_name": self.ode_solver,
            "rtol": DEFAULT_DIFFRAX_RTOL,
            "atol": DEFAULT_DIFFRAX_ATOL,
        }

        user_kwargs = dict(self.diffrax_kwargs)
        extra_kwargs: dict[str, Any] = {}

        nested_extra = user_kwargs.pop("extra_kwargs", None)
        if nested_extra is not None:
            extra_kwargs.update(dict(nested_extra))

        direct_keys = {"tau_0", "num_substeps", "solver_name", "rtol", "atol"}
        for key, value in user_kwargs.items():
            if key in direct_keys:
                kwargs[key] = value
            else:
                extra_kwargs[key] = value

        kwargs["extra_kwargs"] = extra_kwargs
        return kwargs

    def _resolve_rk45_kwargs(self, *, is_not_compiled: bool) -> dict[str, Any]:
        """Build kwargs for :func:`solve_ivp_rk45`."""
        kwargs: dict[str, Any] = {"is_not_compiled": is_not_compiled}
        direct_keys = {"tau_0", "num_substeps", "is_not_compiled"}
        for key, value in self.diffrax_kwargs.items():
            if key in direct_keys:
                kwargs[key] = value
        return kwargs

    def get_solver(self, dynamics: "Dynamics", settings: "Config") -> callable:
        """Create a multi-shoot discretization solver.

        Computes Jacobians of ``dynamics.f`` via ``jax.jacfwd``, vmaps all
        functions for batch evaluation across nodes, and returns a callable
        that integrates the augmented variational equations.

        Args:
            dynamics: System dynamics. Only ``dynamics.f`` is used; Jacobians
                are computed internally via JAX autodiff.
            settings: Problem configuration.

        Returns:
            Callable ``(x, u, params) -> (A_d, B_d, C_d, x_prop, V)``.
        """
        # Compute continuous-time Jacobians from dynamics.f
        A_fn = jax.jacfwd(dynamics.f, argnums=0)
        B_fn = jax.jacfwd(dynamics.f, argnums=1)

        # Vmap for batch evaluation across nodes
        f_vmapped = jax.vmap(dynamics.f, in_axes=(0, 0, 0, None))
        A_vmapped = jax.vmap(A_fn, in_axes=(0, 0, 0, None))
        B_vmapped = jax.vmap(B_fn, in_axes=(0, 0, 0, None))

        # Capture discretizer settings for the returned closure
        discretizer = self

        return lambda x, u, params: _calculate_discretization(
            x=x,
            u=u,
            state_dot=f_vmapped,
            A=A_vmapped,
            B=B_vmapped,
            settings=settings,
            discretizer=discretizer,
            params=params,
        )

    def citation(self) -> List[str]:
        """Return BibTeX citations for the linearize-then-discretize discretization method.

        Returns:
            List containing the BibTeX entries
        """
        return [
            r"""@article{kamath2023real,
  title={Real-time sequential conic optimization for multi-phase rocket landing guidance},
  author={Kamath, Abhinav G and Elango, Purnanand and Yu, Yue and Mceowen, Skye and Chari, Govind M
    and Carson III, John M and A{\c{c}}{\i}kme{\c{s}}e, Beh{\c{c}}et},
  journal={IFAC-PapersOnLine},
  volume={56},
  number={2},
  pages={3118--3125},
  year={2023},
  publisher={Elsevier}
}""",
        ]
citation() -> List[str]

Return BibTeX citations for the linearize-then-discretize discretization method.

Returns:

Type Description
List[str]

List containing the BibTeX entries

Source code in openscvx/discretization/linearize_discretize.py
    def citation(self) -> List[str]:
        """Return BibTeX citations for the linearize-then-discretize discretization method.

        Returns:
            List containing the BibTeX entries
        """
        return [
            r"""@article{kamath2023real,
  title={Real-time sequential conic optimization for multi-phase rocket landing guidance},
  author={Kamath, Abhinav G and Elango, Purnanand and Yu, Yue and Mceowen, Skye and Chari, Govind M
    and Carson III, John M and A{\c{c}}{\i}kme{\c{s}}e, Beh{\c{c}}et},
  journal={IFAC-PapersOnLine},
  volume={56},
  number={2},
  pages={3118--3125},
  year={2023},
  publisher={Elsevier}
}""",
        ]
get_solver(dynamics: Dynamics, settings: Config) -> callable

Create a multi-shoot discretization solver.

Computes Jacobians of dynamics.f via jax.jacfwd, vmaps all functions for batch evaluation across nodes, and returns a callable that integrates the augmented variational equations.

Parameters:

Name Type Description Default
dynamics Dynamics

System dynamics. Only dynamics.f is used; Jacobians are computed internally via JAX autodiff.

required
settings Config

Problem configuration.

required

Returns:

Type Description
callable

Callable (x, u, params) -> (A_d, B_d, C_d, x_prop, V).

Source code in openscvx/discretization/linearize_discretize.py
def get_solver(self, dynamics: "Dynamics", settings: "Config") -> callable:
    """Create a multi-shoot discretization solver.

    Computes Jacobians of ``dynamics.f`` via ``jax.jacfwd``, vmaps all
    functions for batch evaluation across nodes, and returns a callable
    that integrates the augmented variational equations.

    Args:
        dynamics: System dynamics. Only ``dynamics.f`` is used; Jacobians
            are computed internally via JAX autodiff.
        settings: Problem configuration.

    Returns:
        Callable ``(x, u, params) -> (A_d, B_d, C_d, x_prop, V)``.
    """
    # Compute continuous-time Jacobians from dynamics.f
    A_fn = jax.jacfwd(dynamics.f, argnums=0)
    B_fn = jax.jacfwd(dynamics.f, argnums=1)

    # Vmap for batch evaluation across nodes
    f_vmapped = jax.vmap(dynamics.f, in_axes=(0, 0, 0, None))
    A_vmapped = jax.vmap(A_fn, in_axes=(0, 0, 0, None))
    B_vmapped = jax.vmap(B_fn, in_axes=(0, 0, 0, None))

    # Capture discretizer settings for the returned closure
    discretizer = self

    return lambda x, u, params: _calculate_discretization(
        x=x,
        u=u,
        state_dot=f_vmapped,
        A=A_vmapped,
        B=B_vmapped,
        settings=settings,
        discretizer=discretizer,
        params=params,
    )

LinearizeDiscretizeSparse

Bases: LinearizeDiscretize

Sparse variant of linearize-then-discretize.

Uses graph-coloring-based sparse Jacobian computation and a compact augmented state vector that only integrates the structurally nonzero entries of Φ, B_d and C_d. This reduces the ODE system dimension from n_x + n_x² + 2·n_x·n_u to n_x + nnz_Ad + nnz_Bd + nnz_Cd per segment.

Requires A_c_sparsity and B_c_sparsity boolean arrays on the :class:Dynamics object (set automatically when using the symbolic problem interface). Falls back to the dense :class:LinearizeDiscretize path when sparsity patterns are unavailable or fully dense.

Parameters:

Name Type Description Default
dis_type Union[str, Sequence[str]]

Control hold type. "FOH" or "ZOH". Defaults to "FOH".

'FOH'
ode_solver str

Diffrax solver name. Defaults to "Tsit5".

'Tsit5'
diffrax_kwargs Optional[dict[str, Any]]

Diffrax keyword overrides inherited from :class:LinearizeDiscretize. Unknown keys are forwarded to :func:diffrax.diffeqsolve via extra_kwargs.

None
Source code in openscvx/discretization/linearize_discretize_sparse.py
class LinearizeDiscretizeSparse(LinearizeDiscretize):
    """Sparse variant of linearize-then-discretize.

    Uses graph-coloring-based sparse Jacobian computation and a compact
    augmented state vector that only integrates the structurally nonzero
    entries of Φ, B_d and C_d.  This reduces the ODE system dimension
    from ``n_x + n_x² + 2·n_x·n_u`` to ``n_x + nnz_Ad + nnz_Bd + nnz_Cd``
    per segment.

    Requires ``A_c_sparsity`` and ``B_c_sparsity`` boolean arrays on the
    :class:`Dynamics` object (set automatically when using the symbolic
    problem interface).  Falls back to the dense
    :class:`LinearizeDiscretize` path when sparsity patterns are
    unavailable or fully dense.

    Args:
        dis_type: Control hold type. ``"FOH"`` or ``"ZOH"``.
            Defaults to ``"FOH"``.
        ode_solver: Diffrax solver name. Defaults to ``"Tsit5"``.
        diffrax_kwargs: Diffrax keyword overrides inherited from
            :class:`LinearizeDiscretize`. Unknown keys are forwarded to
            :func:`diffrax.diffeqsolve` via ``extra_kwargs``.
    """

    def get_solver(self, dynamics: "Dynamics", settings: "Config") -> callable:
        """Create a sparse multi-shoot discretization solver.

        When ``dynamics.A_c_sparsity`` and ``dynamics.B_c_sparsity`` are
        available and the pattern is not fully dense, builds a compact-V
        integration path with graph-coloring Jacobians.  Otherwise
        delegates to the dense parent implementation.

        Args:
            dynamics: System dynamics with optional sparsity annotations.
            settings: Problem configuration.

        Returns:
            Callable ``(x, u, params) -> (A_d, B_d, C_d, x_prop, V)``.
        """
        from openscvx.symbolic.sparsity import discrete_sparsity

        from .sparse_utils import make_sparse_jacobian_fns

        A_c_pat = getattr(dynamics, "A_c_sparsity", None)
        B_c_pat = getattr(dynamics, "B_c_sparsity", None)
        has_sparsity = A_c_pat is not None and B_c_pat is not None

        if not has_sparsity or A_c_pat.all():
            return super().get_solver(dynamics, settings)

        f_vmapped = jax.vmap(dynamics.f, in_axes=(0, 0, 0, None))
        discretizer = self
        n_x = settings.sim.n_states
        n_u = settings.sim.n_controls

        A_vmapped, B_vmapped = make_sparse_jacobian_fns(
            dynamics.f,
            A_c_pat,
            B_c_pat,
            n_x,
            n_u,
        )

        u_foh_mask = getattr(settings.sim.u, "foh_mask", None)
        resolved_mask = _resolve_foh_mask(self.dis_type, n_u, u_foh_mask)
        Ad_pat, Bd_pat, Cd_pat = discrete_sparsity(
            A_c_pat,
            B_c_pat,
            resolved_mask,
        )
        Ad_r, Ad_c = np.where(Ad_pat)
        Bd_r, Bd_c = np.where(Bd_pat)
        Cd_r, Cd_c = np.where(Cd_pat)

        sparse_layout = (
            jnp.array(Ad_r),
            jnp.array(Ad_c),
            len(Ad_r),
            jnp.array(Bd_r),
            jnp.array(Bd_c),
            len(Bd_r),
            jnp.array(Cd_r),
            jnp.array(Cd_c),
            len(Cd_r),
        )

        return lambda x, u, params: _calculate_discretization_sparse(
            x=x,
            u=u,
            state_dot=f_vmapped,
            A=A_vmapped,
            B=B_vmapped,
            settings=settings,
            discretizer=discretizer,
            params=params,
            sparse_layout=sparse_layout,
        )
get_solver(dynamics: Dynamics, settings: Config) -> callable

Create a sparse multi-shoot discretization solver.

When dynamics.A_c_sparsity and dynamics.B_c_sparsity are available and the pattern is not fully dense, builds a compact-V integration path with graph-coloring Jacobians. Otherwise delegates to the dense parent implementation.

Parameters:

Name Type Description Default
dynamics Dynamics

System dynamics with optional sparsity annotations.

required
settings Config

Problem configuration.

required

Returns:

Type Description
callable

Callable (x, u, params) -> (A_d, B_d, C_d, x_prop, V).

Source code in openscvx/discretization/linearize_discretize_sparse.py
def get_solver(self, dynamics: "Dynamics", settings: "Config") -> callable:
    """Create a sparse multi-shoot discretization solver.

    When ``dynamics.A_c_sparsity`` and ``dynamics.B_c_sparsity`` are
    available and the pattern is not fully dense, builds a compact-V
    integration path with graph-coloring Jacobians.  Otherwise
    delegates to the dense parent implementation.

    Args:
        dynamics: System dynamics with optional sparsity annotations.
        settings: Problem configuration.

    Returns:
        Callable ``(x, u, params) -> (A_d, B_d, C_d, x_prop, V)``.
    """
    from openscvx.symbolic.sparsity import discrete_sparsity

    from .sparse_utils import make_sparse_jacobian_fns

    A_c_pat = getattr(dynamics, "A_c_sparsity", None)
    B_c_pat = getattr(dynamics, "B_c_sparsity", None)
    has_sparsity = A_c_pat is not None and B_c_pat is not None

    if not has_sparsity or A_c_pat.all():
        return super().get_solver(dynamics, settings)

    f_vmapped = jax.vmap(dynamics.f, in_axes=(0, 0, 0, None))
    discretizer = self
    n_x = settings.sim.n_states
    n_u = settings.sim.n_controls

    A_vmapped, B_vmapped = make_sparse_jacobian_fns(
        dynamics.f,
        A_c_pat,
        B_c_pat,
        n_x,
        n_u,
    )

    u_foh_mask = getattr(settings.sim.u, "foh_mask", None)
    resolved_mask = _resolve_foh_mask(self.dis_type, n_u, u_foh_mask)
    Ad_pat, Bd_pat, Cd_pat = discrete_sparsity(
        A_c_pat,
        B_c_pat,
        resolved_mask,
    )
    Ad_r, Ad_c = np.where(Ad_pat)
    Bd_r, Bd_c = np.where(Bd_pat)
    Cd_r, Cd_c = np.where(Cd_pat)

    sparse_layout = (
        jnp.array(Ad_r),
        jnp.array(Ad_c),
        len(Ad_r),
        jnp.array(Bd_r),
        jnp.array(Bd_c),
        len(Bd_r),
        jnp.array(Cd_r),
        jnp.array(Cd_c),
        len(Cd_r),
    )

    return lambda x, u, params: _calculate_discretization_sparse(
        x=x,
        u=u,
        state_dot=f_vmapped,
        A=A_vmapped,
        B=B_vmapped,
        settings=settings,
        discretizer=discretizer,
        params=params,
        sparse_layout=sparse_layout,
    )

VectorizeDiscretizeLinearize

Bases: Discretizer

Discretization via differentiating through the integrator for all segments simultaneously.

Propagates the nonlinear dynamics over all trajectory segments at once, then directly computes Jacobians of the propagated solutions (dF/dx, dF/du) via JAX forward-mode autodiff to produce discrete-time Jacobians.

Supports ZOH (zero-order hold) and FOH (first-order hold) control interpolation between nodes, including independent per-control selection.

This integration scheme offers the best balance of speed and accuracy for most problems.

Parameters:

Name Type Description Default
dis_type Union[str, Sequence[str]]

Control hold type. "FOH" (first-order hold) or "ZOH" (zero-order hold) applies the same hold to every control. A per-control sequence (e.g. ["FOH", "ZOH", "FOH"]) sets the hold independently for each control, merged with any per-control parameterization ("FOH" / "ZOH"). Defaults to "FOH".

'FOH'
ode_solver str

Diffrax solver name. Any solver from Diffrax <https://docs.kidger.site/diffrax/usage/how-to-choose-a-solver/>_ is valid. Defaults to "Tsit5".

'Tsit5'
custom_integrator bool

Use the built-in fixed-step RK45 integrator instead of Diffrax. Faster but less robust. Defaults to False.

False
diffrax_kwargs Optional[dict[str, Any]]

Preferred Diffrax keyword overrides. These map to :func:openscvx.integrators.solve_ivp_diffrax kwargs, and unknown keys are forwarded to :func:diffrax.diffeqsolve (e.g. stepsize_controller). Defaults to {}.

None
args Optional[dict]

Deprecated alias for diffrax_kwargs kept for backward compatibility.

None
Source code in openscvx/discretization/discretize_linearize.py
class VectorizeDiscretizeLinearize(Discretizer):
    """Discretization via differentiating through the integrator for all segments simultaneously.

    Propagates the nonlinear dynamics over all trajectory segments at once, then directly computes
    Jacobians of the propagated solutions (dF/dx, dF/du) via JAX forward-mode autodiff to produce
    discrete-time Jacobians.

    Supports ZOH (zero-order hold) and FOH (first-order hold) control interpolation between nodes,
    including independent per-control selection.

    This integration scheme offers the best balance of speed and accuracy for most problems.

    Args:
        dis_type: Control hold type. ``"FOH"`` (first-order hold) or
            ``"ZOH"`` (zero-order hold) applies the same hold to every
            control.  A per-control sequence (e.g. ``["FOH", "ZOH", "FOH"]``)
            sets the hold independently for each control, merged with any
            per-control ``parameterization`` (``"FOH"`` / ``"ZOH"``).
            Defaults to ``"FOH"``.
        ode_solver: Diffrax solver name. Any solver from
            `Diffrax <https://docs.kidger.site/diffrax/usage/how-to-choose-a-solver/>`_
            is valid. Defaults to ``"Tsit5"``.
        custom_integrator: Use the built-in fixed-step RK45 integrator instead of Diffrax.
            Faster but less robust. Defaults to ``False``.
        diffrax_kwargs: Preferred Diffrax keyword overrides. These map to
            :func:`openscvx.integrators.solve_ivp_diffrax` kwargs, and unknown
            keys are forwarded to :func:`diffrax.diffeqsolve` (e.g.
            ``stepsize_controller``). Defaults to ``{}``.
        args: Deprecated alias for ``diffrax_kwargs`` kept for backward compatibility.
    """

    def __init__(
        self,
        dis_type: Union[str, Sequence[str]] = "FOH",
        ode_solver: str = "Tsit5",
        custom_integrator: bool = False,
        diffrax_kwargs: Optional[dict[str, Any]] = None,
        args: Optional[dict] = None,
    ):
        self.dis_type = dis_type
        self.ode_solver = ode_solver
        self.custom_integrator = custom_integrator
        merged_kwargs: dict[str, Any] = {}
        if diffrax_kwargs is not None:
            merged_kwargs.update(dict(diffrax_kwargs))
        if args is not None:
            merged_kwargs.update(dict(args))
        self.diffrax_kwargs = merged_kwargs

    def _resolve_diffrax_kwargs(self) -> dict[str, Any]:
        kwargs: dict[str, Any] = {
            "solver_name": self.ode_solver,
            "rtol": DEFAULT_DIFFRAX_RTOL,
            "atol": DEFAULT_DIFFRAX_ATOL,
        }
        user_kwargs = dict(self.diffrax_kwargs)
        extra_kwargs: dict[str, Any] = {"adjoint": dfx.ForwardMode()}

        nested_extra = user_kwargs.pop("extra_kwargs", None)
        if nested_extra is not None:
            extra_kwargs.update(dict(nested_extra))

        direct_keys = {"tau_0", "num_substeps", "solver_name", "rtol", "atol"}
        for key, value in user_kwargs.items():
            if key in direct_keys:
                kwargs[key] = value
            else:
                extra_kwargs[key] = value

        kwargs["extra_kwargs"] = extra_kwargs
        return kwargs

    def _resolve_rk45_kwargs(self, *, is_not_compiled: bool) -> dict[str, Any]:
        kwargs: dict[str, Any] = {"is_not_compiled": is_not_compiled}
        direct_keys = {"tau_0", "num_substeps", "is_not_compiled"}
        for key, value in self.diffrax_kwargs.items():
            if key in direct_keys:
                kwargs[key] = value
        return kwargs

    def get_solver(self, dynamics: "Dynamics", settings: "Config") -> callable:
        """Create a multiple-shooting vectorize-then-discretize-then-linearize solver.

        Batches ``dynamics.f`` across all nodes, integrates it, and computes Jacobians of the
        solution directly (i.e. without using the variational equations).

        Args:
            dynamics: System dynamics. Only ``dynamics.f`` is used; Jacobians are computed
                internally via JAX autodiff.
            settings: Problem configuration.

        Returns:
            Callable ``(x, u, params) -> (A_d, B_d, C_d, x_prop, V)``.
        """

        N = settings.sim.n
        n_x = settings.sim.n_states
        n_u = settings.sim.n_controls
        nodes = jnp.arange(0, N - 1)
        u_foh_mask = getattr(settings.sim.u, "foh_mask", None)
        foh_mask = _resolve_foh_mask(self.dis_type, n_u, u_foh_mask)

        multiple_state_dot = jax.vmap(dynamics.f, in_axes=(0, 0, 0, None))

        def multiple_dxdt(
            tau: float,
            x: jnp.ndarray,
            u_cur: np.ndarray,
            u_next: np.ndarray,
            params: dict,
        ) -> jnp.ndarray:

            beta = tau * N * foh_mask

            x = x.reshape(N - 1, n_x)
            u = u_cur + beta * (u_next - u_cur)
            F = multiple_state_dot(x, u, nodes, params)

            return F.flatten()

        def vectorize_then_discretize(
            x: jnp.ndarray,
            u_cur: np.ndarray,
            u_next: np.ndarray,
            params: dict,
        ) -> jnp.ndarray:

            if self.custom_integrator:
                rk45_kwargs = self._resolve_rk45_kwargs(is_not_compiled=settings.dev.debug)
                sol = solve_ivp_rk45(
                    multiple_dxdt,
                    1.0 / (N - 1),
                    x.flatten(),
                    args=(u_cur, u_next, params),
                    **rk45_kwargs,
                )
            else:
                diffrax_kwargs = self._resolve_diffrax_kwargs()
                sol = solve_ivp_diffrax(
                    multiple_dxdt,
                    1.0 / (N - 1),
                    x.flatten(),
                    args=(u_cur, u_next, params),
                    **diffrax_kwargs,
                )
            return sol.reshape(-1, N - 1, n_x)

        i0 = 0
        i1 = n_x
        i2 = n_x + n_u
        i3 = n_x + 2 * n_u
        standard_basis = jnp.repeat(jnp.eye(i3)[None], N - 1, axis=0)

        def vectorize_then_discretize_then_linearize(
            x: jnp.ndarray,
            u_cur: np.ndarray,
            u_next: np.ndarray,
            params: dict,
        ):
            def partial(z):
                x = z[:, i0:i1]
                u_cur = z[:, i1:i2]
                u_next = z[:, i2:i3]
                return vectorize_then_discretize(x, u_cur, u_next, params)

            primal = jnp.concatenate([x, u_cur, u_next], axis=-1)
            pushforward = jax.vmap(
                # Discard value of f (zeroth output)
                lambda tangent: jax.jvp(partial, (primal,), (tangent,))[1],
                in_axes=-1,
                out_axes=-1,
            )
            jacobians = pushforward(standard_basis)

            A_d = jacobians[:, :, :, i0:i1]
            B_d = jacobians[:, :, :, i1:i2]
            C_d = jacobians[:, :, :, i2:i3]

            return A_d, B_d, C_d

        def solver(x, u, params):
            A_d, B_d, C_d = vectorize_then_discretize_then_linearize(x[:-1], u[:-1], u[1:], params)
            x_prop = vectorize_then_discretize(x[:-1], u[:-1], u[1:], params)

            # TODO: providing the histories of A, B, and C can lead to as much as a 20% slowdown.
            # If they aren't getting used, they shouldn't be here. V_multi should be replaced with
            # an output directly corresponding to the history of x_prop.
            V_multi = jnp.concatenate(
                [
                    x_prop,
                    A_d.reshape(-1, N - 1, n_x * n_x),
                    B_d.reshape(-1, N - 1, n_x * n_u),
                    C_d.reshape(-1, N - 1, n_x * n_u),
                ],
                axis=2,
            )
            i4 = V_multi.shape[2]
            V_multi = V_multi.reshape(-1, (N - 1) * i4).T

            return A_d[-1], B_d[-1], C_d[-1], x_prop[-1], V_multi

        return solver

    def citation(self) -> List[str]:
        """Return BibTeX citations for the vectorize-then-discretize-then-linearize method.

        Returns:
            List containing the BibTeX entries
        """
        return [
            r"""@phdthesis{kidger2021on,
  title={{O}n {N}eural {D}ifferential {E}quations},
  author={Patrick Kidger},
  year={2021},
  school={University of Oxford},
}""",
        ]
citation() -> List[str]

Return BibTeX citations for the vectorize-then-discretize-then-linearize method.

Returns:

Type Description
List[str]

List containing the BibTeX entries

Source code in openscvx/discretization/discretize_linearize.py
    def citation(self) -> List[str]:
        """Return BibTeX citations for the vectorize-then-discretize-then-linearize method.

        Returns:
            List containing the BibTeX entries
        """
        return [
            r"""@phdthesis{kidger2021on,
  title={{O}n {N}eural {D}ifferential {E}quations},
  author={Patrick Kidger},
  year={2021},
  school={University of Oxford},
}""",
        ]
get_solver(dynamics: Dynamics, settings: Config) -> callable

Create a multiple-shooting vectorize-then-discretize-then-linearize solver.

Batches dynamics.f across all nodes, integrates it, and computes Jacobians of the solution directly (i.e. without using the variational equations).

Parameters:

Name Type Description Default
dynamics Dynamics

System dynamics. Only dynamics.f is used; Jacobians are computed internally via JAX autodiff.

required
settings Config

Problem configuration.

required

Returns:

Type Description
callable

Callable (x, u, params) -> (A_d, B_d, C_d, x_prop, V).

Source code in openscvx/discretization/discretize_linearize.py
def get_solver(self, dynamics: "Dynamics", settings: "Config") -> callable:
    """Create a multiple-shooting vectorize-then-discretize-then-linearize solver.

    Batches ``dynamics.f`` across all nodes, integrates it, and computes Jacobians of the
    solution directly (i.e. without using the variational equations).

    Args:
        dynamics: System dynamics. Only ``dynamics.f`` is used; Jacobians are computed
            internally via JAX autodiff.
        settings: Problem configuration.

    Returns:
        Callable ``(x, u, params) -> (A_d, B_d, C_d, x_prop, V)``.
    """

    N = settings.sim.n
    n_x = settings.sim.n_states
    n_u = settings.sim.n_controls
    nodes = jnp.arange(0, N - 1)
    u_foh_mask = getattr(settings.sim.u, "foh_mask", None)
    foh_mask = _resolve_foh_mask(self.dis_type, n_u, u_foh_mask)

    multiple_state_dot = jax.vmap(dynamics.f, in_axes=(0, 0, 0, None))

    def multiple_dxdt(
        tau: float,
        x: jnp.ndarray,
        u_cur: np.ndarray,
        u_next: np.ndarray,
        params: dict,
    ) -> jnp.ndarray:

        beta = tau * N * foh_mask

        x = x.reshape(N - 1, n_x)
        u = u_cur + beta * (u_next - u_cur)
        F = multiple_state_dot(x, u, nodes, params)

        return F.flatten()

    def vectorize_then_discretize(
        x: jnp.ndarray,
        u_cur: np.ndarray,
        u_next: np.ndarray,
        params: dict,
    ) -> jnp.ndarray:

        if self.custom_integrator:
            rk45_kwargs = self._resolve_rk45_kwargs(is_not_compiled=settings.dev.debug)
            sol = solve_ivp_rk45(
                multiple_dxdt,
                1.0 / (N - 1),
                x.flatten(),
                args=(u_cur, u_next, params),
                **rk45_kwargs,
            )
        else:
            diffrax_kwargs = self._resolve_diffrax_kwargs()
            sol = solve_ivp_diffrax(
                multiple_dxdt,
                1.0 / (N - 1),
                x.flatten(),
                args=(u_cur, u_next, params),
                **diffrax_kwargs,
            )
        return sol.reshape(-1, N - 1, n_x)

    i0 = 0
    i1 = n_x
    i2 = n_x + n_u
    i3 = n_x + 2 * n_u
    standard_basis = jnp.repeat(jnp.eye(i3)[None], N - 1, axis=0)

    def vectorize_then_discretize_then_linearize(
        x: jnp.ndarray,
        u_cur: np.ndarray,
        u_next: np.ndarray,
        params: dict,
    ):
        def partial(z):
            x = z[:, i0:i1]
            u_cur = z[:, i1:i2]
            u_next = z[:, i2:i3]
            return vectorize_then_discretize(x, u_cur, u_next, params)

        primal = jnp.concatenate([x, u_cur, u_next], axis=-1)
        pushforward = jax.vmap(
            # Discard value of f (zeroth output)
            lambda tangent: jax.jvp(partial, (primal,), (tangent,))[1],
            in_axes=-1,
            out_axes=-1,
        )
        jacobians = pushforward(standard_basis)

        A_d = jacobians[:, :, :, i0:i1]
        B_d = jacobians[:, :, :, i1:i2]
        C_d = jacobians[:, :, :, i2:i3]

        return A_d, B_d, C_d

    def solver(x, u, params):
        A_d, B_d, C_d = vectorize_then_discretize_then_linearize(x[:-1], u[:-1], u[1:], params)
        x_prop = vectorize_then_discretize(x[:-1], u[:-1], u[1:], params)

        # TODO: providing the histories of A, B, and C can lead to as much as a 20% slowdown.
        # If they aren't getting used, they shouldn't be here. V_multi should be replaced with
        # an output directly corresponding to the history of x_prop.
        V_multi = jnp.concatenate(
            [
                x_prop,
                A_d.reshape(-1, N - 1, n_x * n_x),
                B_d.reshape(-1, N - 1, n_x * n_u),
                C_d.reshape(-1, N - 1, n_x * n_u),
            ],
            axis=2,
        )
        i4 = V_multi.shape[2]
        V_multi = V_multi.reshape(-1, (N - 1) * i4).T

        return A_d[-1], B_d[-1], C_d[-1], x_prop[-1], V_multi

    return solver

calculate_impulsive_discretization(x_nodes: np.ndarray, u_nodes: np.ndarray, state_dot_discrete: callable, A_discrete: callable, B_discrete: callable, params: dict) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]

Evaluate discrete/impulsive dynamics and Jacobians at all nodes.

Source code in openscvx/discretization/linearize_discretize.py
def calculate_impulsive_discretization(
    x_nodes: np.ndarray,
    u_nodes: np.ndarray,
    state_dot_discrete: callable,
    A_discrete: callable,
    B_discrete: callable,
    params: dict,
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
    """Evaluate discrete/impulsive dynamics and Jacobians at all nodes."""
    n_nodes = x_nodes.shape[0]
    nodes = jnp.arange(n_nodes)
    n_x = x_nodes.shape[1]
    n_u = u_nodes.shape[1]

    x_prop_plus = state_dot_discrete(x_nodes, u_nodes, nodes, params)
    D_d = A_discrete(x_nodes, u_nodes, nodes, params)
    E_d = B_discrete(x_nodes, u_nodes, nodes, params)

    W_end = jnp.concatenate(
        (
            x_prop_plus,
            D_d.reshape(n_nodes, n_x * n_x),
            E_d.reshape(n_nodes, n_x * n_u),
        ),
        axis=1,
    )
    W = W_end.reshape(-1, 1)
    return x_prop_plus, D_d, E_d, W

color_columns(pattern: np.ndarray) -> np.ndarray

Greedy column coloring of a boolean sparsity pattern.

Two columns i and j can share a color when they have no row in common where both are nonzero. This function assigns the fewest colors possible via a greedy (first-fit) algorithm over columns ordered by decreasing number of nonzeros.

Parameters:

Name Type Description Default
pattern ndarray

Boolean (m, n) array where True indicates a structural nonzero.

required

Returns:

Type Description
ndarray

Integer array of length n mapping each column to its color

ndarray

(0-indexed). The number of distinct colors is max(colors) + 1.

Source code in openscvx/discretization/sparse_utils/sparse_jacobian.py
def color_columns(pattern: np.ndarray) -> np.ndarray:
    """Greedy column coloring of a boolean sparsity pattern.

    Two columns ``i`` and ``j`` can share a color when they have no row
    in common where both are nonzero.  This function assigns the fewest
    colors possible via a greedy (first-fit) algorithm over columns
    ordered by decreasing number of nonzeros.

    Args:
        pattern: Boolean ``(m, n)`` array where ``True`` indicates a
            structural nonzero.

    Returns:
        Integer array of length ``n`` mapping each column to its color
        (0-indexed).  The number of distinct colors is ``max(colors) + 1``.
    """
    m, n = pattern.shape
    colors = -np.ones(n, dtype=np.intp)

    col_order = np.argsort(-pattern.sum(axis=0))

    for col in col_order:
        row_set = set(np.where(pattern[:, col])[0])
        forbidden = set()
        for prev_col in range(n):
            if colors[prev_col] < 0:
                continue
            prev_rows = set(np.where(pattern[:, prev_col])[0])
            if row_set & prev_rows:
                forbidden.add(colors[prev_col])
        c = 0
        while c in forbidden:
            c += 1
        colors[col] = c

    return colors

get_impulsive_discretization_solver(dyn_discrete: Dynamics) -> callable

Create a solver for discrete/impulsive dynamics linearization.

Source code in openscvx/discretization/linearize_discretize.py
def get_impulsive_discretization_solver(dyn_discrete: "Dynamics") -> callable:
    """Create a solver for discrete/impulsive dynamics linearization."""
    A_fn = jax.jacfwd(dyn_discrete.f, argnums=0)
    B_fn = jax.jacfwd(dyn_discrete.f, argnums=1)

    f_vmapped = jax.vmap(dyn_discrete.f, in_axes=(0, 0, 0, None))
    A_vmapped = jax.vmap(A_fn, in_axes=(0, 0, 0, None))
    B_vmapped = jax.vmap(B_fn, in_axes=(0, 0, 0, None))

    return lambda x_nodes, u_nodes, params: calculate_impulsive_discretization(
        x_nodes=x_nodes,
        u_nodes=u_nodes,
        state_dot_discrete=f_vmapped,
        A_discrete=A_vmapped,
        B_discrete=B_vmapped,
        params=params,
    )

make_sparse_jacobian_fns(f: Callable, A_c_pattern: Optional[np.ndarray], B_c_pattern: Optional[np.ndarray], n_x: int, n_u: int) -> Tuple[Callable, Callable]

Create vmapped sparse Jacobian functions for df/dx and df/du.

If a sparsity pattern is fully dense or None, falls back to the standard jax.jacfwd path for that Jacobian.

Parameters:

Name Type Description Default
f Callable

Dynamics function f(x, u, node, params) -> x_dot.

required
A_c_pattern Optional[ndarray]

Boolean (n_x, n_x) sparsity of df/dx, or None.

required
B_c_pattern Optional[ndarray]

Boolean (n_x, n_u) sparsity of df/du, or None.

required
n_x int

Number of state dimensions.

required
n_u int

Number of control dimensions.

required

Returns:

Type Description
Callable

(A_vmapped, B_vmapped) — vmapped Jacobian callables with

Callable

signature (x_batch, u_batch, nodes, params) -> J_batch.

Source code in openscvx/discretization/sparse_utils/sparse_jacobian.py
def make_sparse_jacobian_fns(
    f: Callable,
    A_c_pattern: Optional[np.ndarray],
    B_c_pattern: Optional[np.ndarray],
    n_x: int,
    n_u: int,
) -> Tuple[Callable, Callable]:
    """Create vmapped sparse Jacobian functions for df/dx and df/du.

    If a sparsity pattern is fully dense or ``None``, falls back to the
    standard ``jax.jacfwd`` path for that Jacobian.

    Args:
        f: Dynamics function ``f(x, u, node, params) -> x_dot``.
        A_c_pattern: Boolean ``(n_x, n_x)`` sparsity of df/dx, or ``None``.
        B_c_pattern: Boolean ``(n_x, n_u)`` sparsity of df/du, or ``None``.
        n_x: Number of state dimensions.
        n_u: Number of control dimensions.

    Returns:
        ``(A_vmapped, B_vmapped)`` — vmapped Jacobian callables with
        signature ``(x_batch, u_batch, nodes, params) -> J_batch``.
    """
    # --- df/dx ---
    if A_c_pattern is not None and not A_c_pattern.all():
        seeds_A, nc_A, nz_r_A, nz_c_A = _build_coloring_data(A_c_pattern)
        A_fn = _sparse_jacobian_fn(f, 0, seeds_A, nc_A, nz_r_A, nz_c_A, n_x, n_x)
    else:
        A_fn = jax.jacfwd(f, argnums=0)
    A_vmapped = jax.vmap(A_fn, in_axes=(0, 0, 0, None))

    # --- df/du ---
    if B_c_pattern is not None and not B_c_pattern.all():
        seeds_B, nc_B, nz_r_B, nz_c_B = _build_coloring_data(B_c_pattern)
        B_fn = _sparse_jacobian_fn(f, 1, seeds_B, nc_B, nz_r_B, nz_c_B, n_x, n_u)
    else:
        B_fn = jax.jacfwd(f, argnums=1)
    B_vmapped = jax.vmap(B_fn, in_axes=(0, 0, 0, None))

    return A_vmapped, B_vmapped

resolve_discretizer_config(val: Any) -> DiscretizerSpec

Validate a dict/Spec into a :class:DiscretizerSpec instance.

Injects the default type (VectorizeDiscretizeLinearize) when the input dict omits it, preserving backwards compatibility.

Source code in openscvx/discretization/__init__.py
def resolve_discretizer_config(val: Any) -> DiscretizerSpec:
    """Validate a dict/Spec into a :class:`DiscretizerSpec` instance.

    Injects the default ``type`` (``VectorizeDiscretizeLinearize``) when the
    input dict omits it, preserving backwards compatibility.
    """
    if isinstance(val, DiscretizerSpec):
        return val
    if isinstance(val, dict) and "type" not in val:
        val = {**val, "type": DEFAULT_DISCRETIZER_TYPE}
    return DiscretizerSpec.model_validate(val)