Skip to content

discretize_linearize

DiscretizeLinearizeVectorize

Bases: Discretizer

Discretization via differentiating through the integrator for each segment individually.

Propagates the nonlinear dynamics for each trajectory segment on its own, then directly computes Jacobians of the propagated solutions (dF/dx, dF/du) via JAX forward-mode autodiff to produce discrete-time Jacobians for each node.

Supports ZOH (zero-order hold) and FOH (first-order hold) control interpolation between nodes, including independent per-control selection.

Use this integration scheme when the nonlinear dynamics are challenging (e.g. stiff/sensitive, badly scaled, or over long time horizons) and require very tight tolerances. A prototypical example is atmospheric entry of a spacecraft.

Parameters:

Name Type Description Default
dis_type Union[str, Sequence[str]]

Control hold type. "FOH" (first-order hold) or "ZOH" (zero-order hold) applies the same hold to every control. A per-control sequence (e.g. ["FOH", "ZOH", "FOH"]) sets the hold independently for each control, merged with any per-control parameterization ("FOH" / "ZOH"). Defaults to "FOH".

'FOH'
ode_solver str

Diffrax solver name. Any solver from Diffrax <https://docs.kidger.site/diffrax/usage/how-to-choose-a-solver/>_ is valid. Defaults to "Tsit5".

'Tsit5'
custom_integrator bool

Use the built-in fixed-step RK45 integrator instead of Diffrax. Faster but less robust. Defaults to False.

False
diffrax_kwargs Optional[dict[str, Any]]

Preferred Diffrax keyword overrides. These map to :func:openscvx.integrators.solve_ivp_diffrax kwargs, and unknown keys are forwarded to :func:diffrax.diffeqsolve (e.g. stepsize_controller). Defaults to {}.

None
args Optional[dict]

Deprecated alias for diffrax_kwargs kept for backward compatibility.

None
Source code in openscvx/discretization/discretize_linearize.py
class DiscretizeLinearizeVectorize(Discretizer):
    """Discretization via differentiating through the integrator for each segment individually.

    Propagates the nonlinear dynamics for each trajectory segment on its own, then directly computes
    Jacobians of the propagated solutions (dF/dx, dF/du) via JAX forward-mode autodiff to produce
    discrete-time Jacobians for each node.

    Supports ZOH (zero-order hold) and FOH (first-order hold) control interpolation between nodes,
    including independent per-control selection.

    Use this integration scheme when the nonlinear dynamics are challenging (e.g. stiff/sensitive,
    badly scaled, or over long time horizons) and require very tight tolerances. A prototypical
    example is atmospheric entry of a spacecraft.

    Args:
        dis_type: Control hold type. ``"FOH"`` (first-order hold) or
            ``"ZOH"`` (zero-order hold) applies the same hold to every
            control.  A per-control sequence (e.g. ``["FOH", "ZOH", "FOH"]``)
            sets the hold independently for each control, merged with any
            per-control ``parameterization`` (``"FOH"`` / ``"ZOH"``).
            Defaults to ``"FOH"``.
        ode_solver: Diffrax solver name. Any solver from
            `Diffrax <https://docs.kidger.site/diffrax/usage/how-to-choose-a-solver/>`_
            is valid. Defaults to ``"Tsit5"``.
        custom_integrator: Use the built-in fixed-step RK45 integrator instead of Diffrax.
            Faster but less robust. Defaults to ``False``.
        diffrax_kwargs: Preferred Diffrax keyword overrides. These map to
            :func:`openscvx.integrators.solve_ivp_diffrax` kwargs, and unknown
            keys are forwarded to :func:`diffrax.diffeqsolve` (e.g.
            ``stepsize_controller``). Defaults to ``{}``.
        args: Deprecated alias for ``diffrax_kwargs`` kept for backward compatibility.
    """

    def __init__(
        self,
        dis_type: Union[str, Sequence[str]] = "FOH",
        ode_solver: str = "Tsit5",
        custom_integrator: bool = False,
        diffrax_kwargs: Optional[dict[str, Any]] = None,
        args: Optional[dict] = None,
    ):
        self.dis_type = dis_type
        self.ode_solver = ode_solver
        self.custom_integrator = custom_integrator
        merged_kwargs: dict[str, Any] = {}
        if diffrax_kwargs is not None:
            merged_kwargs.update(dict(diffrax_kwargs))
        if args is not None:
            merged_kwargs.update(dict(args))
        self.diffrax_kwargs = merged_kwargs

    def _resolve_diffrax_kwargs(self, *, n_segments: int) -> dict[str, Any]:
        kwargs: dict[str, Any] = {
            "solver_name": self.ode_solver,
            # Error budget is distributed since one segment is integrated at a time.
            "rtol": DEFAULT_DIFFRAX_RTOL / n_segments,
            "atol": DEFAULT_DIFFRAX_ATOL / n_segments,
        }
        user_kwargs = dict(self.diffrax_kwargs)
        extra_kwargs: dict[str, Any] = {"adjoint": dfx.ForwardMode()}

        nested_extra = user_kwargs.pop("extra_kwargs", None)
        if nested_extra is not None:
            extra_kwargs.update(dict(nested_extra))

        direct_keys = {"tau_0", "num_substeps", "solver_name", "rtol", "atol"}
        for key, value in user_kwargs.items():
            if key in direct_keys:
                kwargs[key] = value
            else:
                extra_kwargs[key] = value

        if "rtol" in user_kwargs:
            kwargs["rtol"] = user_kwargs["rtol"] / n_segments
        if "atol" in user_kwargs:
            kwargs["atol"] = user_kwargs["atol"] / n_segments

        kwargs["extra_kwargs"] = extra_kwargs
        return kwargs

    def _resolve_rk45_kwargs(self, *, is_not_compiled: bool) -> dict[str, Any]:
        kwargs: dict[str, Any] = {"is_not_compiled": is_not_compiled}
        direct_keys = {"tau_0", "num_substeps", "is_not_compiled"}
        for key, value in self.diffrax_kwargs.items():
            if key in direct_keys:
                kwargs[key] = value
        return kwargs

    def get_solver(self, dynamics: "Dynamics", settings: "Config") -> callable:
        """Create a multiple-shooting discretize-then-linearize-then-vectorize solver.

        Integrates ``dynamics.f`` and computes Jacobians of the discretized function directly
        (i.e. without using the variational equations). Outputs are vmapped for batch evaluation
        across nodes.

        Args:
            dynamics: System dynamics. Only ``dynamics.f`` is used; Jacobians are computed
                internally via JAX autodiff.
            settings: Problem configuration.

        Returns:
            Callable ``(x, u, params) -> (A_d, B_d, C_d, x_prop, V)``.
        """

        N = settings.sim.n
        n_x = settings.sim.n_states
        n_u = settings.sim.n_controls
        u_foh_mask = getattr(settings.sim.u, "foh_mask", None)
        foh_mask = _resolve_foh_mask(self.dis_type, n_u, u_foh_mask)

        single_state_dot = dynamics.f

        def single_dxdt(
            tau: float,
            x: jnp.ndarray,
            u_cur: np.ndarray,
            u_next: np.ndarray,
            node: int,
            params: dict,
        ) -> jnp.ndarray:

            beta = tau * N * foh_mask

            u = u_cur + beta * (u_next - u_cur)
            F = single_state_dot(x, u, node, params)

            return F

        def single_shot(
            x: jnp.ndarray,
            u_cur: np.ndarray,
            u_next: np.ndarray,
            node: int,
            params: dict,
        ) -> jnp.ndarray:

            if self.custom_integrator:
                rk45_kwargs = self._resolve_rk45_kwargs(is_not_compiled=settings.dev.debug)
                sol = solve_ivp_rk45(
                    single_dxdt,
                    1.0 / (N - 1),
                    x,
                    args=(u_cur, u_next, node, params),
                    **rk45_kwargs,
                )
            else:
                diffrax_kwargs = self._resolve_diffrax_kwargs(n_segments=N - 1)
                sol = solve_ivp_diffrax(
                    single_dxdt,
                    1.0 / (N - 1),
                    x,
                    args=(u_cur, u_next, node, params),
                    **diffrax_kwargs,
                )
            return sol

        discretize_then_vectorize = jax.vmap(single_shot, in_axes=(0, 0, 0, 0, None), out_axes=1)
        discretize_then_linearize = jax.jacfwd(single_shot, argnums=(0, 1, 2))
        discretize_then_linearize_then_vectorize = jax.vmap(
            discretize_then_linearize, in_axes=(0, 0, 0, 0, None), out_axes=1
        )

        nodes = jnp.arange(0, N - 1)

        def solver(x, u, params):
            A_d, B_d, C_d = discretize_then_linearize_then_vectorize(
                x[:-1], u[:-1], u[1:], nodes, params
            )
            x_prop = discretize_then_vectorize(x[:-1], u[:-1], u[1:], nodes, params)

            # TODO: providing the histories of A, B, and C can lead to as much as a 20% slowdown.
            # If they aren't getting used, they shouldn't be here. V_multi should be replaced with
            # an output directly corresponding to the history of x_prop.
            V_multi = jnp.concatenate(
                [
                    x_prop,
                    A_d.reshape(-1, N - 1, n_x * n_x),
                    B_d.reshape(-1, N - 1, n_x * n_u),
                    C_d.reshape(-1, N - 1, n_x * n_u),
                ],
                axis=2,
            )
            i4 = V_multi.shape[2]
            V_multi = V_multi.reshape(-1, (N - 1) * i4).T

            return A_d[-1], B_d[-1], C_d[-1], x_prop[-1], V_multi

        return solver

    def citation(self) -> List[str]:
        """Return BibTeX citations for the discretize-then-linearize-then-vectorize method.

        Returns:
            List containing the BibTeX entries
        """
        return [
            r"""@phdthesis{kidger2021on,
  title={{O}n {N}eural {D}ifferential {E}quations},
  author={Patrick Kidger},
  year={2021},
  school={University of Oxford},
}""",
        ]
citation() -> List[str]

Return BibTeX citations for the discretize-then-linearize-then-vectorize method.

Returns:

Type Description
List[str]

List containing the BibTeX entries

Source code in openscvx/discretization/discretize_linearize.py
    def citation(self) -> List[str]:
        """Return BibTeX citations for the discretize-then-linearize-then-vectorize method.

        Returns:
            List containing the BibTeX entries
        """
        return [
            r"""@phdthesis{kidger2021on,
  title={{O}n {N}eural {D}ifferential {E}quations},
  author={Patrick Kidger},
  year={2021},
  school={University of Oxford},
}""",
        ]
get_solver(dynamics: Dynamics, settings: Config) -> callable

Create a multiple-shooting discretize-then-linearize-then-vectorize solver.

Integrates dynamics.f and computes Jacobians of the discretized function directly (i.e. without using the variational equations). Outputs are vmapped for batch evaluation across nodes.

Parameters:

Name Type Description Default
dynamics Dynamics

System dynamics. Only dynamics.f is used; Jacobians are computed internally via JAX autodiff.

required
settings Config

Problem configuration.

required

Returns:

Type Description
callable

Callable (x, u, params) -> (A_d, B_d, C_d, x_prop, V).

Source code in openscvx/discretization/discretize_linearize.py
def get_solver(self, dynamics: "Dynamics", settings: "Config") -> callable:
    """Create a multiple-shooting discretize-then-linearize-then-vectorize solver.

    Integrates ``dynamics.f`` and computes Jacobians of the discretized function directly
    (i.e. without using the variational equations). Outputs are vmapped for batch evaluation
    across nodes.

    Args:
        dynamics: System dynamics. Only ``dynamics.f`` is used; Jacobians are computed
            internally via JAX autodiff.
        settings: Problem configuration.

    Returns:
        Callable ``(x, u, params) -> (A_d, B_d, C_d, x_prop, V)``.
    """

    N = settings.sim.n
    n_x = settings.sim.n_states
    n_u = settings.sim.n_controls
    u_foh_mask = getattr(settings.sim.u, "foh_mask", None)
    foh_mask = _resolve_foh_mask(self.dis_type, n_u, u_foh_mask)

    single_state_dot = dynamics.f

    def single_dxdt(
        tau: float,
        x: jnp.ndarray,
        u_cur: np.ndarray,
        u_next: np.ndarray,
        node: int,
        params: dict,
    ) -> jnp.ndarray:

        beta = tau * N * foh_mask

        u = u_cur + beta * (u_next - u_cur)
        F = single_state_dot(x, u, node, params)

        return F

    def single_shot(
        x: jnp.ndarray,
        u_cur: np.ndarray,
        u_next: np.ndarray,
        node: int,
        params: dict,
    ) -> jnp.ndarray:

        if self.custom_integrator:
            rk45_kwargs = self._resolve_rk45_kwargs(is_not_compiled=settings.dev.debug)
            sol = solve_ivp_rk45(
                single_dxdt,
                1.0 / (N - 1),
                x,
                args=(u_cur, u_next, node, params),
                **rk45_kwargs,
            )
        else:
            diffrax_kwargs = self._resolve_diffrax_kwargs(n_segments=N - 1)
            sol = solve_ivp_diffrax(
                single_dxdt,
                1.0 / (N - 1),
                x,
                args=(u_cur, u_next, node, params),
                **diffrax_kwargs,
            )
        return sol

    discretize_then_vectorize = jax.vmap(single_shot, in_axes=(0, 0, 0, 0, None), out_axes=1)
    discretize_then_linearize = jax.jacfwd(single_shot, argnums=(0, 1, 2))
    discretize_then_linearize_then_vectorize = jax.vmap(
        discretize_then_linearize, in_axes=(0, 0, 0, 0, None), out_axes=1
    )

    nodes = jnp.arange(0, N - 1)

    def solver(x, u, params):
        A_d, B_d, C_d = discretize_then_linearize_then_vectorize(
            x[:-1], u[:-1], u[1:], nodes, params
        )
        x_prop = discretize_then_vectorize(x[:-1], u[:-1], u[1:], nodes, params)

        # TODO: providing the histories of A, B, and C can lead to as much as a 20% slowdown.
        # If they aren't getting used, they shouldn't be here. V_multi should be replaced with
        # an output directly corresponding to the history of x_prop.
        V_multi = jnp.concatenate(
            [
                x_prop,
                A_d.reshape(-1, N - 1, n_x * n_x),
                B_d.reshape(-1, N - 1, n_x * n_u),
                C_d.reshape(-1, N - 1, n_x * n_u),
            ],
            axis=2,
        )
        i4 = V_multi.shape[2]
        V_multi = V_multi.reshape(-1, (N - 1) * i4).T

        return A_d[-1], B_d[-1], C_d[-1], x_prop[-1], V_multi

    return solver

VectorizeDiscretizeLinearize

Bases: Discretizer

Discretization via differentiating through the integrator for all segments simultaneously.

Propagates the nonlinear dynamics over all trajectory segments at once, then directly computes Jacobians of the propagated solutions (dF/dx, dF/du) via JAX forward-mode autodiff to produce discrete-time Jacobians.

Supports ZOH (zero-order hold) and FOH (first-order hold) control interpolation between nodes, including independent per-control selection.

This integration scheme offers the best balance of speed and accuracy for most problems.

Parameters:

Name Type Description Default
dis_type Union[str, Sequence[str]]

Control hold type. "FOH" (first-order hold) or "ZOH" (zero-order hold) applies the same hold to every control. A per-control sequence (e.g. ["FOH", "ZOH", "FOH"]) sets the hold independently for each control, merged with any per-control parameterization ("FOH" / "ZOH"). Defaults to "FOH".

'FOH'
ode_solver str

Diffrax solver name. Any solver from Diffrax <https://docs.kidger.site/diffrax/usage/how-to-choose-a-solver/>_ is valid. Defaults to "Tsit5".

'Tsit5'
custom_integrator bool

Use the built-in fixed-step RK45 integrator instead of Diffrax. Faster but less robust. Defaults to False.

False
diffrax_kwargs Optional[dict[str, Any]]

Preferred Diffrax keyword overrides. These map to :func:openscvx.integrators.solve_ivp_diffrax kwargs, and unknown keys are forwarded to :func:diffrax.diffeqsolve (e.g. stepsize_controller). Defaults to {}.

None
args Optional[dict]

Deprecated alias for diffrax_kwargs kept for backward compatibility.

None
Source code in openscvx/discretization/discretize_linearize.py
class VectorizeDiscretizeLinearize(Discretizer):
    """Discretization via differentiating through the integrator for all segments simultaneously.

    Propagates the nonlinear dynamics over all trajectory segments at once, then directly computes
    Jacobians of the propagated solutions (dF/dx, dF/du) via JAX forward-mode autodiff to produce
    discrete-time Jacobians.

    Supports ZOH (zero-order hold) and FOH (first-order hold) control interpolation between nodes,
    including independent per-control selection.

    This integration scheme offers the best balance of speed and accuracy for most problems.

    Args:
        dis_type: Control hold type. ``"FOH"`` (first-order hold) or
            ``"ZOH"`` (zero-order hold) applies the same hold to every
            control.  A per-control sequence (e.g. ``["FOH", "ZOH", "FOH"]``)
            sets the hold independently for each control, merged with any
            per-control ``parameterization`` (``"FOH"`` / ``"ZOH"``).
            Defaults to ``"FOH"``.
        ode_solver: Diffrax solver name. Any solver from
            `Diffrax <https://docs.kidger.site/diffrax/usage/how-to-choose-a-solver/>`_
            is valid. Defaults to ``"Tsit5"``.
        custom_integrator: Use the built-in fixed-step RK45 integrator instead of Diffrax.
            Faster but less robust. Defaults to ``False``.
        diffrax_kwargs: Preferred Diffrax keyword overrides. These map to
            :func:`openscvx.integrators.solve_ivp_diffrax` kwargs, and unknown
            keys are forwarded to :func:`diffrax.diffeqsolve` (e.g.
            ``stepsize_controller``). Defaults to ``{}``.
        args: Deprecated alias for ``diffrax_kwargs`` kept for backward compatibility.
    """

    def __init__(
        self,
        dis_type: Union[str, Sequence[str]] = "FOH",
        ode_solver: str = "Tsit5",
        custom_integrator: bool = False,
        diffrax_kwargs: Optional[dict[str, Any]] = None,
        args: Optional[dict] = None,
    ):
        self.dis_type = dis_type
        self.ode_solver = ode_solver
        self.custom_integrator = custom_integrator
        merged_kwargs: dict[str, Any] = {}
        if diffrax_kwargs is not None:
            merged_kwargs.update(dict(diffrax_kwargs))
        if args is not None:
            merged_kwargs.update(dict(args))
        self.diffrax_kwargs = merged_kwargs

    def _resolve_diffrax_kwargs(self) -> dict[str, Any]:
        kwargs: dict[str, Any] = {
            "solver_name": self.ode_solver,
            "rtol": DEFAULT_DIFFRAX_RTOL,
            "atol": DEFAULT_DIFFRAX_ATOL,
        }
        user_kwargs = dict(self.diffrax_kwargs)
        extra_kwargs: dict[str, Any] = {"adjoint": dfx.ForwardMode()}

        nested_extra = user_kwargs.pop("extra_kwargs", None)
        if nested_extra is not None:
            extra_kwargs.update(dict(nested_extra))

        direct_keys = {"tau_0", "num_substeps", "solver_name", "rtol", "atol"}
        for key, value in user_kwargs.items():
            if key in direct_keys:
                kwargs[key] = value
            else:
                extra_kwargs[key] = value

        kwargs["extra_kwargs"] = extra_kwargs
        return kwargs

    def _resolve_rk45_kwargs(self, *, is_not_compiled: bool) -> dict[str, Any]:
        kwargs: dict[str, Any] = {"is_not_compiled": is_not_compiled}
        direct_keys = {"tau_0", "num_substeps", "is_not_compiled"}
        for key, value in self.diffrax_kwargs.items():
            if key in direct_keys:
                kwargs[key] = value
        return kwargs

    def get_solver(self, dynamics: "Dynamics", settings: "Config") -> callable:
        """Create a multiple-shooting vectorize-then-discretize-then-linearize solver.

        Batches ``dynamics.f`` across all nodes, integrates it, and computes Jacobians of the
        solution directly (i.e. without using the variational equations).

        Args:
            dynamics: System dynamics. Only ``dynamics.f`` is used; Jacobians are computed
                internally via JAX autodiff.
            settings: Problem configuration.

        Returns:
            Callable ``(x, u, params) -> (A_d, B_d, C_d, x_prop, V)``.
        """

        N = settings.sim.n
        n_x = settings.sim.n_states
        n_u = settings.sim.n_controls
        nodes = jnp.arange(0, N - 1)
        u_foh_mask = getattr(settings.sim.u, "foh_mask", None)
        foh_mask = _resolve_foh_mask(self.dis_type, n_u, u_foh_mask)

        multiple_state_dot = jax.vmap(dynamics.f, in_axes=(0, 0, 0, None))

        def multiple_dxdt(
            tau: float,
            x: jnp.ndarray,
            u_cur: np.ndarray,
            u_next: np.ndarray,
            params: dict,
        ) -> jnp.ndarray:

            beta = tau * N * foh_mask

            x = x.reshape(N - 1, n_x)
            u = u_cur + beta * (u_next - u_cur)
            F = multiple_state_dot(x, u, nodes, params)

            return F.flatten()

        def vectorize_then_discretize(
            x: jnp.ndarray,
            u_cur: np.ndarray,
            u_next: np.ndarray,
            params: dict,
        ) -> jnp.ndarray:

            if self.custom_integrator:
                rk45_kwargs = self._resolve_rk45_kwargs(is_not_compiled=settings.dev.debug)
                sol = solve_ivp_rk45(
                    multiple_dxdt,
                    1.0 / (N - 1),
                    x.flatten(),
                    args=(u_cur, u_next, params),
                    **rk45_kwargs,
                )
            else:
                diffrax_kwargs = self._resolve_diffrax_kwargs()
                sol = solve_ivp_diffrax(
                    multiple_dxdt,
                    1.0 / (N - 1),
                    x.flatten(),
                    args=(u_cur, u_next, params),
                    **diffrax_kwargs,
                )
            return sol.reshape(-1, N - 1, n_x)

        i0 = 0
        i1 = n_x
        i2 = n_x + n_u
        i3 = n_x + 2 * n_u
        standard_basis = jnp.repeat(jnp.eye(i3)[None], N - 1, axis=0)

        def vectorize_then_discretize_then_linearize(
            x: jnp.ndarray,
            u_cur: np.ndarray,
            u_next: np.ndarray,
            params: dict,
        ):
            def partial(z):
                x = z[:, i0:i1]
                u_cur = z[:, i1:i2]
                u_next = z[:, i2:i3]
                return vectorize_then_discretize(x, u_cur, u_next, params)

            primal = jnp.concatenate([x, u_cur, u_next], axis=-1)
            pushforward = jax.vmap(
                # Discard value of f (zeroth output)
                lambda tangent: jax.jvp(partial, (primal,), (tangent,))[1],
                in_axes=-1,
                out_axes=-1,
            )
            jacobians = pushforward(standard_basis)

            A_d = jacobians[:, :, :, i0:i1]
            B_d = jacobians[:, :, :, i1:i2]
            C_d = jacobians[:, :, :, i2:i3]

            return A_d, B_d, C_d

        def solver(x, u, params):
            A_d, B_d, C_d = vectorize_then_discretize_then_linearize(x[:-1], u[:-1], u[1:], params)
            x_prop = vectorize_then_discretize(x[:-1], u[:-1], u[1:], params)

            # TODO: providing the histories of A, B, and C can lead to as much as a 20% slowdown.
            # If they aren't getting used, they shouldn't be here. V_multi should be replaced with
            # an output directly corresponding to the history of x_prop.
            V_multi = jnp.concatenate(
                [
                    x_prop,
                    A_d.reshape(-1, N - 1, n_x * n_x),
                    B_d.reshape(-1, N - 1, n_x * n_u),
                    C_d.reshape(-1, N - 1, n_x * n_u),
                ],
                axis=2,
            )
            i4 = V_multi.shape[2]
            V_multi = V_multi.reshape(-1, (N - 1) * i4).T

            return A_d[-1], B_d[-1], C_d[-1], x_prop[-1], V_multi

        return solver

    def citation(self) -> List[str]:
        """Return BibTeX citations for the vectorize-then-discretize-then-linearize method.

        Returns:
            List containing the BibTeX entries
        """
        return [
            r"""@phdthesis{kidger2021on,
  title={{O}n {N}eural {D}ifferential {E}quations},
  author={Patrick Kidger},
  year={2021},
  school={University of Oxford},
}""",
        ]
citation() -> List[str]

Return BibTeX citations for the vectorize-then-discretize-then-linearize method.

Returns:

Type Description
List[str]

List containing the BibTeX entries

Source code in openscvx/discretization/discretize_linearize.py
    def citation(self) -> List[str]:
        """Return BibTeX citations for the vectorize-then-discretize-then-linearize method.

        Returns:
            List containing the BibTeX entries
        """
        return [
            r"""@phdthesis{kidger2021on,
  title={{O}n {N}eural {D}ifferential {E}quations},
  author={Patrick Kidger},
  year={2021},
  school={University of Oxford},
}""",
        ]
get_solver(dynamics: Dynamics, settings: Config) -> callable

Create a multiple-shooting vectorize-then-discretize-then-linearize solver.

Batches dynamics.f across all nodes, integrates it, and computes Jacobians of the solution directly (i.e. without using the variational equations).

Parameters:

Name Type Description Default
dynamics Dynamics

System dynamics. Only dynamics.f is used; Jacobians are computed internally via JAX autodiff.

required
settings Config

Problem configuration.

required

Returns:

Type Description
callable

Callable (x, u, params) -> (A_d, B_d, C_d, x_prop, V).

Source code in openscvx/discretization/discretize_linearize.py
def get_solver(self, dynamics: "Dynamics", settings: "Config") -> callable:
    """Create a multiple-shooting vectorize-then-discretize-then-linearize solver.

    Batches ``dynamics.f`` across all nodes, integrates it, and computes Jacobians of the
    solution directly (i.e. without using the variational equations).

    Args:
        dynamics: System dynamics. Only ``dynamics.f`` is used; Jacobians are computed
            internally via JAX autodiff.
        settings: Problem configuration.

    Returns:
        Callable ``(x, u, params) -> (A_d, B_d, C_d, x_prop, V)``.
    """

    N = settings.sim.n
    n_x = settings.sim.n_states
    n_u = settings.sim.n_controls
    nodes = jnp.arange(0, N - 1)
    u_foh_mask = getattr(settings.sim.u, "foh_mask", None)
    foh_mask = _resolve_foh_mask(self.dis_type, n_u, u_foh_mask)

    multiple_state_dot = jax.vmap(dynamics.f, in_axes=(0, 0, 0, None))

    def multiple_dxdt(
        tau: float,
        x: jnp.ndarray,
        u_cur: np.ndarray,
        u_next: np.ndarray,
        params: dict,
    ) -> jnp.ndarray:

        beta = tau * N * foh_mask

        x = x.reshape(N - 1, n_x)
        u = u_cur + beta * (u_next - u_cur)
        F = multiple_state_dot(x, u, nodes, params)

        return F.flatten()

    def vectorize_then_discretize(
        x: jnp.ndarray,
        u_cur: np.ndarray,
        u_next: np.ndarray,
        params: dict,
    ) -> jnp.ndarray:

        if self.custom_integrator:
            rk45_kwargs = self._resolve_rk45_kwargs(is_not_compiled=settings.dev.debug)
            sol = solve_ivp_rk45(
                multiple_dxdt,
                1.0 / (N - 1),
                x.flatten(),
                args=(u_cur, u_next, params),
                **rk45_kwargs,
            )
        else:
            diffrax_kwargs = self._resolve_diffrax_kwargs()
            sol = solve_ivp_diffrax(
                multiple_dxdt,
                1.0 / (N - 1),
                x.flatten(),
                args=(u_cur, u_next, params),
                **diffrax_kwargs,
            )
        return sol.reshape(-1, N - 1, n_x)

    i0 = 0
    i1 = n_x
    i2 = n_x + n_u
    i3 = n_x + 2 * n_u
    standard_basis = jnp.repeat(jnp.eye(i3)[None], N - 1, axis=0)

    def vectorize_then_discretize_then_linearize(
        x: jnp.ndarray,
        u_cur: np.ndarray,
        u_next: np.ndarray,
        params: dict,
    ):
        def partial(z):
            x = z[:, i0:i1]
            u_cur = z[:, i1:i2]
            u_next = z[:, i2:i3]
            return vectorize_then_discretize(x, u_cur, u_next, params)

        primal = jnp.concatenate([x, u_cur, u_next], axis=-1)
        pushforward = jax.vmap(
            # Discard value of f (zeroth output)
            lambda tangent: jax.jvp(partial, (primal,), (tangent,))[1],
            in_axes=-1,
            out_axes=-1,
        )
        jacobians = pushforward(standard_basis)

        A_d = jacobians[:, :, :, i0:i1]
        B_d = jacobians[:, :, :, i1:i2]
        C_d = jacobians[:, :, :, i2:i3]

        return A_d, B_d, C_d

    def solver(x, u, params):
        A_d, B_d, C_d = vectorize_then_discretize_then_linearize(x[:-1], u[:-1], u[1:], params)
        x_prop = vectorize_then_discretize(x[:-1], u[:-1], u[1:], params)

        # TODO: providing the histories of A, B, and C can lead to as much as a 20% slowdown.
        # If they aren't getting used, they shouldn't be here. V_multi should be replaced with
        # an output directly corresponding to the history of x_prop.
        V_multi = jnp.concatenate(
            [
                x_prop,
                A_d.reshape(-1, N - 1, n_x * n_x),
                B_d.reshape(-1, N - 1, n_x * n_u),
                C_d.reshape(-1, N - 1, n_x * n_u),
            ],
            axis=2,
        )
        i4 = V_multi.shape[2]
        V_multi = V_multi.reshape(-1, (N - 1) * i4).T

        return A_d[-1], B_d[-1], C_d[-1], x_prop[-1], V_multi

    return solver