Skip to content

sparse_utils

Sparse linear algebra helpers built on jax.experimental.sparse.

color_columns(pattern: np.ndarray) -> np.ndarray

Greedy column coloring of a boolean sparsity pattern.

Two columns i and j can share a color when they have no row in common where both are nonzero. This function assigns the fewest colors possible via a greedy (first-fit) algorithm over columns ordered by decreasing number of nonzeros.

Parameters:

Name Type Description Default
pattern ndarray

Boolean (m, n) array where True indicates a structural nonzero.

required

Returns:

Type Description
ndarray

Integer array of length n mapping each column to its color

ndarray

(0-indexed). The number of distinct colors is max(colors) + 1.

Source code in openscvx/discretization/sparse_utils/sparse_jacobian.py
def color_columns(pattern: np.ndarray) -> np.ndarray:
    """Greedy column coloring of a boolean sparsity pattern.

    Two columns ``i`` and ``j`` can share a color when they have no row
    in common where both are nonzero.  This function assigns the fewest
    colors possible via a greedy (first-fit) algorithm over columns
    ordered by decreasing number of nonzeros.

    Args:
        pattern: Boolean ``(m, n)`` array where ``True`` indicates a
            structural nonzero.

    Returns:
        Integer array of length ``n`` mapping each column to its color
        (0-indexed).  The number of distinct colors is ``max(colors) + 1``.
    """
    m, n = pattern.shape
    colors = -np.ones(n, dtype=np.intp)

    col_order = np.argsort(-pattern.sum(axis=0))

    for col in col_order:
        row_set = set(np.where(pattern[:, col])[0])
        forbidden = set()
        for prev_col in range(n):
            if colors[prev_col] < 0:
                continue
            prev_rows = set(np.where(pattern[:, prev_col])[0])
            if row_set & prev_rows:
                forbidden.add(colors[prev_col])
        c = 0
        while c in forbidden:
            c += 1
        colors[col] = c

    return colors

make_sparse_jacobian_fns(f: Callable, A_c_pattern: Optional[np.ndarray], B_c_pattern: Optional[np.ndarray], n_x: int, n_u: int) -> Tuple[Callable, Callable]

Create vmapped sparse Jacobian functions for df/dx and df/du.

If a sparsity pattern is fully dense or None, falls back to the standard jax.jacfwd path for that Jacobian.

Parameters:

Name Type Description Default
f Callable

Dynamics function f(x, u, node, params) -> x_dot.

required
A_c_pattern Optional[ndarray]

Boolean (n_x, n_x) sparsity of df/dx, or None.

required
B_c_pattern Optional[ndarray]

Boolean (n_x, n_u) sparsity of df/du, or None.

required
n_x int

Number of state dimensions.

required
n_u int

Number of control dimensions.

required

Returns:

Type Description
Callable

(A_vmapped, B_vmapped) — vmapped Jacobian callables with

Callable

signature (x_batch, u_batch, nodes, params) -> J_batch.

Source code in openscvx/discretization/sparse_utils/sparse_jacobian.py
def make_sparse_jacobian_fns(
    f: Callable,
    A_c_pattern: Optional[np.ndarray],
    B_c_pattern: Optional[np.ndarray],
    n_x: int,
    n_u: int,
) -> Tuple[Callable, Callable]:
    """Create vmapped sparse Jacobian functions for df/dx and df/du.

    If a sparsity pattern is fully dense or ``None``, falls back to the
    standard ``jax.jacfwd`` path for that Jacobian.

    Args:
        f: Dynamics function ``f(x, u, node, params) -> x_dot``.
        A_c_pattern: Boolean ``(n_x, n_x)`` sparsity of df/dx, or ``None``.
        B_c_pattern: Boolean ``(n_x, n_u)`` sparsity of df/du, or ``None``.
        n_x: Number of state dimensions.
        n_u: Number of control dimensions.

    Returns:
        ``(A_vmapped, B_vmapped)`` — vmapped Jacobian callables with
        signature ``(x_batch, u_batch, nodes, params) -> J_batch``.
    """
    # --- df/dx ---
    if A_c_pattern is not None and not A_c_pattern.all():
        seeds_A, nc_A, nz_r_A, nz_c_A = _build_coloring_data(A_c_pattern)
        A_fn = _sparse_jacobian_fn(f, 0, seeds_A, nc_A, nz_r_A, nz_c_A, n_x, n_x)
    else:
        A_fn = jax.jacfwd(f, argnums=0)
    A_vmapped = jax.vmap(A_fn, in_axes=(0, 0, 0, None))

    # --- df/du ---
    if B_c_pattern is not None and not B_c_pattern.all():
        seeds_B, nc_B, nz_r_B, nz_c_B = _build_coloring_data(B_c_pattern)
        B_fn = _sparse_jacobian_fn(f, 1, seeds_B, nc_B, nz_r_B, nz_c_B, n_x, n_u)
    else:
        B_fn = jax.jacfwd(f, argnums=1)
    B_vmapped = jax.vmap(B_fn, in_axes=(0, 0, 0, None))

    return A_vmapped, B_vmapped

precompute_sparse_indices(pattern: np.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray, int]

Pre-compute BCOO index arrays from a boolean sparsity pattern.

Parameters:

Name Type Description Default
pattern ndarray

Boolean (m, n) array of structural nonzeros.

required

Returns:

Type Description
Tuple[ndarray, ndarray, int]

(nz_rows, nz_cols, nnz) — JAX integer arrays and the nonzero count.

Source code in openscvx/discretization/sparse_utils/bcoo_helpers.py
def precompute_sparse_indices(
    pattern: np.ndarray,
) -> Tuple[jnp.ndarray, jnp.ndarray, int]:
    """Pre-compute BCOO index arrays from a boolean sparsity pattern.

    Args:
        pattern: Boolean ``(m, n)`` array of structural nonzeros.

    Returns:
        ``(nz_rows, nz_cols, nnz)`` — JAX integer arrays and the nonzero count.
    """
    rows, cols = np.where(pattern)
    return jnp.array(rows), jnp.array(cols), len(rows)

sparse_matmul_batched(dense_jac: jnp.ndarray, rhs: jnp.ndarray, nz_rows: jnp.ndarray, nz_cols: jnp.ndarray, m: int, n: int) -> jnp.ndarray

Sparse-dense batched matmul: sparse(dense_jac) @ rhs.

Extracts nonzero values from the dense batch using pre-computed indices, constructs a BCOO matrix per batch item, and multiplies.

Parameters:

Name Type Description Default
dense_jac ndarray

(batch, m, n) dense arrays (only positions at nz_rows, nz_cols are used).

required
rhs ndarray

(batch, n, k) right-hand-side matrix.

required
nz_rows ndarray

Row indices of structural nonzeros.

required
nz_cols ndarray

Column indices of structural nonzeros.

required
m int

Number of rows of the sparse matrix.

required
n int

Number of columns of the sparse matrix.

required

Returns:

Type Description
ndarray

(batch, m, k) result of the sparse matmul.

Source code in openscvx/discretization/sparse_utils/bcoo_helpers.py
def sparse_matmul_batched(
    dense_jac: jnp.ndarray,
    rhs: jnp.ndarray,
    nz_rows: jnp.ndarray,
    nz_cols: jnp.ndarray,
    m: int,
    n: int,
) -> jnp.ndarray:
    """Sparse-dense batched matmul: ``sparse(dense_jac) @ rhs``.

    Extracts nonzero values from the dense batch using pre-computed indices,
    constructs a BCOO matrix per batch item, and multiplies.

    Args:
        dense_jac: ``(batch, m, n)`` dense arrays (only positions at
            ``nz_rows, nz_cols`` are used).
        rhs: ``(batch, n, k)`` right-hand-side matrix.
        nz_rows: Row indices of structural nonzeros.
        nz_cols: Column indices of structural nonzeros.
        m: Number of rows of the sparse matrix.
        n: Number of columns of the sparse matrix.

    Returns:
        ``(batch, m, k)`` result of the sparse matmul.
    """
    data = dense_jac[:, nz_rows, nz_cols]  # (batch, nnz)
    indices = jnp.stack([nz_rows, nz_cols], axis=-1)  # (nnz, 2)

    def _single_matmul(data_i, rhs_i):
        sp = BCOO((data_i, indices), shape=(m, n))
        return sp @ rhs_i

    return jax.vmap(_single_matmul)(data, rhs)