Skip to content

linearize_discretize_sparse

LinearizeDiscretizeSparse

Bases: LinearizeDiscretize

Sparse variant of linearize-then-discretize.

Uses graph-coloring-based sparse Jacobian computation and a compact augmented state vector that only integrates the structurally nonzero entries of Φ, B_d and C_d. This reduces the ODE system dimension from n_x + n_x² + 2·n_x·n_u to n_x + nnz_Ad + nnz_Bd + nnz_Cd per segment.

Requires A_c_sparsity and B_c_sparsity boolean arrays on the :class:Dynamics object (set automatically when using the symbolic problem interface). Falls back to the dense :class:LinearizeDiscretize path when sparsity patterns are unavailable or fully dense.

Parameters:

Name Type Description Default
dis_type Union[str, Sequence[str]]

Control hold type. "FOH" or "ZOH". Defaults to "FOH".

'FOH'
ode_solver str

Diffrax solver name. Defaults to "Tsit5".

'Tsit5'
diffrax_kwargs Optional[dict[str, Any]]

Diffrax keyword overrides inherited from :class:LinearizeDiscretize. Unknown keys are forwarded to :func:diffrax.diffeqsolve via extra_kwargs.

None
Source code in openscvx/discretization/linearize_discretize_sparse.py
class LinearizeDiscretizeSparse(LinearizeDiscretize):
    """Sparse variant of linearize-then-discretize.

    Uses graph-coloring-based sparse Jacobian computation and a compact
    augmented state vector that only integrates the structurally nonzero
    entries of Φ, B_d and C_d.  This reduces the ODE system dimension
    from ``n_x + n_x² + 2·n_x·n_u`` to ``n_x + nnz_Ad + nnz_Bd + nnz_Cd``
    per segment.

    Requires ``A_c_sparsity`` and ``B_c_sparsity`` boolean arrays on the
    :class:`Dynamics` object (set automatically when using the symbolic
    problem interface).  Falls back to the dense
    :class:`LinearizeDiscretize` path when sparsity patterns are
    unavailable or fully dense.

    Args:
        dis_type: Control hold type. ``"FOH"`` or ``"ZOH"``.
            Defaults to ``"FOH"``.
        ode_solver: Diffrax solver name. Defaults to ``"Tsit5"``.
        diffrax_kwargs: Diffrax keyword overrides inherited from
            :class:`LinearizeDiscretize`. Unknown keys are forwarded to
            :func:`diffrax.diffeqsolve` via ``extra_kwargs``.
    """

    def get_solver(self, dynamics: "Dynamics", settings: "Config") -> callable:
        """Create a sparse multi-shoot discretization solver.

        When ``dynamics.A_c_sparsity`` and ``dynamics.B_c_sparsity`` are
        available and the pattern is not fully dense, builds a compact-V
        integration path with graph-coloring Jacobians.  Otherwise
        delegates to the dense parent implementation.

        Args:
            dynamics: System dynamics with optional sparsity annotations.
            settings: Problem configuration.

        Returns:
            Callable ``(x, u, params) -> (A_d, B_d, C_d, x_prop, V)``.
        """
        from openscvx.symbolic.sparsity import discrete_sparsity

        from .sparse_utils import make_sparse_jacobian_fns

        A_c_pat = getattr(dynamics, "A_c_sparsity", None)
        B_c_pat = getattr(dynamics, "B_c_sparsity", None)
        has_sparsity = A_c_pat is not None and B_c_pat is not None

        if not has_sparsity or A_c_pat.all():
            return super().get_solver(dynamics, settings)

        f_vmapped = jax.vmap(dynamics.f, in_axes=(0, 0, 0, None))
        discretizer = self
        n_x = settings.sim.n_states
        n_u = settings.sim.n_controls

        A_vmapped, B_vmapped = make_sparse_jacobian_fns(
            dynamics.f,
            A_c_pat,
            B_c_pat,
            n_x,
            n_u,
        )

        u_foh_mask = getattr(settings.sim.u, "foh_mask", None)
        resolved_mask = _resolve_foh_mask(self.dis_type, n_u, u_foh_mask)
        Ad_pat, Bd_pat, Cd_pat = discrete_sparsity(
            A_c_pat,
            B_c_pat,
            resolved_mask,
        )
        Ad_r, Ad_c = np.where(Ad_pat)
        Bd_r, Bd_c = np.where(Bd_pat)
        Cd_r, Cd_c = np.where(Cd_pat)

        sparse_layout = (
            jnp.array(Ad_r),
            jnp.array(Ad_c),
            len(Ad_r),
            jnp.array(Bd_r),
            jnp.array(Bd_c),
            len(Bd_r),
            jnp.array(Cd_r),
            jnp.array(Cd_c),
            len(Cd_r),
        )

        return lambda x, u, params: _calculate_discretization_sparse(
            x=x,
            u=u,
            state_dot=f_vmapped,
            A=A_vmapped,
            B=B_vmapped,
            settings=settings,
            discretizer=discretizer,
            params=params,
            sparse_layout=sparse_layout,
        )
get_solver(dynamics: Dynamics, settings: Config) -> callable

Create a sparse multi-shoot discretization solver.

When dynamics.A_c_sparsity and dynamics.B_c_sparsity are available and the pattern is not fully dense, builds a compact-V integration path with graph-coloring Jacobians. Otherwise delegates to the dense parent implementation.

Parameters:

Name Type Description Default
dynamics Dynamics

System dynamics with optional sparsity annotations.

required
settings Config

Problem configuration.

required

Returns:

Type Description
callable

Callable (x, u, params) -> (A_d, B_d, C_d, x_prop, V).

Source code in openscvx/discretization/linearize_discretize_sparse.py
def get_solver(self, dynamics: "Dynamics", settings: "Config") -> callable:
    """Create a sparse multi-shoot discretization solver.

    When ``dynamics.A_c_sparsity`` and ``dynamics.B_c_sparsity`` are
    available and the pattern is not fully dense, builds a compact-V
    integration path with graph-coloring Jacobians.  Otherwise
    delegates to the dense parent implementation.

    Args:
        dynamics: System dynamics with optional sparsity annotations.
        settings: Problem configuration.

    Returns:
        Callable ``(x, u, params) -> (A_d, B_d, C_d, x_prop, V)``.
    """
    from openscvx.symbolic.sparsity import discrete_sparsity

    from .sparse_utils import make_sparse_jacobian_fns

    A_c_pat = getattr(dynamics, "A_c_sparsity", None)
    B_c_pat = getattr(dynamics, "B_c_sparsity", None)
    has_sparsity = A_c_pat is not None and B_c_pat is not None

    if not has_sparsity or A_c_pat.all():
        return super().get_solver(dynamics, settings)

    f_vmapped = jax.vmap(dynamics.f, in_axes=(0, 0, 0, None))
    discretizer = self
    n_x = settings.sim.n_states
    n_u = settings.sim.n_controls

    A_vmapped, B_vmapped = make_sparse_jacobian_fns(
        dynamics.f,
        A_c_pat,
        B_c_pat,
        n_x,
        n_u,
    )

    u_foh_mask = getattr(settings.sim.u, "foh_mask", None)
    resolved_mask = _resolve_foh_mask(self.dis_type, n_u, u_foh_mask)
    Ad_pat, Bd_pat, Cd_pat = discrete_sparsity(
        A_c_pat,
        B_c_pat,
        resolved_mask,
    )
    Ad_r, Ad_c = np.where(Ad_pat)
    Bd_r, Bd_c = np.where(Bd_pat)
    Cd_r, Cd_c = np.where(Cd_pat)

    sparse_layout = (
        jnp.array(Ad_r),
        jnp.array(Ad_c),
        len(Ad_r),
        jnp.array(Bd_r),
        jnp.array(Bd_c),
        len(Bd_r),
        jnp.array(Cd_r),
        jnp.array(Cd_c),
        len(Cd_r),
    )

    return lambda x, u, params: _calculate_discretization_sparse(
        x=x,
        u=u,
        state_dot=f_vmapped,
        A=A_vmapped,
        B=B_vmapped,
        settings=settings,
        discretizer=discretizer,
        params=params,
        sparse_layout=sparse_layout,
    )