Skip to content

expert

Expert-mode features for advanced users.

This module contains features for expert users who need fine-grained control and are willing to bypass higher-level abstractions.

ByofSpec

Bases: TypedDict

Bring-Your-Own-Functions specification for expert users.

Allows bypassing the symbolic layer and directly providing raw JAX functions. All fields are optional - you can mix symbolic and byof as needed.

Warning

You are responsible for:

  • Correct indexing into unified state/control vectors
  • Ensuring functions are JAX-compatible (use jax.numpy, no side effects)
  • Ensuring functions are differentiable
  • Following g(x,u) <= 0 convention for constraints
Tip

Use the .slice property on State/Control objects for cleaner, more maintainable indexing instead of hardcoded indices. For example, use x[velocity.slice] instead of x[2:3]. The slice property is set after preprocessing and provides the correct indices into the unified state/control vectors.

Attributes:

Name Type Description
dynamics dict[str, DynamicsFunction]

Raw JAX functions for state derivatives. Maps state names to functions with signature (x, u, node, params) -> xdot_component. States here should NOT appear in symbolic dynamics dict. You can mix: some states symbolic, some in byof.

nodal_constraints List[NodalConstraintSpec]

Point-wise constraints applied at specific nodes. Each item is a :class:NodalConstraintSpec dict with:

  • func: Constraint function (x, u, node, params) -> residual (required)
  • nodes: List of node indices (optional, defaults to all nodes)

Follows g(x,u) <= 0 convention.

cross_nodal_constraints List[CrossNodalConstraintFunction]

Constraints coupling multiple nodes (smoothness, rate limits). Signature: (X, U, params) -> residual where X is (N, n_x) and U is (N, n_u). N is the number of trajectory nodes, n_x is state dimension, n_u is control dimension. Follows g(X,U) <= 0 convention.

ctcs_constraints List[CtcsConstraintSpec]

Continuous-time constraint satisfaction via dynamics augmentation. Each adds an augmented state accumulating violation penalties. See :class:CtcsConstraintSpec for details.

Example

Custom dynamics and constraints::

import jax.numpy as jnp
import openscvx as ox
from openscvx import ByofSpec

# Define states and controls
position = ox.State("position", shape=(2,))
velocity = ox.State("velocity", shape=(1,))
theta = ox.Control("theta", shape=(1,))

# Custom dynamics for one state using .slice property
def custom_velocity_dynamics(x, u, node, params):
    # Use .slice property for clean indexing
    return params["g"] * jnp.cos(u[theta.slice][0])

byof: ByofSpec = {
    "dynamics": {
        "velocity": custom_velocity_dynamics,
    },
    "nodal_constraints": [
        # Applied to all nodes (no "nodes" field)
        {
            "constraint_fn": lambda x, u, node, params: x[velocity.slice][0] - 10.0,
        },
        {
            "constraint_fn": lambda x, u, node, params: -x[velocity.slice][0],
        },
        # Specify nodes for selective enforcement
        {
            "constraint_fn": lambda x, u, node, params: x[velocity.slice][0],
            "nodes": [0],  # Velocity must be exactly 0 at start
        },
    ],
    "cross_nodal_constraints": [
        # Constrain total velocity across trajectory: sum(velocities) >= 5
        # X.shape = (N, n_x), extract velocity column using slice
        lambda X, U, params: 5.0 - jnp.sum(X[:, velocity.slice]),
    ],
    "ctcs_constraints": [
        {
            "constraint_fn": lambda x, u, node, params: x[position.slice][0] - 5.0,
            "penalty": "square",
        }
    ],
}
Source code in openscvx/expert/byof.py
class ByofSpec(TypedDict, total=False):
    """Bring-Your-Own-Functions specification for expert users.

    Allows bypassing the symbolic layer and directly providing raw JAX functions.
    All fields are optional - you can mix symbolic and byof as needed.

    Warning:
        You are responsible for:

        - Correct indexing into unified state/control vectors
        - Ensuring functions are JAX-compatible (use jax.numpy, no side effects)
        - Ensuring functions are differentiable
        - Following g(x,u) <= 0 convention for constraints

    Tip:
        Use the ``.slice`` property on State/Control objects for cleaner, more
        maintainable indexing instead of hardcoded indices. For example, use
        ``x[velocity.slice]`` instead of ``x[2:3]``. The slice property is set
        after preprocessing and provides the correct indices into the unified
        state/control vectors.

    Attributes:
        dynamics: Raw JAX functions for state derivatives. Maps state names to functions
            with signature ``(x, u, node, params) -> xdot_component``. States here should
            NOT appear in symbolic dynamics dict. You can mix: some states symbolic,
            some in byof.
        nodal_constraints: Point-wise constraints applied at specific nodes.
            Each item is a :class:`NodalConstraintSpec` dict with:

            - ``func``: Constraint function ``(x, u, node, params) -> residual`` (required)
            - ``nodes``: List of node indices (optional, defaults to all nodes)

            Follows g(x,u) <= 0 convention.
        cross_nodal_constraints: Constraints coupling multiple nodes (smoothness, rate limits).
            Signature: ``(X, U, params) -> residual`` where X is (N, n_x) and U is (N, n_u).
            N is the number of trajectory nodes, n_x is state dimension, n_u is control dimension.
            Follows g(X,U) <= 0 convention.
        ctcs_constraints: Continuous-time constraint satisfaction via dynamics augmentation.
            Each adds an augmented state accumulating violation penalties.
            See :class:`CtcsConstraintSpec` for details.

    Example:
        Custom dynamics and constraints::

            import jax.numpy as jnp
            import openscvx as ox
            from openscvx import ByofSpec

            # Define states and controls
            position = ox.State("position", shape=(2,))
            velocity = ox.State("velocity", shape=(1,))
            theta = ox.Control("theta", shape=(1,))

            # Custom dynamics for one state using .slice property
            def custom_velocity_dynamics(x, u, node, params):
                # Use .slice property for clean indexing
                return params["g"] * jnp.cos(u[theta.slice][0])

            byof: ByofSpec = {
                "dynamics": {
                    "velocity": custom_velocity_dynamics,
                },
                "nodal_constraints": [
                    # Applied to all nodes (no "nodes" field)
                    {
                        "constraint_fn": lambda x, u, node, params: x[velocity.slice][0] - 10.0,
                    },
                    {
                        "constraint_fn": lambda x, u, node, params: -x[velocity.slice][0],
                    },
                    # Specify nodes for selective enforcement
                    {
                        "constraint_fn": lambda x, u, node, params: x[velocity.slice][0],
                        "nodes": [0],  # Velocity must be exactly 0 at start
                    },
                ],
                "cross_nodal_constraints": [
                    # Constrain total velocity across trajectory: sum(velocities) >= 5
                    # X.shape = (N, n_x), extract velocity column using slice
                    lambda X, U, params: 5.0 - jnp.sum(X[:, velocity.slice]),
                ],
                "ctcs_constraints": [
                    {
                        "constraint_fn": lambda x, u, node, params: x[position.slice][0] - 5.0,
                        "penalty": "square",
                    }
                ],
            }
    """

    dynamics: dict[str, DynamicsFunction]
    nodal_constraints: List[NodalConstraintSpec]
    cross_nodal_constraints: List[CrossNodalConstraintFunction]
    ctcs_constraints: List[CtcsConstraintSpec]

CtcsConstraintSpec

Bases: TypedDict

Specification for CTCS (Continuous-Time Constraint Satisfaction) constraint.

CTCS constraints are enforced by augmenting the dynamics with a penalty term that accumulates violations over time. Useful for path constraints that must be satisfied continuously, not just at discrete nodes.

Attributes:

Name Type Description
constraint_fn CtcsConstraintFunction

Function computing constraint residual with signature (x, u, node, params) -> scalar. Must return scalar. Follows g(x,u) <= 0 convention (negative = satisfied). Required field.

penalty PenaltyFunction

Penalty function for positive residuals (violations). Built-in options: "square" (max(r,0)^2, default), "l1" (max(r,0)), "huber" (Huber loss). Custom: Callable (r) -> penalty (non-negative, differentiable).

bounds Tuple[float, float]

(min, max) bounds for augmented state accumulating penalties. Default: (0.0, 1e-4). Max acts as soft constraint on total violation.

initial float

Initial value for augmented state. Default: bounds[0] (usually 0.0).

over Tuple[int, int]

Node interval (start, end) where constraint is active. The constraint is enforced for nodes in [start, end). If omitted, constraint is active over all nodes. Matches symbolic .over() method behavior.

idx int

Constraint group index for sharing augmented states (default: 0). All CTCS constraints (symbolic and byof) with the same idx share a single augmented state. Their penalties are summed together. Use different idx values to track different types of violations separately.

Warning

If symbolic CTCS constraints exist with idx values [0, 1, 2], then byof idx must either:

  • Match an existing idx (e.g., 0, 1, or 2) to add to that augmented state
  • Be sequential after them (e.g., 3, 4, 5) to create new augmented states

You cannot use idx values that create gaps (e.g., if symbolic has [0, 1], you cannot use byof idx=3 without also using idx=2).

Example

Enforce position[0] <= 10.0 continuously::

# Assuming position = ox.State("position", shape=(2,))
ctcs_spec: CtcsConstraintSpec = {
    "constraint_fn": lambda x, u, node, params: x[position.slice][0] - 10.0,
    "penalty": "square",
    "bounds": (0.0, 1e-4),
    "initial": 0.0,
    "idx": 0,  # Groups with other constraints having idx=0
}

Enforce constraint only over specific node range::

ctcs_spec: CtcsConstraintSpec = {
    "constraint_fn": lambda x, u, node, params: x[position.slice][0] - 10.0,
    "over": (10, 50),  # Active only for nodes 10-49
    "penalty": "square",
}

Multiple constraints sharing an augmented state::

# If symbolic CTCS already has idx=[0, 1], then:

byof = {
    "ctcs_constraints": [
        # Add to existing symbolic idx=0 augmented state
        {
            "constraint_fn": lambda x, u, node, params: x[pos.slice][0] - 10.0,
            "idx": 0,  # Shares with symbolic idx=0
        },
        # Add to existing symbolic idx=1 augmented state
        {
            "constraint_fn": lambda x, u, node, params: x[vel.slice][0] - 5.0,
            "idx": 1,  # Shares with symbolic idx=1
        },
        # Create NEW augmented state (sequential after symbolic)
        {
            "constraint_fn": lambda x, u, node, params: x[pos.slice][1] - 8.0,
            "idx": 2,  # New state (symbolic has 0,1, so next is 2)
        },
    ]
}
Source code in openscvx/expert/byof.py
class CtcsConstraintSpec(TypedDict, total=False):
    """Specification for CTCS (Continuous-Time Constraint Satisfaction) constraint.

    CTCS constraints are enforced by augmenting the dynamics with a penalty term that
    accumulates violations over time. Useful for path constraints that must be satisfied
    continuously, not just at discrete nodes.

    Attributes:
        constraint_fn: Function computing constraint residual with signature
            ``(x, u, node, params) -> scalar``. Must return scalar.
            Follows g(x,u) <= 0 convention (negative = satisfied). Required field.
        penalty: Penalty function for positive residuals (violations).
            Built-in options: "square" (max(r,0)^2, default), "l1" (max(r,0)),
            "huber" (Huber loss). Custom: Callable ``(r) -> penalty`` (non-negative,
            differentiable).
        bounds: (min, max) bounds for augmented state accumulating penalties.
            Default: (0.0, 1e-4). Max acts as soft constraint on total violation.
        initial: Initial value for augmented state. Default: bounds[0] (usually 0.0).
        over: Node interval (start, end) where constraint is active. The constraint
            is enforced for nodes in [start, end). If omitted, constraint is active
            over all nodes. Matches symbolic `.over()` method behavior.
        idx: Constraint group index for sharing augmented states (default: 0).
            All CTCS constraints (symbolic and byof) with the same idx share a single
            augmented state. Their penalties are summed together. Use different idx values
            to track different types of violations separately.

    Warning:
        If symbolic CTCS constraints exist with idx values [0, 1, 2], then byof idx **must** either:

        - Match an existing idx (e.g., 0, 1, or 2) to add to that augmented state
        - Be sequential after them (e.g., 3, 4, 5) to create new augmented states

        You cannot use idx values that create gaps (e.g., if symbolic has [0, 1],
        you cannot use byof idx=3 without also using idx=2).

    Example:
        Enforce position[0] <= 10.0 continuously::

            # Assuming position = ox.State("position", shape=(2,))
            ctcs_spec: CtcsConstraintSpec = {
                "constraint_fn": lambda x, u, node, params: x[position.slice][0] - 10.0,
                "penalty": "square",
                "bounds": (0.0, 1e-4),
                "initial": 0.0,
                "idx": 0,  # Groups with other constraints having idx=0
            }

        Enforce constraint only over specific node range::

            ctcs_spec: CtcsConstraintSpec = {
                "constraint_fn": lambda x, u, node, params: x[position.slice][0] - 10.0,
                "over": (10, 50),  # Active only for nodes 10-49
                "penalty": "square",
            }

        Multiple constraints sharing an augmented state::

            # If symbolic CTCS already has idx=[0, 1], then:

            byof = {
                "ctcs_constraints": [
                    # Add to existing symbolic idx=0 augmented state
                    {
                        "constraint_fn": lambda x, u, node, params: x[pos.slice][0] - 10.0,
                        "idx": 0,  # Shares with symbolic idx=0
                    },
                    # Add to existing symbolic idx=1 augmented state
                    {
                        "constraint_fn": lambda x, u, node, params: x[vel.slice][0] - 5.0,
                        "idx": 1,  # Shares with symbolic idx=1
                    },
                    # Create NEW augmented state (sequential after symbolic)
                    {
                        "constraint_fn": lambda x, u, node, params: x[pos.slice][1] - 8.0,
                        "idx": 2,  # New state (symbolic has 0,1, so next is 2)
                    },
                ]
            }
    """

    constraint_fn: CtcsConstraintFunction  # Required
    penalty: PenaltyFunction
    bounds: Tuple[float, float]
    initial: float
    over: Tuple[int, int]
    idx: int

NodalConstraintSpec

Bases: TypedDict

Specification for nodal constraint with optional node selection.

Nodal constraints are point-wise constraints evaluated at specific trajectory nodes. By default, constraints apply to all nodes, but you can restrict enforcement to specific nodes for boundary conditions, waypoints, or computational efficiency.

Attributes:

Name Type Description
constraint_fn NodalConstraintFunction

Constraint function with signature (x, u, node, params) -> residual. Follows g(x,u) <= 0 convention (negative = satisfied). Required field.

nodes List[int]

List of integer node indices where constraint is enforced. If omitted, applies to all nodes. Negative indices supported (e.g., -1 for last). Optional field.

Example

Boundary constraint only at first and last nodes::

nodal_spec: NodalConstraintSpec = {
    "constraint_fn": lambda x, u, node, params: x[velocity.slice][0],
    "nodes": [0, -1],  # Only at start and end
}

Waypoint constraint at middle of trajectory::

nodal_spec: NodalConstraintSpec = {
    "constraint_fn": lambda x, u, node, params: jnp.linalg.norm(
        x[position.slice] - jnp.array([5.0, 7.5])
    ) - 0.1,
    "nodes": [N // 2],
}
Source code in openscvx/expert/byof.py
class NodalConstraintSpec(TypedDict, total=False):
    """Specification for nodal constraint with optional node selection.

    Nodal constraints are point-wise constraints evaluated at specific trajectory nodes.
    By default, constraints apply to all nodes, but you can restrict enforcement to
    specific nodes for boundary conditions, waypoints, or computational efficiency.

    Attributes:
        constraint_fn: Constraint function with signature ``(x, u, node, params) -> residual``.
            Follows g(x,u) <= 0 convention (negative = satisfied). Required field.
        nodes: List of integer node indices where constraint is enforced.
            If omitted, applies to all nodes. Negative indices supported (e.g., -1 for last).
            Optional field.

    Example:
        Boundary constraint only at first and last nodes::

            nodal_spec: NodalConstraintSpec = {
                "constraint_fn": lambda x, u, node, params: x[velocity.slice][0],
                "nodes": [0, -1],  # Only at start and end
            }

        Waypoint constraint at middle of trajectory::

            nodal_spec: NodalConstraintSpec = {
                "constraint_fn": lambda x, u, node, params: jnp.linalg.norm(
                    x[position.slice] - jnp.array([5.0, 7.5])
                ) - 0.1,
                "nodes": [N // 2],
            }
    """

    constraint_fn: NodalConstraintFunction  # Required
    nodes: List[int]

apply_byof(byof: dict, dynamics: Dynamics, dynamics_prop: Dynamics, jax_constraints: LoweredJaxConstraints, x_unified: UnifiedState, x_prop_unified: UnifiedState, u_unified: UnifiedState, states: List[State], states_prop: List[State], N: int) -> Tuple[Dynamics, Dynamics, LoweredJaxConstraints, UnifiedState, UnifiedState]

Apply bring-your-own-functions (byof) to augment lowered problem.

Handles raw JAX functions provided by expert users, including: - dynamics: Raw JAX functions for specific state derivatives - nodal_constraints: Point-wise constraints at each node - cross_nodal_constraints: Constraints coupling multiple nodes - ctcs_constraints: Continuous-time constraint satisfaction via dynamics augmentation

Parameters:

Name Type Description Default
byof dict

Dict with keys "dynamics", "nodal_constraints", "cross_nodal_constraints", "ctcs_constraints"

required
dynamics Dynamics

Lowered optimization dynamics to potentially augment

required
dynamics_prop Dynamics

Lowered propagation dynamics to potentially augment

required
jax_constraints LoweredJaxConstraints

Lowered JAX constraints to append to

required
x_unified UnifiedState

Unified optimization state interface to potentially augment

required
x_prop_unified UnifiedState

Unified propagation state interface to potentially augment

required
u_unified UnifiedState

Unified control interface for validation

required
states List[State]

List of State objects for optimization (with _slice attributes)

required
states_prop List[State]

List of State objects for propagation (with _slice attributes)

required
N int

Number of nodes in the trajectory

required

Returns:

Type Description
Tuple[Dynamics, Dynamics, LoweredJaxConstraints, UnifiedState, UnifiedState]

Tuple of (dynamics, dynamics_prop, jax_constraints, x_unified, x_prop_unified)

Example

dynamics, dynamics_prop, constraints, x_unified, x_prop_unified = apply_byof( ... byof, dynamics, dynamics_prop, jax_constraints, ... x_unified, x_prop_unified, u_unified, states, states_prop, N ... )

Source code in openscvx/expert/lowering.py
def apply_byof(
    byof: dict,
    dynamics: Dynamics,
    dynamics_prop: Dynamics,
    jax_constraints: LoweredJaxConstraints,
    x_unified: "UnifiedState",
    x_prop_unified: "UnifiedState",
    u_unified: "UnifiedState",
    states: List["State"],
    states_prop: List["State"],
    N: int,
) -> Tuple[Dynamics, Dynamics, LoweredJaxConstraints, "UnifiedState", "UnifiedState"]:
    """Apply bring-your-own-functions (byof) to augment lowered problem.

    Handles raw JAX functions provided by expert users, including:
    - dynamics: Raw JAX functions for specific state derivatives
    - nodal_constraints: Point-wise constraints at each node
    - cross_nodal_constraints: Constraints coupling multiple nodes
    - ctcs_constraints: Continuous-time constraint satisfaction via dynamics augmentation

    Args:
        byof: Dict with keys "dynamics", "nodal_constraints", "cross_nodal_constraints",
            "ctcs_constraints"
        dynamics: Lowered optimization dynamics to potentially augment
        dynamics_prop: Lowered propagation dynamics to potentially augment
        jax_constraints: Lowered JAX constraints to append to
        x_unified: Unified optimization state interface to potentially augment
        x_prop_unified: Unified propagation state interface to potentially augment
        u_unified: Unified control interface for validation
        states: List of State objects for optimization (with _slice attributes)
        states_prop: List of State objects for propagation (with _slice attributes)
        N: Number of nodes in the trajectory

    Returns:
        Tuple of (dynamics, dynamics_prop, jax_constraints, x_unified, x_prop_unified)

    Example:
        >>> dynamics, dynamics_prop, constraints, x_unified, x_prop_unified = apply_byof(
        ...     byof, dynamics, dynamics_prop, jax_constraints,
        ...     x_unified, x_prop_unified, u_unified, states, states_prop, N
        ... )
    """

    # Note: byof validation happens earlier in Problem.__init__ to fail fast
    # Handle byof dynamics by splicing in raw JAX functions at the correct slices
    byof_dynamics = byof.get("dynamics", {})
    if byof_dynamics:
        # Build mapping from state name to slice for optimization states
        state_slices = {state.name: state._slice for state in states}
        state_slices_prop = {state.name: state._slice for state in states_prop}

        def _make_composite_dynamics(orig_f, byof_fns, slices_map):
            """Create composite dynamics combining symbolic and byof state derivatives.

            This factory splices user-provided byof dynamics into the unified dynamics
            function at the appropriate slice indices, replacing the symbolic dynamics
            for specific states while preserving the rest.

            Args:
                orig_f: Original unified dynamics (x, u, node, params) -> xdot
                byof_fns: Dict mapping state names to byof dynamics functions
                slices_map: Dict mapping state names to slice objects for indexing

            Returns:
                Composite dynamics function with byof derivatives spliced in
            """

            def composite_f(x, u, node, params):
                # Start with symbolic/default dynamics for all states
                xdot = orig_f(x, u, node, params)

                # Splice in byof dynamics for specific states
                for state_name, byof_fn in byof_fns.items():
                    sl = slices_map[state_name]
                    # Replace the derivative for this state with the byof result
                    xdot = xdot.at[sl].set(byof_fn(x, u, node, params))

                return xdot

            return composite_f

        # Create composite optimization dynamics
        composite_f = _make_composite_dynamics(dynamics.f, byof_dynamics, state_slices)
        dynamics = Dynamics(
            f=composite_f,
            A=jacfwd(composite_f, argnums=0),
            B=jacfwd(composite_f, argnums=1),
        )

        # Create composite propagation dynamics
        composite_f_prop = _make_composite_dynamics(
            dynamics_prop.f, byof_dynamics, state_slices_prop
        )
        dynamics_prop = Dynamics(
            f=composite_f_prop,
            A=jacfwd(composite_f_prop, argnums=0),
            B=jacfwd(composite_f_prop, argnums=1),
        )

    # Handle nodal constraints
    # Note: Validation happens earlier in Problem.__init__ via validate_byof
    for constraint_spec in byof.get("nodal_constraints", []):
        fn = constraint_spec["constraint_fn"]
        nodes = constraint_spec.get("nodes", list(range(N)))  # Default: all nodes

        # Normalize negative node indices (validation already done in validate_byof)
        normalized_nodes = [node if node >= 0 else N + node for node in nodes]

        constraint = LoweredNodalConstraint(
            func=jax.vmap(fn, in_axes=(0, 0, None, None)),
            grad_g_x=jax.vmap(jacfwd(fn, argnums=0), in_axes=(0, 0, None, None)),
            grad_g_u=jax.vmap(jacfwd(fn, argnums=1), in_axes=(0, 0, None, None)),
            nodes=normalized_nodes,
        )
        jax_constraints.nodal.append(constraint)

    # Handle cross-nodal constraints
    for fn in byof.get("cross_nodal_constraints", []):
        constraint = LoweredCrossNodeConstraint(
            func=fn,
            grad_g_X=jacfwd(fn, argnums=0),
            grad_g_U=jacfwd(fn, argnums=1),
        )
        jax_constraints.cross_node.append(constraint)

    # Handle CTCS constraints by augmenting dynamics
    # Built-in penalty functions
    def _penalty_square(r):
        return jnp.maximum(r, 0.0) ** 2

    def _penalty_l1(r):
        return jnp.maximum(r, 0.0)

    def _penalty_huber(r, delta=1.0):
        abs_r = jnp.maximum(r, 0.0)
        return jnp.where(abs_r <= delta, 0.5 * abs_r**2, delta * (abs_r - 0.5 * delta))

    _PENALTY_FUNCTIONS = {
        "square": _penalty_square,
        "l1": _penalty_l1,
        "huber": _penalty_huber,
    }

    # Determine which symbolic CTCS idx values already exist
    # Symbolic augmented states are named "_ctcs_aug_{i}" where i is sequential
    # and corresponds to sorted symbolic idx values (0, 1, 2, ...)
    symbolic_ctcs_idx = []
    for state in states:
        if state.name.startswith("_ctcs_aug_"):
            try:
                aug_idx = int(state.name.split("_")[-1])
                symbolic_ctcs_idx.append(aug_idx)
            except (ValueError, IndexError):
                pass

    # Symbolic CTCS creates augmented states with sequential idx: 0, 1, 2, ...
    # so max_symbolic_idx = len(symbolic_ctcs_idx) - 1 (or -1 if none exist)
    max_symbolic_idx = len(symbolic_ctcs_idx) - 1 if symbolic_ctcs_idx else -1

    # Build idx -> augmented_state_slice mapping for existing symbolic CTCS
    # Augmented states appear after regular states in the unified vector
    # We'll determine the slice by finding the state in the states list
    idx_to_aug_slice = {}
    for state in states:
        if state.name.startswith("_ctcs_aug_"):
            try:
                aug_idx = int(state.name.split("_")[-1])
                # The actual idx value IS the sequential index for symbolic CTCS
                # (they're created with idx 0, 1, 2, ... in sorted order)
                idx_to_aug_slice[aug_idx] = state._slice
            except (ValueError, IndexError, AttributeError):
                pass

    # Group BYOF CTCS constraints by idx (default to 0)
    byof_ctcs_groups = {}
    for ctcs_spec in byof.get("ctcs_constraints", []):
        idx = ctcs_spec.get("idx", 0)
        if idx not in byof_ctcs_groups:
            byof_ctcs_groups[idx] = []
        byof_ctcs_groups[idx].append(ctcs_spec)

    # Validate that byof idx values don't create gaps
    # All idx must form contiguous sequence: [0, 1, 2, ..., max_idx]
    if byof_ctcs_groups:
        all_idx = sorted(set(range(max_symbolic_idx + 1)) | set(byof_ctcs_groups.keys()))
        expected_idx = list(range(len(all_idx)))
        if all_idx != expected_idx:
            raise ValueError(
                f"BYOF CTCS idx values create non-contiguous sequence. "
                f"Symbolic CTCS has idx=[{', '.join(map(str, range(max_symbolic_idx + 1)))}], "
                f"combined with byof idx={sorted(byof_ctcs_groups.keys())} gives {all_idx}. "
                f"Expected contiguous sequence {expected_idx}. "
                f"Byof idx must either match existing symbolic idx or be sequential after them."
            )

    # Process each idx group
    for idx in sorted(byof_ctcs_groups.keys()):
        specs = byof_ctcs_groups[idx]

        # Collect all penalty functions for this idx
        penalty_fns = []
        for spec in specs:
            constraint_fn = spec["constraint_fn"]
            penalty_spec = spec.get("penalty", "square")
            over_interval = spec.get("over", None)  # Node interval (start, end) or None

            if callable(penalty_spec):
                penalty_func = penalty_spec
            else:
                penalty_func = _PENALTY_FUNCTIONS[penalty_spec]

            # Create a combined constraint+penalty function
            def _make_penalty_fn(cons_fn, pen_func, over):
                """Factory to capture constraint, penalty functions, and node interval.

                Args:
                    cons_fn: Constraint function (x, u, node, params) -> scalar residual
                    pen_func: Penalty function (residual) -> penalty value
                    over: Optional (start, end) tuple for conditional activation

                Returns:
                    Penalty function that conditionally activates based on node interval
                """

                def penalty_fn(x, u, node, params):
                    # Compute penalty for the constraint violation
                    residual = cons_fn(x, u, node, params)
                    penalty_value = pen_func(residual)

                    # Apply conditional logic if over interval is specified
                    if over is not None:
                        start_node, end_node = over
                        # Extract scalar from node (which may be array or scalar)
                        # Keep as JAX array for tracing compatibility
                        node_scalar = jnp.atleast_1d(node)[0]
                        is_active = (start_node <= node_scalar) & (node_scalar < end_node)

                        # Use jax.lax.cond for JAX-traceable conditional evaluation
                        # Penalty is active only when node is in [start, end)
                        return cond(
                            is_active,
                            lambda _: penalty_value,
                            lambda _: 0.0,
                            operand=None,
                        )
                    else:
                        # Always active if no interval specified
                        return penalty_value

                return penalty_fn

            penalty_fns.append(_make_penalty_fn(constraint_fn, penalty_func, over_interval))

        if idx in idx_to_aug_slice:
            # This idx already exists from symbolic CTCS - add penalties to existing state
            aug_slice = idx_to_aug_slice[idx]

            def _make_ctcs_addition(orig_f, pen_fns, aug_sl):
                """Create dynamics that adds penalties to existing augmented state.

                Args:
                    orig_f: Original dynamics function
                    pen_fns: List of penalty functions to add
                    aug_sl: Slice of the augmented state to modify

                Returns:
                    Modified dynamics function
                """

                def modified_f(x, u, node, params):
                    xdot = orig_f(x, u, node, params)

                    # Sum all penalties for this idx
                    total_penalty = sum(pen_fn(x, u, node, params) for pen_fn in pen_fns)

                    # Add to existing augmented state derivative
                    current_deriv = xdot[aug_sl]
                    xdot = xdot.at[aug_sl].set(current_deriv + total_penalty)

                    return xdot

                return modified_f

            # Modify both optimization and propagation dynamics
            dynamics.f = _make_ctcs_addition(dynamics.f, penalty_fns, aug_slice)
            dynamics.A = jacfwd(dynamics.f, argnums=0)
            dynamics.B = jacfwd(dynamics.f, argnums=1)

            dynamics_prop.f = _make_ctcs_addition(dynamics_prop.f, penalty_fns, aug_slice)
            dynamics_prop.A = jacfwd(dynamics_prop.f, argnums=0)
            dynamics_prop.B = jacfwd(dynamics_prop.f, argnums=1)

        else:
            # New idx - create new augmented state
            # Use bounds/initial from first spec in this group
            first_spec = specs[0]
            bounds = first_spec.get("bounds", (0.0, 1e-4))
            initial = first_spec.get("initial", bounds[0])

            def _make_ctcs_new_state(orig_f, pen_fns):
                """Create dynamics augmented with new CTCS state.

                Args:
                    orig_f: Original dynamics function
                    pen_fns: List of penalty functions to sum

                Returns:
                    Augmented dynamics function
                """

                def augmented_f(x, u, node, params):
                    xdot = orig_f(x, u, node, params)

                    # Sum all penalties for this new idx
                    total_penalty = sum(pen_fn(x, u, node, params) for pen_fn in pen_fns)

                    # Append as new augmented state derivative
                    return jnp.concatenate([xdot, jnp.atleast_1d(total_penalty)])

                return augmented_f

            # Augment optimization dynamics
            aug_f = _make_ctcs_new_state(dynamics.f, penalty_fns)
            dynamics = Dynamics(
                f=aug_f,
                A=jacfwd(aug_f, argnums=0),
                B=jacfwd(aug_f, argnums=1),
            )

            # Augment propagation dynamics
            aug_f_prop = _make_ctcs_new_state(dynamics_prop.f, penalty_fns)
            dynamics_prop = Dynamics(
                f=aug_f_prop,
                A=jacfwd(aug_f_prop, argnums=0),
                B=jacfwd(aug_f_prop, argnums=1),
            )

            # Create State objects for the new augmented states
            # This is necessary for CVXPy variable creation and other bookkeeping
            from openscvx.symbolic.expr.state import State

            # Create augmented state for optimization
            aug_state = State(f"_ctcs_aug_{idx}", shape=(1,))
            aug_state.min = np.array([bounds[0]])
            aug_state.max = np.array([bounds[1]])
            aug_state.initial = np.array([initial])
            aug_state.final = [("free", 0.0)]
            aug_state.guess = np.full((N, 1), initial)

            # Set _slice attribute for the new state
            current_dim = x_unified.shape[0]
            aug_state._slice = slice(current_dim, current_dim + 1)

            # Append to states list (in-place modification visible to caller)
            states.append(aug_state)

            # Create augmented state for propagation
            aug_state_prop = State(f"_ctcs_aug_{idx}", shape=(1,))
            aug_state_prop.min = np.array([bounds[0]])
            aug_state_prop.max = np.array([bounds[1]])
            aug_state_prop.initial = np.array([initial])
            aug_state_prop.final = [("free", 0.0)]
            aug_state_prop.guess = np.full((N, 1), initial)

            # Set _slice attribute for the propagation state
            current_dim_prop = x_prop_unified.shape[0]
            aug_state_prop._slice = slice(current_dim_prop, current_dim_prop + 1)

            # Append to states_prop list
            states_prop.append(aug_state_prop)

            # Add new augmented states to both unified state interfaces
            x_unified.append(
                min=bounds[0],
                max=bounds[1],
                guess=initial,
                initial=initial,
                final=0.0,
                augmented=True,
            )
            x_prop_unified.append(
                min=bounds[0],
                max=bounds[1],
                guess=initial,
                initial=initial,
                final=0.0,
                augmented=True,
            )

    return dynamics, dynamics_prop, jax_constraints, x_unified, x_prop_unified

validate_byof(byof: dict, states: List[State], n_x: int, n_u: int, N: int = None) -> None

Validate byof function signatures and shapes.

Checks that user-provided functions have the correct signatures and return appropriate shapes. Performs validation before functions are used to provide clear error messages.

Parameters:

Name Type Description Default
byof dict

Dictionary of user-provided functions to validate

required
states List[State]

List of State objects for determining expected shapes

required
n_x int

Total dimension of the unified state vector

required
n_u int

Total dimension of the unified control vector

required
N int

Number of nodes in the trajectory (optional). If provided, validates node indices in nodal constraints.

None

Raises:

Type Description
ValueError

If any function has invalid signature or returns wrong shape

TypeError

If functions are not callable

Example

validate_byof(byof, states, n_x=10, n_u=3, N=50) # Raises if invalid

Source code in openscvx/expert/validation.py
def validate_byof(
    byof: dict,
    states: List["State"],
    n_x: int,
    n_u: int,
    N: int = None,
) -> None:
    """Validate byof function signatures and shapes.

    Checks that user-provided functions have the correct signatures and return
    appropriate shapes. Performs validation before functions are used to provide
    clear error messages.

    Args:
        byof: Dictionary of user-provided functions to validate
        states: List of State objects for determining expected shapes
        n_x: Total dimension of the unified state vector
        n_u: Total dimension of the unified control vector
        N: Number of nodes in the trajectory (optional). If provided, validates
            node indices in nodal constraints.

    Raises:
        ValueError: If any function has invalid signature or returns wrong shape
        TypeError: If functions are not callable

    Example:
        >>> validate_byof(byof, states, n_x=10, n_u=3, N=50)  # Raises if invalid
    """
    import jax
    import jax.numpy as jnp

    # Validate byof keys
    valid_keys = {"dynamics", "nodal_constraints", "cross_nodal_constraints", "ctcs_constraints"}
    invalid_keys = set(byof.keys()) - valid_keys
    if invalid_keys:
        raise ValueError(f"Unknown byof keys: {invalid_keys}. Valid keys: {valid_keys}")

    # Create dummy inputs for testing
    dummy_x = jnp.zeros(n_x)
    dummy_u = jnp.zeros(n_u)
    dummy_node = 0
    dummy_params = {}

    # Validate dynamics functions
    byof_dynamics = byof.get("dynamics", {})
    if byof_dynamics:
        # Build mapping from state name to expected shape
        state_shapes = {state.name: state.shape for state in states}

        for state_name, fn in byof_dynamics.items():
            if state_name not in state_shapes:
                raise ValueError(
                    f"byof dynamics '{state_name}' does not match any state name. "
                    f"Available states: {list(state_shapes.keys())}"
                )

            if not callable(fn):
                raise TypeError(f"byof dynamics '{state_name}' must be callable, got {type(fn)}")

            # Check signature
            sig = inspect.signature(fn)
            if len(sig.parameters) != 4:
                raise ValueError(
                    f"byof dynamics '{state_name}' must have signature f(x, u, node, params), "
                    f"got {len(sig.parameters)} parameters: {list(sig.parameters.keys())}"
                )

            # Test call and check output shape
            try:
                result = fn(dummy_x, dummy_u, dummy_node, dummy_params)
            except Exception as e:
                raise ValueError(
                    f"byof dynamics '{state_name}' failed on test call with "
                    f"x.shape={dummy_x.shape}, u.shape={dummy_u.shape}: {e}"
                ) from e

            expected_shape = state_shapes[state_name]
            result_shape = jnp.asarray(result).shape
            if result_shape != expected_shape:
                raise ValueError(
                    f"byof dynamics '{state_name}' returned shape {result_shape}, "
                    f"expected {expected_shape} (state '{state_name}' shape)"
                )

            # Test that gradient works (JAX compatibility check)
            try:
                jax.grad(lambda x: jnp.sum(fn(x, dummy_u, dummy_node, dummy_params)))(dummy_x)
            except Exception as e:
                raise ValueError(
                    f"byof dynamics '{state_name}' is not differentiable with JAX. "
                    f"Ensure the function uses JAX operations (jax.numpy, not numpy): {e}"
                ) from e

    # Validate nodal constraints
    for i, constraint_spec in enumerate(byof.get("nodal_constraints", [])):
        if not isinstance(constraint_spec, dict):
            raise TypeError(
                f"byof nodal_constraints[{i}] must be a dict (NodalConstraintSpec), "
                f"got {type(constraint_spec)}"
            )

        if "constraint_fn" not in constraint_spec:
            raise ValueError(f"byof nodal_constraints[{i}] missing required key 'constraint_fn'")

        fn = constraint_spec["constraint_fn"]
        if not callable(fn):
            raise TypeError(
                f"byof nodal_constraints[{i}]['constraint_fn'] must be callable, got {type(fn)}"
            )

        # Check signature
        sig = inspect.signature(fn)
        if len(sig.parameters) != 4:
            raise ValueError(
                f"byof nodal_constraints[{i}]['constraint_fn'] must have signature "
                f"f(x, u, node, params), "
                f"got {len(sig.parameters)} parameters: {list(sig.parameters.keys())}"
            )

        # Test call
        try:
            result = fn(dummy_x, dummy_u, dummy_node, dummy_params)
        except Exception as e:
            raise ValueError(
                f"byof nodal_constraints[{i}]['constraint_fn'] failed on test call with "
                f"x.shape={dummy_x.shape}, u.shape={dummy_u.shape}: {e}"
            ) from e

        # Check that result is array-like (can be scalar or vector)
        try:
            result_array = jnp.asarray(result)
        except Exception as e:
            raise ValueError(
                f"byof nodal_constraints[{i}]['constraint_fn'] must return array-like value, "
                f"got {type(result)}: {e}"
            ) from e

        # Test gradient
        try:
            jax.grad(lambda x: jnp.sum(fn(x, dummy_u, dummy_node, dummy_params)))(dummy_x)
        except Exception as e:
            raise ValueError(
                f"byof nodal_constraints[{i}]['constraint_fn'] is not differentiable with JAX: {e}"
            ) from e

        # Validate nodes if provided
        if "nodes" in constraint_spec:
            nodes = constraint_spec["nodes"]
            if not isinstance(nodes, (list, tuple)):
                raise TypeError(
                    f"byof nodal_constraints[{i}]['nodes'] must be a list or tuple, "
                    f"got {type(nodes)}"
                )
            if len(nodes) == 0:
                raise ValueError(f"byof nodal_constraints[{i}]['nodes'] cannot be empty")

            # Validate node indices if N is provided
            if N is not None:
                for node in nodes:
                    # Handle negative indices (e.g., -1 for last node)
                    normalized_node = node if node >= 0 else N + node
                    # Validate range
                    if not (0 <= normalized_node < N):
                        raise ValueError(
                            f"byof nodal_constraints[{i}]['nodes'] contains invalid index {node} "
                            f"(normalized: {normalized_node}). Valid range is [0, {N}) or "
                            f"negative indices [-{N}, -1]."
                        )

    # Validate cross-nodal constraints
    dummy_X = jnp.zeros((10, n_x))  # Dummy trajectory with 10 nodes
    dummy_U = jnp.zeros((10, n_u))

    for i, fn in enumerate(byof.get("cross_nodal_constraints", [])):
        if not callable(fn):
            raise TypeError(f"byof cross_nodal_constraints[{i}] must be callable, got {type(fn)}")

        # Check signature
        sig = inspect.signature(fn)
        if len(sig.parameters) != 3:
            raise ValueError(
                f"byof cross_nodal_constraints[{i}] must have signature f(X, U, params), "
                f"got {len(sig.parameters)} parameters: {list(sig.parameters.keys())}"
            )

        # Test call
        try:
            result = fn(dummy_X, dummy_U, dummy_params)
        except Exception as e:
            raise ValueError(
                f"byof cross_nodal_constraints[{i}] failed on test call with "
                f"X.shape={dummy_X.shape}, U.shape={dummy_U.shape}: {e}"
            ) from e

        # Check that result is array-like
        try:
            result_array = jnp.asarray(result)
        except Exception as e:
            raise ValueError(
                f"byof cross_nodal_constraints[{i}] must return array-like value, "
                f"got {type(result)}: {e}"
            ) from e

        # Test gradient
        try:
            jax.grad(lambda X: jnp.sum(fn(X, dummy_U, dummy_params)))(dummy_X)
        except Exception as e:
            raise ValueError(
                f"byof cross_nodal_constraints[{i}] is not differentiable with JAX: {e}"
            ) from e

    # Validate CTCS constraints
    for i, ctcs_spec in enumerate(byof.get("ctcs_constraints", [])):
        if not isinstance(ctcs_spec, dict):
            raise TypeError(f"byof ctcs_constraints[{i}] must be a dict, got {type(ctcs_spec)}")

        if "constraint_fn" not in ctcs_spec:
            raise ValueError(f"byof ctcs_constraints[{i}] missing required key 'constraint_fn'")

        fn = ctcs_spec["constraint_fn"]
        if not callable(fn):
            raise TypeError(
                f"byof ctcs_constraints[{i}]['constraint_fn'] must be callable, got {type(fn)}"
            )

        # Check signature
        sig = inspect.signature(fn)
        if len(sig.parameters) != 4:
            raise ValueError(
                f"byof ctcs_constraints[{i}]['constraint_fn'] must have signature "
                f"f(x, u, node, params), got {len(sig.parameters)} parameters: "
                f"{list(sig.parameters.keys())}"
            )

        # Test call
        try:
            result = fn(dummy_x, dummy_u, dummy_node, dummy_params)
        except Exception as e:
            raise ValueError(
                f"byof ctcs_constraints[{i}]['constraint_fn'] failed on test call: {e}"
            ) from e

        # Check that result is scalar
        result_array = jnp.asarray(result)
        if result_array.shape != ():
            raise ValueError(
                f"byof ctcs_constraints[{i}]['constraint_fn'] must return a scalar, "
                f"got shape {result_array.shape}"
            )

        # Test gradient
        try:
            jax.grad(lambda x: fn(x, dummy_u, dummy_node, dummy_params))(dummy_x)
        except Exception as e:
            raise ValueError(
                f"byof ctcs_constraints[{i}]['constraint_fn'] is not differentiable with JAX: {e}"
            ) from e

        # Validate penalty function if provided
        if "penalty" in ctcs_spec:
            penalty_spec = ctcs_spec["penalty"]
            if callable(penalty_spec):
                # Test custom penalty function
                try:
                    test_residual = jnp.array(0.5)
                    penalty_result = penalty_spec(test_residual)
                    jnp.asarray(penalty_result)
                except Exception as e:
                    raise ValueError(
                        f"byof ctcs_constraints[{i}]['penalty'] custom function failed: {e}"
                    ) from e
            elif penalty_spec not in ["square", "l1", "huber"]:
                raise ValueError(
                    f"byof ctcs_constraints[{i}]['penalty'] must be 'square', 'l1', 'huber', "
                    f"or a callable, got {penalty_spec!r}"
                )

        # Validate idx if provided
        if "idx" in ctcs_spec:
            idx = ctcs_spec["idx"]
            if not isinstance(idx, int):
                raise TypeError(
                    f"byof ctcs_constraints[{i}]['idx'] must be an integer, got {type(idx)}"
                )
            if idx < 0:
                raise ValueError(
                    f"byof ctcs_constraints[{i}]['idx'] must be non-negative, got {idx}"
                )

        # Validate bounds if provided
        if "bounds" in ctcs_spec:
            bounds = ctcs_spec["bounds"]
            if not isinstance(bounds, (tuple, list)) or len(bounds) != 2:
                raise ValueError(
                    f"byof ctcs_constraints[{i}]['bounds'] must be a (min, max) tuple, got {bounds}"
                )
            if bounds[0] > bounds[1]:
                raise ValueError(
                    f"byof ctcs_constraints[{i}]['bounds'] min ({bounds[0]}) must be <= "
                    f"max ({bounds[1]})"
                )
        else:
            # Use default bounds for initial value validation
            bounds = (0.0, 1e-4)

        # Validate initial value is within bounds
        if "initial" in ctcs_spec:
            initial = ctcs_spec["initial"]
            if not (bounds[0] <= initial <= bounds[1]):
                raise ValueError(
                    f"byof ctcs_constraints[{i}]['initial'] ({initial}) must be within "
                    f"bounds [{bounds[0]}, {bounds[1]}]"
                )

        # Validate over (node interval) if provided
        if "over" in ctcs_spec:
            over = ctcs_spec["over"]
            if not isinstance(over, (tuple, list)) or len(over) != 2:
                raise ValueError(
                    f"byof ctcs_constraints[{i}]['over'] must be a (start, end) tuple, got {over}"
                )
            start, end = over
            if not isinstance(start, int) or not isinstance(end, int):
                raise TypeError(
                    f"byof ctcs_constraints[{i}]['over'] indices must be integers, "
                    f"got start={type(start)}, end={type(end)}"
                )
            if start >= end:
                raise ValueError(
                    f"byof ctcs_constraints[{i}]['over'] start ({start}) must be < end ({end})"
                )
            if start < 0:
                raise ValueError(
                    f"byof ctcs_constraints[{i}]['over'] start ({start}) must be non-negative"
                )
            # Validate against trajectory length if N is provided
            if N is not None:
                if end > N:
                    raise ValueError(
                        f"byof ctcs_constraints[{i}]['over'] end ({end}) exceeds "
                        f"trajectory length ({N})"
                    )