Skip to content

expert

Expert-mode features for advanced users.

This module contains features for expert users who need fine-grained control and are willing to bypass higher-level abstractions.

ByofSpec

Bases: BaseModel

Bring-Your-Own-Functions specification for expert users.

Allows bypassing the symbolic layer and directly providing raw JAX functions. All fields are optional - you can mix symbolic and byof as needed.

Warning

You are responsible for:

  • Correct indexing into unified state/control vectors
  • Ensuring functions are JAX-compatible (use jax.numpy, no side effects)
  • Ensuring functions are differentiable
  • Following g(x,u) <= 0 convention for constraints
Tip

Use the .slice property on State/Control objects for cleaner, more maintainable indexing instead of hardcoded indices. For example, use x[velocity.slice] instead of x[2:3]. The slice property is set after preprocessing and provides the correct indices into the unified state/control vectors.

Attributes:

Name Type Description
parameters List[Parameter]

List of :class:~openscvx.symbolic.expr.expr.Parameter objects that are only used inside byof functions (not in any symbolic expression). Parameters already referenced in symbolic dynamics or constraints are collected automatically and do not need to be listed here. Duplicates are ignored.

dynamics Dict[str, DynamicsFunction]

Raw JAX functions for state derivatives. Maps state names to functions with signature (x, u, node, params) -> xdot_component. States here should NOT appear in symbolic dynamics dict. You can mix: some states symbolic, some in byof.

dynamics_discrete Dict[str, DynamicsFunction]

Raw JAX functions for discrete/impulsive state updates. Maps state names to functions with signature (x, u, node, params) -> x_next_component. States here should NOT appear in symbolic dynamics_discrete dict. Use when you need custom impulsive updates (e.g. delta-V) alongside or instead of symbolic ones.

nodal_constraints List[NodalConstraintSpec]

Point-wise constraints applied at specific nodes. Each item is a :class:NodalConstraintSpec dict with:

  • func: Constraint function (x, u, node, params) -> residual (required)
  • nodes: List of node indices (optional, defaults to all nodes)

Follows g(x,u) <= 0 convention.

cross_nodal_constraints List[CrossNodalConstraintFunction]

Constraints coupling multiple nodes (smoothness, rate limits). Signature: (X, U, params) -> residual where X is (N, n_x) and U is (N, n_u). N is the number of trajectory nodes, n_x is state dimension, n_u is control dimension. Follows g(X,U) <= 0 convention.

ctcs_constraints List[CtcsConstraintSpec]

Continuous-time constraint satisfaction via dynamics augmentation. Each adds an augmented state accumulating violation penalties. See :class:CtcsConstraintSpec for details.

Example

Custom dynamics and constraints::

import jax.numpy as jnp
import openscvx as ox
from openscvx import ByofSpec

# Define states and controls
position = ox.State("position", shape=(2,))
velocity = ox.State("velocity", shape=(1,))
theta = ox.Control("theta", shape=(1,))

# Parameters used by byof functions
g = ox.Parameter("g", shape=(), value=9.81)

# Custom dynamics for one state using .slice property
def custom_velocity_dynamics(x, u, node, params):
    # Use .slice property for clean indexing
    return params["g"] * jnp.cos(u[theta.slice][0])

byof: ByofSpec = {
    "parameters": [g],
    "dynamics": {
        "velocity": custom_velocity_dynamics,
    },
    "dynamics_discrete": {
        # Optional: e.g. "velocity": lambda x, u, node, params: x[vel_sl] + u[dv_sl]
    },
    "nodal_constraints": [
        # Applied to all nodes (no "nodes" field)
        {
            "constraint_fn": lambda x, u, node, params: x[velocity.slice][0] - 10.0,
        },
        {
            "constraint_fn": lambda x, u, node, params: -x[velocity.slice][0],
        },
        # Specify nodes for selective enforcement
        {
            "constraint_fn": lambda x, u, node, params: x[velocity.slice][0],
            "nodes": [0],  # Velocity must be exactly 0 at start
        },
    ],
    "cross_nodal_constraints": [
        # Constrain total velocity across trajectory: sum(velocities) >= 5
        # X.shape = (N, n_x), extract velocity column using slice
        lambda X, U, params: 5.0 - jnp.sum(X[:, velocity.slice]),
    ],
    "ctcs_constraints": [
        {
            "constraint_fn": lambda x, u, node, params: x[position.slice][0] - 5.0,
            "penalty": "square",
        }
    ],
}
Source code in openscvx/expert/byof.py
class ByofSpec(BaseModel):
    """Bring-Your-Own-Functions specification for expert users.

    Allows bypassing the symbolic layer and directly providing raw JAX functions.
    All fields are optional - you can mix symbolic and byof as needed.

    Warning:
        You are responsible for:

        - Correct indexing into unified state/control vectors
        - Ensuring functions are JAX-compatible (use jax.numpy, no side effects)
        - Ensuring functions are differentiable
        - Following g(x,u) <= 0 convention for constraints

    Tip:
        Use the ``.slice`` property on State/Control objects for cleaner, more
        maintainable indexing instead of hardcoded indices. For example, use
        ``x[velocity.slice]`` instead of ``x[2:3]``. The slice property is set
        after preprocessing and provides the correct indices into the unified
        state/control vectors.

    Attributes:
        parameters: List of :class:`~openscvx.symbolic.expr.expr.Parameter` objects that
            are only used inside byof functions (not in any symbolic expression). Parameters
            already referenced in symbolic dynamics or constraints are collected automatically
            and do not need to be listed here. Duplicates are ignored.
        dynamics: Raw JAX functions for state derivatives. Maps state names to functions
            with signature ``(x, u, node, params) -> xdot_component``. States here should
            NOT appear in symbolic dynamics dict. You can mix: some states symbolic,
            some in byof.
        dynamics_discrete: Raw JAX functions for discrete/impulsive state updates. Maps state
            names to functions with signature ``(x, u, node, params) -> x_next_component``.
            States here should NOT appear in symbolic dynamics_discrete dict. Use when you need
            custom impulsive updates (e.g. delta-V) alongside or instead of symbolic ones.
        nodal_constraints: Point-wise constraints applied at specific nodes.
            Each item is a :class:`NodalConstraintSpec` dict with:

            - ``func``: Constraint function ``(x, u, node, params) -> residual`` (required)
            - ``nodes``: List of node indices (optional, defaults to all nodes)

            Follows g(x,u) <= 0 convention.
        cross_nodal_constraints: Constraints coupling multiple nodes (smoothness, rate limits).
            Signature: ``(X, U, params) -> residual`` where X is (N, n_x) and U is (N, n_u).
            N is the number of trajectory nodes, n_x is state dimension, n_u is control dimension.
            Follows g(X,U) <= 0 convention.
        ctcs_constraints: Continuous-time constraint satisfaction via dynamics augmentation.
            Each adds an augmented state accumulating violation penalties.
            See :class:`CtcsConstraintSpec` for details.

    Example:
        Custom dynamics and constraints::

            import jax.numpy as jnp
            import openscvx as ox
            from openscvx import ByofSpec

            # Define states and controls
            position = ox.State("position", shape=(2,))
            velocity = ox.State("velocity", shape=(1,))
            theta = ox.Control("theta", shape=(1,))

            # Parameters used by byof functions
            g = ox.Parameter("g", shape=(), value=9.81)

            # Custom dynamics for one state using .slice property
            def custom_velocity_dynamics(x, u, node, params):
                # Use .slice property for clean indexing
                return params["g"] * jnp.cos(u[theta.slice][0])

            byof: ByofSpec = {
                "parameters": [g],
                "dynamics": {
                    "velocity": custom_velocity_dynamics,
                },
                "dynamics_discrete": {
                    # Optional: e.g. "velocity": lambda x, u, node, params: x[vel_sl] + u[dv_sl]
                },
                "nodal_constraints": [
                    # Applied to all nodes (no "nodes" field)
                    {
                        "constraint_fn": lambda x, u, node, params: x[velocity.slice][0] - 10.0,
                    },
                    {
                        "constraint_fn": lambda x, u, node, params: -x[velocity.slice][0],
                    },
                    # Specify nodes for selective enforcement
                    {
                        "constraint_fn": lambda x, u, node, params: x[velocity.slice][0],
                        "nodes": [0],  # Velocity must be exactly 0 at start
                    },
                ],
                "cross_nodal_constraints": [
                    # Constrain total velocity across trajectory: sum(velocities) >= 5
                    # X.shape = (N, n_x), extract velocity column using slice
                    lambda X, U, params: 5.0 - jnp.sum(X[:, velocity.slice]),
                ],
                "ctcs_constraints": [
                    {
                        "constraint_fn": lambda x, u, node, params: x[position.slice][0] - 5.0,
                        "penalty": "square",
                    }
                ],
            }
    """

    parameters: List[Parameter] = []
    dynamics: Dict[str, DynamicsFunction] = {}
    dynamics_discrete: Dict[str, DynamicsFunction] = {}
    nodal_constraints: List[NodalConstraintSpec] = []
    cross_nodal_constraints: List[CrossNodalConstraintFunction] = []
    ctcs_constraints: List[CtcsConstraintSpec] = []

    model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)

CtcsConstraintSpec

Bases: BaseModel

Specification for CTCS (Continuous-Time Constraint Satisfaction) constraint.

CTCS constraints are enforced by augmenting the dynamics with a penalty term that accumulates violations over time. Useful for path constraints that must be satisfied continuously, not just at discrete nodes.

Attributes:

Name Type Description
constraint_fn CtcsConstraintFunction

Function computing constraint residual with signature (x, u, node, params) -> scalar. Must return scalar. Follows g(x,u) <= 0 convention (negative = satisfied). Required field.

penalty PenaltyFunction

Penalty function for positive residuals (violations). Built-in options: "square" (max(r,0)^2, default), "l1" (max(r,0)), "huber" (Huber loss). Custom: Callable (r) -> penalty (non-negative, differentiable).

bounds Tuple[float, float]

(min, max) bounds for augmented state accumulating penalties. Default: (0.0, 1e-4). Max acts as soft constraint on total violation.

initial Optional[float]

Initial value for augmented state. Default: bounds[0] (usually 0.0).

over Optional[Tuple[int, int]]

Node interval (start, end) where constraint is active. The constraint is enforced for nodes in [start, end). If omitted, constraint is active over all nodes. Matches symbolic .over() method behavior.

idx int

Constraint group index for sharing augmented states (default: 0). All CTCS constraints (symbolic and byof) with the same idx share a single augmented state. Their penalties are summed together. Use different idx values to track different types of violations separately.

Warning

If symbolic CTCS constraints exist with idx values [0, 1, 2], then byof idx must either:

  • Match an existing idx (e.g., 0, 1, or 2) to add to that augmented state
  • Be sequential after them (e.g., 3, 4, 5) to create new augmented states

You cannot use idx values that create gaps (e.g., if symbolic has [0, 1], you cannot use byof idx=3 without also using idx=2).

Example

Enforce position[0] <= 10.0 continuously::

# Assuming position = ox.State("position", shape=(2,))
ctcs_spec: CtcsConstraintSpec = {
    "constraint_fn": lambda x, u, node, params: x[position.slice][0] - 10.0,
    "penalty": "square",
    "bounds": (0.0, 1e-4),
    "initial": 0.0,
    "idx": 0,  # Groups with other constraints having idx=0
}

Enforce constraint only over specific node range::

ctcs_spec: CtcsConstraintSpec = {
    "constraint_fn": lambda x, u, node, params: x[position.slice][0] - 10.0,
    "over": (10, 50),  # Active only for nodes 10-49
    "penalty": "square",
}

Multiple constraints sharing an augmented state::

# If symbolic CTCS already has idx=[0, 1], then:

byof = {
    "ctcs_constraints": [
        # Add to existing symbolic idx=0 augmented state
        {
            "constraint_fn": lambda x, u, node, params: x[pos.slice][0] - 10.0,
            "idx": 0,  # Shares with symbolic idx=0
        },
        # Add to existing symbolic idx=1 augmented state
        {
            "constraint_fn": lambda x, u, node, params: x[vel.slice][0] - 5.0,
            "idx": 1,  # Shares with symbolic idx=1
        },
        # Create NEW augmented state (sequential after symbolic)
        {
            "constraint_fn": lambda x, u, node, params: x[pos.slice][1] - 8.0,
            "idx": 2,  # New state (symbolic has 0,1, so next is 2)
        },
    ]
}
Source code in openscvx/expert/byof.py
class CtcsConstraintSpec(BaseModel):
    """Specification for CTCS (Continuous-Time Constraint Satisfaction) constraint.

    CTCS constraints are enforced by augmenting the dynamics with a penalty term that
    accumulates violations over time. Useful for path constraints that must be satisfied
    continuously, not just at discrete nodes.

    Attributes:
        constraint_fn: Function computing constraint residual with signature
            ``(x, u, node, params) -> scalar``. Must return scalar.
            Follows g(x,u) <= 0 convention (negative = satisfied). Required field.
        penalty: Penalty function for positive residuals (violations).
            Built-in options: "square" (max(r,0)^2, default), "l1" (max(r,0)),
            "huber" (Huber loss). Custom: Callable ``(r) -> penalty`` (non-negative,
            differentiable).
        bounds: (min, max) bounds for augmented state accumulating penalties.
            Default: (0.0, 1e-4). Max acts as soft constraint on total violation.
        initial: Initial value for augmented state. Default: bounds[0] (usually 0.0).
        over: Node interval (start, end) where constraint is active. The constraint
            is enforced for nodes in [start, end). If omitted, constraint is active
            over all nodes. Matches symbolic `.over()` method behavior.
        idx: Constraint group index for sharing augmented states (default: 0).
            All CTCS constraints (symbolic and byof) with the same idx share a single
            augmented state. Their penalties are summed together. Use different idx values
            to track different types of violations separately.

    Warning:
        If symbolic CTCS constraints exist with idx values [0, 1, 2], then byof idx **must** either:

        - Match an existing idx (e.g., 0, 1, or 2) to add to that augmented state
        - Be sequential after them (e.g., 3, 4, 5) to create new augmented states

        You cannot use idx values that create gaps (e.g., if symbolic has [0, 1],
        you cannot use byof idx=3 without also using idx=2).

    Example:
        Enforce position[0] <= 10.0 continuously::

            # Assuming position = ox.State("position", shape=(2,))
            ctcs_spec: CtcsConstraintSpec = {
                "constraint_fn": lambda x, u, node, params: x[position.slice][0] - 10.0,
                "penalty": "square",
                "bounds": (0.0, 1e-4),
                "initial": 0.0,
                "idx": 0,  # Groups with other constraints having idx=0
            }

        Enforce constraint only over specific node range::

            ctcs_spec: CtcsConstraintSpec = {
                "constraint_fn": lambda x, u, node, params: x[position.slice][0] - 10.0,
                "over": (10, 50),  # Active only for nodes 10-49
                "penalty": "square",
            }

        Multiple constraints sharing an augmented state::

            # If symbolic CTCS already has idx=[0, 1], then:

            byof = {
                "ctcs_constraints": [
                    # Add to existing symbolic idx=0 augmented state
                    {
                        "constraint_fn": lambda x, u, node, params: x[pos.slice][0] - 10.0,
                        "idx": 0,  # Shares with symbolic idx=0
                    },
                    # Add to existing symbolic idx=1 augmented state
                    {
                        "constraint_fn": lambda x, u, node, params: x[vel.slice][0] - 5.0,
                        "idx": 1,  # Shares with symbolic idx=1
                    },
                    # Create NEW augmented state (sequential after symbolic)
                    {
                        "constraint_fn": lambda x, u, node, params: x[pos.slice][1] - 8.0,
                        "idx": 2,  # New state (symbolic has 0,1, so next is 2)
                    },
                ]
            }
    """

    constraint_fn: CtcsConstraintFunction
    penalty: PenaltyFunction = "square"
    bounds: Tuple[float, float] = (0.0, 1e-4)
    initial: Optional[float] = None
    over: Optional[Tuple[int, int]] = None
    idx: int = 0

    model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)

NodalConstraintSpec

Bases: BaseModel

Specification for nodal constraint with optional node selection.

Nodal constraints are point-wise constraints evaluated at specific trajectory nodes. By default, constraints apply to all nodes, but you can restrict enforcement to specific nodes for boundary conditions, waypoints, or computational efficiency.

Attributes:

Name Type Description
constraint_fn NodalConstraintFunction

Constraint function with signature (x, u, node, params) -> residual. Follows g(x,u) <= 0 convention (negative = satisfied). Required field.

nodes Optional[List[int]]

List of integer node indices where constraint is enforced. If omitted, applies to all nodes. Negative indices supported (e.g., -1 for last). Optional field.

Example

Boundary constraint only at first and last nodes::

nodal_spec: NodalConstraintSpec = {
    "constraint_fn": lambda x, u, node, params: x[velocity.slice][0],
    "nodes": [0, -1],  # Only at start and end
}

Waypoint constraint at middle of trajectory::

nodal_spec: NodalConstraintSpec = {
    "constraint_fn": lambda x, u, node, params: jnp.linalg.norm(
        x[position.slice] - jnp.array([5.0, 7.5])
    ) - 0.1,
    "nodes": [N // 2],
}
Source code in openscvx/expert/byof.py
class NodalConstraintSpec(BaseModel):
    """Specification for nodal constraint with optional node selection.

    Nodal constraints are point-wise constraints evaluated at specific trajectory nodes.
    By default, constraints apply to all nodes, but you can restrict enforcement to
    specific nodes for boundary conditions, waypoints, or computational efficiency.

    Attributes:
        constraint_fn: Constraint function with signature ``(x, u, node, params) -> residual``.
            Follows g(x,u) <= 0 convention (negative = satisfied). Required field.
        nodes: List of integer node indices where constraint is enforced.
            If omitted, applies to all nodes. Negative indices supported (e.g., -1 for last).
            Optional field.

    Example:
        Boundary constraint only at first and last nodes::

            nodal_spec: NodalConstraintSpec = {
                "constraint_fn": lambda x, u, node, params: x[velocity.slice][0],
                "nodes": [0, -1],  # Only at start and end
            }

        Waypoint constraint at middle of trajectory::

            nodal_spec: NodalConstraintSpec = {
                "constraint_fn": lambda x, u, node, params: jnp.linalg.norm(
                    x[position.slice] - jnp.array([5.0, 7.5])
                ) - 0.1,
                "nodes": [N // 2],
            }
    """

    constraint_fn: NodalConstraintFunction
    nodes: Optional[List[int]] = None

    model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)

apply_byof(byof: ByofSpec, dynamics: Dynamics, dynamics_prop: Dynamics, dynamics_discrete: Dynamics, jax_constraints: LoweredJaxConstraints, x_unified: UnifiedState, x_prop_unified: UnifiedState, u_unified: UnifiedState, states: List[State], states_prop: List[State], N: int) -> Tuple[Dynamics, Dynamics, Dynamics, LoweredJaxConstraints, UnifiedState, UnifiedState]

Apply bring-your-own-functions (byof) to augment lowered problem.

Handles raw JAX functions provided by expert users, including: - dynamics: Raw JAX functions for specific state derivatives - nodal_constraints: Point-wise constraints at each node - cross_nodal_constraints: Constraints coupling multiple nodes - ctcs_constraints: Continuous-time constraint satisfaction via dynamics augmentation

Parameters:

Name Type Description Default
byof ByofSpec

Dict with keys "dynamics", "nodal_constraints", "cross_nodal_constraints", "ctcs_constraints"

required
dynamics Dynamics

Lowered optimization dynamics to potentially augment

required
dynamics_prop Dynamics

Lowered propagation dynamics to potentially augment

required
dynamics_discrete Dynamics

Lowered discrete dynamics (for impulsive solver); extended when new CTCS states are added so its output dimension matches the unified state.

required
jax_constraints LoweredJaxConstraints

Lowered JAX constraints to append to

required
x_unified UnifiedState

Unified optimization state interface to potentially augment

required
x_prop_unified UnifiedState

Unified propagation state interface to potentially augment

required
u_unified UnifiedState

Unified control interface for validation

required
states List[State]

List of State objects for optimization (with _slice attributes)

required
states_prop List[State]

List of State objects for propagation (with _slice attributes)

required
N int

Number of nodes in the trajectory

required

Returns:

Type Description
Dynamics

Tuple of (dynamics, dynamics_prop, dynamics_discrete, jax_constraints,

Dynamics

x_unified, x_prop_unified)

Example

(dynamics, dynamics_prop, dynamics_discrete, constraints, x_unified, ... x_prop_unified) = apply_byof( ... byof, dynamics, dynamics_prop, dynamics_discrete, jax_constraints, ... x_unified, x_prop_unified, u_unified, states, states_prop, N ... )

Source code in openscvx/expert/lowering.py
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
def apply_byof(
    byof: "ByofSpec",
    dynamics: Dynamics,
    dynamics_prop: Dynamics,
    dynamics_discrete: Dynamics,
    jax_constraints: LoweredJaxConstraints,
    x_unified: "UnifiedState",
    x_prop_unified: "UnifiedState",
    u_unified: "UnifiedState",
    states: List["State"],
    states_prop: List["State"],
    N: int,
) -> Tuple[Dynamics, Dynamics, Dynamics, LoweredJaxConstraints, "UnifiedState", "UnifiedState"]:
    """Apply bring-your-own-functions (byof) to augment lowered problem.

    Handles raw JAX functions provided by expert users, including:
    - dynamics: Raw JAX functions for specific state derivatives
    - nodal_constraints: Point-wise constraints at each node
    - cross_nodal_constraints: Constraints coupling multiple nodes
    - ctcs_constraints: Continuous-time constraint satisfaction via dynamics augmentation

    Args:
        byof: Dict with keys "dynamics", "nodal_constraints", "cross_nodal_constraints",
            "ctcs_constraints"
        dynamics: Lowered optimization dynamics to potentially augment
        dynamics_prop: Lowered propagation dynamics to potentially augment
        dynamics_discrete: Lowered discrete dynamics (for impulsive solver); extended when
            new CTCS states are added so its output dimension matches the unified state.
        jax_constraints: Lowered JAX constraints to append to
        x_unified: Unified optimization state interface to potentially augment
        x_prop_unified: Unified propagation state interface to potentially augment
        u_unified: Unified control interface for validation
        states: List of State objects for optimization (with _slice attributes)
        states_prop: List of State objects for propagation (with _slice attributes)
        N: Number of nodes in the trajectory

    Returns:
        Tuple of (dynamics, dynamics_prop, dynamics_discrete, jax_constraints,
        x_unified, x_prop_unified)

    Example:
        >>> (dynamics, dynamics_prop, dynamics_discrete, constraints, x_unified,
        ...     x_prop_unified) = apply_byof(
        ...     byof, dynamics, dynamics_prop, dynamics_discrete, jax_constraints,
        ...     x_unified, x_prop_unified, u_unified, states, states_prop, N
        ... )
    """

    # Note: byof validation happens earlier in Problem.__init__ to fail fast
    # Handle byof dynamics by splicing in raw JAX functions at the correct slices
    byof_dynamics = byof.dynamics
    if byof_dynamics:
        # Build mapping from state name to slice for optimization states
        state_slices = {state.name: state._slice for state in states}
        state_slices_prop = {state.name: state._slice for state in states_prop}

        # Time-dilation slice for multiplying byof outputs by s
        td_slice = u_unified.time_dilation_slice

        def _make_composite_dynamics(orig_f, byof_fns, slices_map, td_sl):
            """Create composite dynamics combining symbolic and byof state derivatives.

            This factory splices user-provided byof dynamics into the unified dynamics
            function at the appropriate slice indices, replacing the symbolic dynamics
            for specific states while preserving the rest. The byof outputs are
            multiplied by the time-dilation factor s to match the symbolic dynamics
            which already include s * f(x, u) via the Mul node.

            Args:
                orig_f: Original unified dynamics (x, u, node, params) -> xdot
                byof_fns: Dict mapping state names to byof dynamics functions
                slices_map: Dict mapping state names to slice objects for indexing
                td_sl: Slice for the time-dilation control in the unified u vector

            Returns:
                Composite dynamics function with byof derivatives spliced in
            """

            def composite_f(x, u, node, params):
                # Start with symbolic/default dynamics for all states
                xdot = orig_f(x, u, node, params)

                # Time-dilation factor (symbolic dynamics already include s *)
                s = u[td_sl]

                # Splice in byof dynamics for specific states, multiplied by s
                for state_name, byof_fn in byof_fns.items():
                    sl = slices_map[state_name]
                    # Replace the derivative for this state with s * byof result
                    xdot = xdot.at[sl].set(s * byof_fn(x, u, node, params))

                return xdot

            return composite_f

        # Create composite optimization dynamics
        # Jacobians are computed by the discretizer, not here.
        composite_f = _make_composite_dynamics(dynamics.f, byof_dynamics, state_slices, td_slice)
        dynamics = Dynamics(f=composite_f)

        # Create composite propagation dynamics
        composite_f_prop = _make_composite_dynamics(
            dynamics_prop.f, byof_dynamics, state_slices_prop, td_slice
        )
        dynamics_prop = Dynamics(f=composite_f_prop)

    # Handle byof dynamics_discrete by splicing in raw JAX functions at the correct slices
    byof_dynamics_discrete = byof.dynamics_discrete
    if byof_dynamics_discrete:
        state_slices = {state.name: state._slice for state in states}

        def _make_composite_dynamics_discrete(orig_f, byof_fns, slices_map):
            """Create composite discrete dynamics combining symbolic and byof state updates."""

            def composite_f(x, u, node, params):
                x_next = orig_f(x, u, node, params)
                for state_name, byof_fn in byof_fns.items():
                    sl = slices_map[state_name]
                    x_next = x_next.at[sl].set(byof_fn(x, u, node, params))
                return x_next

            return composite_f

        dynamics_discrete = Dynamics(
            f=_make_composite_dynamics_discrete(
                dynamics_discrete.f, byof_dynamics_discrete, state_slices
            )
        )

    # Handle nodal constraints
    # Note: Validation happens earlier in Problem.__init__ via validate_byof
    for constraint_spec in byof.nodal_constraints:
        fn = constraint_spec.constraint_fn
        nodes = constraint_spec.nodes if constraint_spec.nodes is not None else list(range(N))

        # Normalize negative node indices (validation already done in validate_byof)
        normalized_nodes = [node if node >= 0 else N + node for node in nodes]

        constraint = LoweredNodalConstraint(
            func=jax.vmap(fn, in_axes=(0, 0, None, None)),
            grad_g_x=jax.vmap(jacfwd(fn, argnums=0), in_axes=(0, 0, None, None)),
            grad_g_u=jax.vmap(jacfwd(fn, argnums=1), in_axes=(0, 0, None, None)),
            nodes=normalized_nodes,
        )
        jax_constraints.nodal.append(constraint)

    # Handle cross-nodal constraints
    for fn in byof.cross_nodal_constraints:
        constraint = LoweredCrossNodeConstraint(
            func=fn,
            grad_g_X=jacfwd(fn, argnums=0),
            grad_g_U=jacfwd(fn, argnums=1),
        )
        jax_constraints.cross_node.append(constraint)

    # Handle CTCS constraints by augmenting dynamics
    # Built-in penalty functions
    def _penalty_square(r):
        return jnp.maximum(r, 0.0) ** 2

    def _penalty_l1(r):
        return jnp.maximum(r, 0.0)

    def _penalty_huber(r, delta=1.0):
        abs_r = jnp.maximum(r, 0.0)
        return jnp.where(abs_r <= delta, 0.5 * abs_r**2, delta * (abs_r - 0.5 * delta))

    _PENALTY_FUNCTIONS = {
        "square": _penalty_square,
        "l1": _penalty_l1,
        "huber": _penalty_huber,
    }

    # Determine which symbolic CTCS idx values already exist
    # Symbolic augmented states are named "_ctcs_aug_{i}" where i is sequential
    # and corresponds to sorted symbolic idx values (0, 1, 2, ...)
    symbolic_ctcs_idx = []
    for state in states:
        if state.name.startswith("_ctcs_aug_"):
            try:
                aug_idx = int(state.name.split("_")[-1])
                symbolic_ctcs_idx.append(aug_idx)
            except (ValueError, IndexError):
                pass

    # Symbolic CTCS creates augmented states with sequential idx: 0, 1, 2, ...
    # so max_symbolic_idx = len(symbolic_ctcs_idx) - 1 (or -1 if none exist)
    max_symbolic_idx = len(symbolic_ctcs_idx) - 1 if symbolic_ctcs_idx else -1

    # Build idx -> augmented_state_slice mapping for existing symbolic CTCS
    # Augmented states appear after regular states in the unified vector
    # We'll determine the slice by finding the state in the states list
    idx_to_aug_slice = {}
    for state in states:
        if state.name.startswith("_ctcs_aug_"):
            try:
                aug_idx = int(state.name.split("_")[-1])
                # The actual idx value IS the sequential index for symbolic CTCS
                # (they're created with idx 0, 1, 2, ... in sorted order)
                idx_to_aug_slice[aug_idx] = state._slice
            except (ValueError, IndexError, AttributeError):
                pass

    # Group BYOF CTCS constraints by idx
    byof_ctcs_groups = {}
    for ctcs_spec in byof.ctcs_constraints:
        if ctcs_spec.idx not in byof_ctcs_groups:
            byof_ctcs_groups[ctcs_spec.idx] = []
        byof_ctcs_groups[ctcs_spec.idx].append(ctcs_spec)

    # Validate that byof idx values don't create gaps
    # All idx must form contiguous sequence: [0, 1, 2, ..., max_idx]
    if byof_ctcs_groups:
        all_idx = sorted(set(range(max_symbolic_idx + 1)) | set(byof_ctcs_groups.keys()))
        expected_idx = list(range(len(all_idx)))
        if all_idx != expected_idx:
            raise ValueError(
                f"BYOF CTCS idx values create non-contiguous sequence. "
                f"Symbolic CTCS has idx=[{', '.join(map(str, range(max_symbolic_idx + 1)))}], "
                f"combined with byof idx={sorted(byof_ctcs_groups.keys())} gives {all_idx}. "
                f"Expected contiguous sequence {expected_idx}. "
                f"Byof idx must either match existing symbolic idx or be sequential after them."
            )

    # Process each idx group
    for idx in sorted(byof_ctcs_groups.keys()):
        specs = byof_ctcs_groups[idx]

        # Collect all penalty functions for this idx
        penalty_fns = []
        for spec in specs:
            constraint_fn = spec.constraint_fn
            penalty_spec = spec.penalty
            over_interval = spec.over

            if callable(penalty_spec):
                penalty_func = penalty_spec
            else:
                penalty_func = _PENALTY_FUNCTIONS[penalty_spec]

            # Create a combined constraint+penalty function
            def _make_penalty_fn(cons_fn, pen_func, over):
                """Factory to capture constraint, penalty functions, and node interval.

                Args:
                    cons_fn: Constraint function (x, u, node, params) -> scalar residual
                    pen_func: Penalty function (residual) -> penalty value
                    over: Optional (start, end) tuple for conditional activation

                Returns:
                    Penalty function that conditionally activates based on node interval
                """

                def penalty_fn(x, u, node, params):
                    # Compute penalty for the constraint violation
                    residual = cons_fn(x, u, node, params)
                    penalty_value = pen_func(residual)

                    # Apply conditional logic if over interval is specified
                    if over is not None:
                        start_node, end_node = over
                        # Extract scalar from node (which may be array or scalar)
                        # Keep as JAX array for tracing compatibility
                        node_scalar = jnp.atleast_1d(node)[0]
                        is_active = (start_node <= node_scalar) & (node_scalar < end_node)

                        # Use jax.lax.cond for JAX-traceable conditional evaluation
                        # Penalty is active only when node is in [start, end)
                        return cond(
                            is_active,
                            lambda _: penalty_value,
                            lambda _: 0.0,
                            operand=None,
                        )
                    else:
                        # Always active if no interval specified
                        return penalty_value

                return penalty_fn

            penalty_fns.append(_make_penalty_fn(constraint_fn, penalty_func, over_interval))

        # Time-dilation slice for multiplying byof CTCS penalties by s
        td_slice = u_unified.time_dilation_slice

        if idx in idx_to_aug_slice:
            # This idx already exists from symbolic CTCS - add penalties to existing state
            aug_slice = idx_to_aug_slice[idx]

            def _make_ctcs_addition(orig_f, pen_fns, aug_sl, td_sl):
                """Create dynamics that adds penalties to existing augmented state.

                The penalty is multiplied by the time-dilation factor s to match
                the symbolic dynamics which already include s * f(x, u).

                Args:
                    orig_f: Original dynamics function
                    pen_fns: List of penalty functions to add
                    aug_sl: Slice of the augmented state to modify
                    td_sl: Slice for the time-dilation control

                Returns:
                    Modified dynamics function
                """

                def modified_f(x, u, node, params):
                    xdot = orig_f(x, u, node, params)

                    # Sum all penalties for this idx, scaled by time-dilation
                    s = u[td_sl]
                    total_penalty = s * sum(pen_fn(x, u, node, params) for pen_fn in pen_fns)

                    # Add to existing augmented state derivative
                    current_deriv = xdot[aug_sl]
                    xdot = xdot.at[aug_sl].set(current_deriv + total_penalty)

                    return xdot

                return modified_f

            # Modify both optimization and propagation dynamics
            # Jacobians are computed by the discretizer, not here.
            dynamics.f = _make_ctcs_addition(dynamics.f, penalty_fns, aug_slice, td_slice)
            dynamics_prop.f = _make_ctcs_addition(dynamics_prop.f, penalty_fns, aug_slice, td_slice)

        else:
            # New idx - create new augmented state
            # Use bounds/initial from first spec in this group
            first_spec = specs[0]
            bounds = first_spec.bounds
            initial = first_spec.initial if first_spec.initial is not None else bounds[0]

            def _make_ctcs_new_state(orig_f, pen_fns, td_sl):
                """Create dynamics augmented with new CTCS state.

                The penalty is multiplied by the time-dilation factor s to match
                the symbolic dynamics which already include s * f(x, u).

                Args:
                    orig_f: Original dynamics function
                    pen_fns: List of penalty functions to sum
                    td_sl: Slice for the time-dilation control

                Returns:
                    Augmented dynamics function
                """

                def augmented_f(x, u, node, params):
                    xdot = orig_f(x, u, node, params)

                    # Sum all penalties for this new idx, scaled by time-dilation
                    s = u[td_sl]
                    total_penalty = s * sum(pen_fn(x, u, node, params) for pen_fn in pen_fns)

                    # Append as new augmented state derivative
                    return jnp.concatenate([xdot, jnp.atleast_1d(total_penalty)])

                return augmented_f

            def _extend_discrete_dynamics(orig_discrete_f, aug_sl):
                """Extend discrete dynamics so output dim matches unified state.

                For impulsive/discrete updates, new CTCS augmented states are
                pass-through: x_next[aug] = x_curr[aug], so we append x[aug_sl]
                to the discrete dynamics output.
                """

                def extended_f(x, u, node, params):
                    out = orig_discrete_f(x, u, node, params)
                    return jnp.concatenate([out, jnp.atleast_1d(x[aug_sl])])

                return extended_f

            # Augment optimization dynamics
            # Jacobians are computed by the discretizer, not here.
            aug_f = _make_ctcs_new_state(dynamics.f, penalty_fns, td_slice)
            dynamics = Dynamics(f=aug_f)

            # Augment propagation dynamics
            aug_f_prop = _make_ctcs_new_state(dynamics_prop.f, penalty_fns, td_slice)
            dynamics_prop = Dynamics(f=aug_f_prop)

            # Extend discrete dynamics so impulsive solver sees full state dimension
            current_dim = x_unified.shape[0]
            aug_slice = slice(current_dim, current_dim + 1)
            dynamics_discrete = Dynamics(
                f=_extend_discrete_dynamics(dynamics_discrete.f, aug_slice)
            )

            # Create State objects for the new augmented states
            # This is necessary for CVXPy variable creation and other bookkeeping
            from openscvx.symbolic.expr.state import State

            # Create augmented state for optimization
            aug_state = State(f"_ctcs_aug_{idx}", shape=(1,))
            aug_state.min = np.array([bounds[0]])
            aug_state.max = np.array([bounds[1]])
            aug_state.initial = np.array([initial])
            aug_state.final = [("free", 0.0)]
            aug_state.guess = np.full((N, 1), initial)

            # Set _slice attribute for the new state
            current_dim = x_unified.shape[0]
            aug_state._slice = slice(current_dim, current_dim + 1)

            # Append to states list (in-place modification visible to caller)
            states.append(aug_state)

            # Create augmented state for propagation
            aug_state_prop = State(f"_ctcs_aug_{idx}", shape=(1,))
            aug_state_prop.min = np.array([bounds[0]])
            aug_state_prop.max = np.array([bounds[1]])
            aug_state_prop.initial = np.array([initial])
            aug_state_prop.final = [("free", 0.0)]
            aug_state_prop.guess = np.full((N, 1), initial)

            # Set _slice attribute for the propagation state
            current_dim_prop = x_prop_unified.shape[0]
            aug_state_prop._slice = slice(current_dim_prop, current_dim_prop + 1)

            # Append to states_prop list
            states_prop.append(aug_state_prop)

            # Add new augmented states to both unified state interfaces
            x_unified.append(
                min=bounds[0],
                max=bounds[1],
                guess=initial,
                initial=initial,
                final=0.0,
                augmented=True,
            )
            x_prop_unified.append(
                min=bounds[0],
                max=bounds[1],
                guess=initial,
                initial=initial,
                final=0.0,
                augmented=True,
            )

    return dynamics, dynamics_prop, dynamics_discrete, jax_constraints, x_unified, x_prop_unified

validate_byof(byof: Union[ByofSpec, dict], states: List[State], n_x: int, n_u: int, N: int = None, parameters: dict = None) -> None

Validate byof function signatures and shapes.

Checks that user-provided functions have the correct signatures and return appropriate shapes. Performs validation before functions are used to provide clear error messages.

Accepts both a :class:ByofSpec and a raw dict (which is validated through ByofSpec.model_validate first, catching unknown keys and missing fields).

Parameters:

Name Type Description Default
byof Union[ByofSpec, dict]

ByofSpec or raw dict with user-provided functions

required
states List[State]

List of State objects for determining expected shapes

required
n_x int

Total dimension of the unified state vector

required
n_u int

Total dimension of the unified control vector

required
N int

Number of nodes in the trajectory (optional). If provided, validates node indices in nodal constraints.

None

Raises:

Type Description
ValueError

If any function has invalid signature or returns wrong shape

TypeError

If functions are not callable

Example

validate_byof(byof, states, n_x=10, n_u=3, N=50) # Raises if invalid

Source code in openscvx/expert/validation.py
def validate_byof(
    byof: Union[ByofSpec, dict],
    states: List["State"],
    n_x: int,
    n_u: int,
    N: int = None,
    parameters: dict = None,
) -> None:
    """Validate byof function signatures and shapes.

    Checks that user-provided functions have the correct signatures and return
    appropriate shapes. Performs validation before functions are used to provide
    clear error messages.

    Accepts both a :class:`ByofSpec` and a raw dict (which is validated through
    ``ByofSpec.model_validate`` first, catching unknown keys and missing fields).

    Args:
        byof: ByofSpec or raw dict with user-provided functions
        states: List of State objects for determining expected shapes
        n_x: Total dimension of the unified state vector
        n_u: Total dimension of the unified control vector
        N: Number of nodes in the trajectory (optional). If provided, validates
            node indices in nodal constraints.

    Raises:
        ValueError: If any function has invalid signature or returns wrong shape
        TypeError: If functions are not callable

    Example:
        >>> validate_byof(byof, states, n_x=10, n_u=3, N=50)  # Raises if invalid
    """
    import jax
    import jax.numpy as jnp

    # Convert raw dict to ByofSpec (validates keys, required fields, types)
    if not isinstance(byof, ByofSpec):
        byof = ByofSpec.model_validate(byof)

    # Validate byof parameters
    from openscvx.symbolic.expr.parameter import Parameter

    for i, param in enumerate(byof.parameters):
        if not isinstance(param, Parameter):
            raise TypeError(f"byof parameters[{i}] must be a Parameter object, got {type(param)}")

    # Create dummy inputs for testing
    dummy_x = jnp.zeros(n_x)
    dummy_u = jnp.zeros(n_u)
    dummy_node = 0
    dummy_params = dict(parameters) if parameters else {}

    # Validate dynamics functions
    byof_dynamics = byof.dynamics
    if byof_dynamics:
        # Build mapping from state name to expected shape
        state_shapes = {state.name: state.shape for state in states}

        for state_name, fn in byof_dynamics.items():
            if state_name not in state_shapes:
                raise ValueError(
                    f"byof dynamics '{state_name}' does not match any state name. "
                    f"Available states: {list(state_shapes.keys())}"
                )

            if not callable(fn):
                raise TypeError(f"byof dynamics '{state_name}' must be callable, got {type(fn)}")

            # Check signature
            sig = inspect.signature(fn)
            if len(sig.parameters) != 4:
                raise ValueError(
                    f"byof dynamics '{state_name}' must have signature f(x, u, node, params), "
                    f"got {len(sig.parameters)} parameters: {list(sig.parameters.keys())}"
                )

            # Test call and check output shape
            try:
                result = fn(dummy_x, dummy_u, dummy_node, dummy_params)
            except Exception as e:
                raise ValueError(
                    f"byof dynamics '{state_name}' failed on test call with "
                    f"x.shape={dummy_x.shape}, u.shape={dummy_u.shape}: {e}"
                ) from e

            expected_shape = state_shapes[state_name]
            result_shape = jnp.asarray(result).shape
            if result_shape != expected_shape:
                raise ValueError(
                    f"byof dynamics '{state_name}' returned shape {result_shape}, "
                    f"expected {expected_shape} (state '{state_name}' shape)"
                )

            # Test that gradient works (JAX compatibility check)
            try:
                jax.grad(lambda x: jnp.sum(fn(x, dummy_u, dummy_node, dummy_params)))(dummy_x)
            except Exception as e:
                raise ValueError(
                    f"byof dynamics '{state_name}' is not differentiable with JAX. "
                    f"Ensure the function uses JAX operations (jax.numpy, not numpy): {e}"
                ) from e

    # Validate dynamics_discrete functions (same signature and shape as dynamics)
    byof_dynamics_discrete = byof.dynamics_discrete
    if byof_dynamics_discrete:
        state_shapes = {state.name: state.shape for state in states}
        for state_name, fn in byof_dynamics_discrete.items():
            if state_name not in state_shapes:
                raise ValueError(
                    f"byof dynamics_discrete '{state_name}' does not match any state name. "
                    f"Available states: {list(state_shapes.keys())}"
                )
            if not callable(fn):
                raise TypeError(
                    f"byof dynamics_discrete '{state_name}' must be callable, got {type(fn)}"
                )
            sig = inspect.signature(fn)
            if len(sig.parameters) != 4:
                raise ValueError(
                    f"byof dynamics_discrete '{state_name}' must have signature "
                    f"f(x, u, node, params), got {len(sig.parameters)} parameters: "
                    f"{list(sig.parameters.keys())}"
                )
            try:
                result = fn(dummy_x, dummy_u, dummy_node, dummy_params)
            except Exception as e:
                raise ValueError(
                    f"byof dynamics_discrete '{state_name}' failed on test call with "
                    f"x.shape={dummy_x.shape}, u.shape={dummy_u.shape}: {e}"
                ) from e
            expected_shape = state_shapes[state_name]
            result_shape = jnp.asarray(result).shape
            if result_shape != expected_shape:
                raise ValueError(
                    f"byof dynamics_discrete '{state_name}' returned shape {result_shape}, "
                    f"expected {expected_shape} (state '{state_name}' shape)"
                )
            try:
                jax.grad(lambda x: jnp.sum(fn(x, dummy_u, dummy_node, dummy_params)))(dummy_x)
            except Exception as e:
                raise ValueError(
                    f"byof dynamics_discrete '{state_name}' is not differentiable with JAX. "
                    f"Ensure the function uses JAX operations (jax.numpy, not numpy): {e}"
                ) from e

    # Validate nodal constraints
    for i, constraint_spec in enumerate(byof.nodal_constraints):
        fn = constraint_spec.constraint_fn
        if not callable(fn):
            raise TypeError(
                f"byof nodal_constraints[{i}]['constraint_fn'] must be callable, got {type(fn)}"
            )

        # Check signature
        sig = inspect.signature(fn)
        if len(sig.parameters) != 4:
            raise ValueError(
                f"byof nodal_constraints[{i}]['constraint_fn'] must have signature "
                f"f(x, u, node, params), "
                f"got {len(sig.parameters)} parameters: {list(sig.parameters.keys())}"
            )

        # Test call
        try:
            result = fn(dummy_x, dummy_u, dummy_node, dummy_params)
        except Exception as e:
            raise ValueError(
                f"byof nodal_constraints[{i}]['constraint_fn'] failed on test call with "
                f"x.shape={dummy_x.shape}, u.shape={dummy_u.shape}: {e}"
            ) from e

        # Check that result is array-like (can be scalar or vector)
        try:
            result_array = jnp.asarray(result)
        except Exception as e:
            raise ValueError(
                f"byof nodal_constraints[{i}]['constraint_fn'] must return array-like value, "
                f"got {type(result)}: {e}"
            ) from e

        # Test gradient
        try:
            jax.grad(lambda x: jnp.sum(fn(x, dummy_u, dummy_node, dummy_params)))(dummy_x)
        except Exception as e:
            raise ValueError(
                f"byof nodal_constraints[{i}]['constraint_fn'] is not differentiable with JAX: {e}"
            ) from e

        # Validate nodes if provided
        if constraint_spec.nodes is not None:
            nodes = constraint_spec.nodes
            if len(nodes) == 0:
                raise ValueError(f"byof nodal_constraints[{i}]['nodes'] cannot be empty")

            # Validate node indices if N is provided
            if N is not None:
                for node in nodes:
                    # Handle negative indices (e.g., -1 for last node)
                    normalized_node = node if node >= 0 else N + node
                    # Validate range
                    if not (0 <= normalized_node < N):
                        raise ValueError(
                            f"byof nodal_constraints[{i}]['nodes'] contains invalid index {node} "
                            f"(normalized: {normalized_node}). Valid range is [0, {N}) or "
                            f"negative indices [-{N}, -1]."
                        )

    # Validate cross-nodal constraints
    dummy_X = jnp.zeros((10, n_x))  # Dummy trajectory with 10 nodes
    dummy_U = jnp.zeros((10, n_u))

    for i, fn in enumerate(byof.cross_nodal_constraints):
        if not callable(fn):
            raise TypeError(f"byof cross_nodal_constraints[{i}] must be callable, got {type(fn)}")

        # Check signature
        sig = inspect.signature(fn)
        if len(sig.parameters) != 3:
            raise ValueError(
                f"byof cross_nodal_constraints[{i}] must have signature f(X, U, params), "
                f"got {len(sig.parameters)} parameters: {list(sig.parameters.keys())}"
            )

        # Test call
        try:
            result = fn(dummy_X, dummy_U, dummy_params)
        except Exception as e:
            raise ValueError(
                f"byof cross_nodal_constraints[{i}] failed on test call with "
                f"X.shape={dummy_X.shape}, U.shape={dummy_U.shape}: {e}"
            ) from e

        # Check that result is array-like
        try:
            result_array = jnp.asarray(result)
        except Exception as e:
            raise ValueError(
                f"byof cross_nodal_constraints[{i}] must return array-like value, "
                f"got {type(result)}: {e}"
            ) from e

        # Test gradient
        try:
            jax.grad(lambda X: jnp.sum(fn(X, dummy_U, dummy_params)))(dummy_X)
        except Exception as e:
            raise ValueError(
                f"byof cross_nodal_constraints[{i}] is not differentiable with JAX: {e}"
            ) from e

    # Validate CTCS constraints
    for i, ctcs_spec in enumerate(byof.ctcs_constraints):
        fn = ctcs_spec.constraint_fn
        if not callable(fn):
            raise TypeError(
                f"byof ctcs_constraints[{i}]['constraint_fn'] must be callable, got {type(fn)}"
            )

        # Check signature
        sig = inspect.signature(fn)
        if len(sig.parameters) != 4:
            raise ValueError(
                f"byof ctcs_constraints[{i}]['constraint_fn'] must have signature "
                f"f(x, u, node, params), got {len(sig.parameters)} parameters: "
                f"{list(sig.parameters.keys())}"
            )

        # Test call
        try:
            result = fn(dummy_x, dummy_u, dummy_node, dummy_params)
        except Exception as e:
            raise ValueError(
                f"byof ctcs_constraints[{i}]['constraint_fn'] failed on test call: {e}"
            ) from e

        # Check that result is scalar
        result_array = jnp.asarray(result)
        if result_array.shape != ():
            raise ValueError(
                f"byof ctcs_constraints[{i}]['constraint_fn'] must return a scalar, "
                f"got shape {result_array.shape}"
            )

        # Test gradient
        try:
            jax.grad(lambda x: fn(x, dummy_u, dummy_node, dummy_params))(dummy_x)
        except Exception as e:
            raise ValueError(
                f"byof ctcs_constraints[{i}]['constraint_fn'] is not differentiable with JAX: {e}"
            ) from e

        # Validate penalty function if callable
        penalty_spec = ctcs_spec.penalty
        if callable(penalty_spec):
            try:
                test_residual = jnp.array(0.5)
                penalty_result = penalty_spec(test_residual)
                jnp.asarray(penalty_result)
            except Exception as e:
                raise ValueError(
                    f"byof ctcs_constraints[{i}]['penalty'] custom function failed: {e}"
                ) from e

        # Validate idx
        if ctcs_spec.idx < 0:
            raise ValueError(
                f"byof ctcs_constraints[{i}]['idx'] must be non-negative, got {ctcs_spec.idx}"
            )

        # Validate bounds
        bounds = ctcs_spec.bounds
        if bounds[0] > bounds[1]:
            raise ValueError(
                f"byof ctcs_constraints[{i}]['bounds'] min ({bounds[0]}) must be <= "
                f"max ({bounds[1]})"
            )

        # Validate initial value is within bounds
        if ctcs_spec.initial is not None:
            initial = ctcs_spec.initial
            if not (bounds[0] <= initial <= bounds[1]):
                raise ValueError(
                    f"byof ctcs_constraints[{i}]['initial'] ({initial}) must be within "
                    f"bounds [{bounds[0]}, {bounds[1]}]"
                )

        # Validate over (node interval) if provided
        if ctcs_spec.over is not None:
            start, end = ctcs_spec.over
            if start >= end:
                raise ValueError(
                    f"byof ctcs_constraints[{i}]['over'] start ({start}) must be < end ({end})"
                )
            if start < 0:
                raise ValueError(
                    f"byof ctcs_constraints[{i}]['over'] start ({start}) must be non-negative"
                )
            # Validate against trajectory length if N is provided
            if N is not None:
                if end > N:
                    raise ValueError(
                        f"byof ctcs_constraints[{i}]['over'] end ({end}) exceeds "
                        f"trajectory length ({N})"
                    )