Skip to content

sparse_jacobian

Sparse Jacobian computation via graph coloring and jax.experimental.sparse.

This module provides utilities to compute Jacobians efficiently when the sparsity pattern is known at compile time. Instead of computing all n_x (or n_u) columns via jax.jacfwd, we use column graph coloring to determine the minimum number of directional derivatives (JVPs) needed, then reconstruct the sparse Jacobian from the compressed results.

The key entry point is :func:make_sparse_jacobian_fns, which returns vmapped Jacobian callables (matching the dense jax.vmap(jax.jacfwd(...)) signature) that internally use the sparse path.

color_columns(pattern: np.ndarray) -> np.ndarray

Greedy column coloring of a boolean sparsity pattern.

Two columns i and j can share a color when they have no row in common where both are nonzero. This function assigns the fewest colors possible via a greedy (first-fit) algorithm over columns ordered by decreasing number of nonzeros.

Parameters:

Name Type Description Default
pattern ndarray

Boolean (m, n) array where True indicates a structural nonzero.

required

Returns:

Type Description
ndarray

Integer array of length n mapping each column to its color

ndarray

(0-indexed). The number of distinct colors is max(colors) + 1.

Source code in openscvx/discretization/sparse_utils/sparse_jacobian.py
def color_columns(pattern: np.ndarray) -> np.ndarray:
    """Greedy column coloring of a boolean sparsity pattern.

    Two columns ``i`` and ``j`` can share a color when they have no row
    in common where both are nonzero.  This function assigns the fewest
    colors possible via a greedy (first-fit) algorithm over columns
    ordered by decreasing number of nonzeros.

    Args:
        pattern: Boolean ``(m, n)`` array where ``True`` indicates a
            structural nonzero.

    Returns:
        Integer array of length ``n`` mapping each column to its color
        (0-indexed).  The number of distinct colors is ``max(colors) + 1``.
    """
    m, n = pattern.shape
    colors = -np.ones(n, dtype=np.intp)

    col_order = np.argsort(-pattern.sum(axis=0))

    for col in col_order:
        row_set = set(np.where(pattern[:, col])[0])
        forbidden = set()
        for prev_col in range(n):
            if colors[prev_col] < 0:
                continue
            prev_rows = set(np.where(pattern[:, prev_col])[0])
            if row_set & prev_rows:
                forbidden.add(colors[prev_col])
        c = 0
        while c in forbidden:
            c += 1
        colors[col] = c

    return colors

make_sparse_jacobian_fns(f: Callable, A_c_pattern: Optional[np.ndarray], B_c_pattern: Optional[np.ndarray], n_x: int, n_u: int) -> Tuple[Callable, Callable]

Create vmapped sparse Jacobian functions for df/dx and df/du.

If a sparsity pattern is fully dense or None, falls back to the standard jax.jacfwd path for that Jacobian.

Parameters:

Name Type Description Default
f Callable

Dynamics function f(x, u, node, params) -> x_dot.

required
A_c_pattern Optional[ndarray]

Boolean (n_x, n_x) sparsity of df/dx, or None.

required
B_c_pattern Optional[ndarray]

Boolean (n_x, n_u) sparsity of df/du, or None.

required
n_x int

Number of state dimensions.

required
n_u int

Number of control dimensions.

required

Returns:

Type Description
Callable

(A_vmapped, B_vmapped) — vmapped Jacobian callables with

Callable

signature (x_batch, u_batch, nodes, params) -> J_batch.

Source code in openscvx/discretization/sparse_utils/sparse_jacobian.py
def make_sparse_jacobian_fns(
    f: Callable,
    A_c_pattern: Optional[np.ndarray],
    B_c_pattern: Optional[np.ndarray],
    n_x: int,
    n_u: int,
) -> Tuple[Callable, Callable]:
    """Create vmapped sparse Jacobian functions for df/dx and df/du.

    If a sparsity pattern is fully dense or ``None``, falls back to the
    standard ``jax.jacfwd`` path for that Jacobian.

    Args:
        f: Dynamics function ``f(x, u, node, params) -> x_dot``.
        A_c_pattern: Boolean ``(n_x, n_x)`` sparsity of df/dx, or ``None``.
        B_c_pattern: Boolean ``(n_x, n_u)`` sparsity of df/du, or ``None``.
        n_x: Number of state dimensions.
        n_u: Number of control dimensions.

    Returns:
        ``(A_vmapped, B_vmapped)`` — vmapped Jacobian callables with
        signature ``(x_batch, u_batch, nodes, params) -> J_batch``.
    """
    # --- df/dx ---
    if A_c_pattern is not None and not A_c_pattern.all():
        seeds_A, nc_A, nz_r_A, nz_c_A = _build_coloring_data(A_c_pattern)
        A_fn = _sparse_jacobian_fn(f, 0, seeds_A, nc_A, nz_r_A, nz_c_A, n_x, n_x)
    else:
        A_fn = jax.jacfwd(f, argnums=0)
    A_vmapped = jax.vmap(A_fn, in_axes=(0, 0, 0, None))

    # --- df/du ---
    if B_c_pattern is not None and not B_c_pattern.all():
        seeds_B, nc_B, nz_r_B, nz_c_B = _build_coloring_data(B_c_pattern)
        B_fn = _sparse_jacobian_fn(f, 1, seeds_B, nc_B, nz_r_B, nz_c_B, n_x, n_u)
    else:
        B_fn = jax.jacfwd(f, argnums=1)
    B_vmapped = jax.vmap(B_fn, in_axes=(0, 0, 0, None))

    return A_vmapped, B_vmapped