preprocessing
Validation and preprocessing utilities for symbolic expressions.
This module provides preprocessing and validation functions for symbolic expressions in trajectory optimization problems. These utilities ensure that expressions are well-formed and constraints are properly specified before compilation to solvers.
The preprocessing pipeline includes
- Shape validation: Ensure all expressions have compatible shapes
- Variable name validation: Check for unique, non-reserved variable names
- Constraint validation: Verify constraints appear only at root level
- Dynamics validation: Check that dynamics match state dimensions
- Time parameter validation: Validate time configuration
- Slice assignment: Assign contiguous memory slices to variables
These functions are typically called automatically during problem construction, but can also be used manually for debugging or custom problem setups.
Example
Validating expressions before problem construction::
import openscvx as ox
x = ox.State("x", shape=(3,))
u = ox.Control("u", shape=(2,))
# Build dynamics and constraints
dynamics = {
"x": u # Will fail validation - dimension mismatch!
}
# Validate dimensions before creating problem
from openscvx.symbolic.preprocessing import validate_dynamics_dict_dimensions
try:
validate_dynamics_dict_dimensions(dynamics, [x])
except ValueError as e:
print(f"Validation error: {e}")
collect_and_assign_slices(states: List[State], controls: List[Control], *, start_index: int = 0) -> Tuple[list[State], list[Control]]
¶
Assign contiguous memory slices to states and controls.
This function assigns slice objects to states and controls that determine their positions in the flat decision variable vector. Variables can have either: - Auto-assigned slices: Automatically assigned contiguously based on order - Manual slices: User-specified slices that must be contiguous and non-overlapping
If any variables have manual slices, they must: - Start at index 0 (or start_index if specified) - Be contiguous and non-overlapping - Match the variable's flattened dimension
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
states
|
List[State]
|
List of State objects in canonical order |
required |
controls
|
List[Control]
|
List of Control objects in canonical order |
required |
start_index
|
int
|
Starting index for slice assignment (default: 0) |
0
|
Returns:
| Type | Description |
|---|---|
Tuple[list[State], list[Control]]
|
Tuple of (states, controls) with slice attributes assigned |
Raises:
| Type | Description |
|---|---|
ValueError
|
If manual slices are invalid (wrong size, overlapping, not starting at 0) |
Example
x = ox.State("x", shape=(3,)) u = ox.Control("u", shape=(2,)) states, controls = collect_and_assign_slices([x], [u]) print(x._slice) # slice(0, 3) print(u._slice) # slice(0, 2)
Source code in openscvx/symbolic/preprocessing.py
139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 | |
convert_dynamics_dict_to_expr(dynamics: Dict[str, Expr], states: List[State]) -> Tuple[Dict[str, Expr], Expr]
¶
Convert dynamics dictionary to concatenated expression in canonical order.
Converts a dictionary-based dynamics specification to a single concatenated expression that represents the full ODE system x_dot = f(x, u, t). The dynamics are ordered according to the states list to ensure consistent variable ordering.
This function also normalizes scalar values (int, float) to Constant expressions.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dynamics
|
Dict[str, Expr]
|
Dictionary mapping state names to their dynamics expressions |
required |
states
|
List[State]
|
List of State objects defining the canonical order |
required |
Returns:
| Type | Description |
|---|---|
Tuple[Dict[str, Expr], Expr]
|
Tuple of: - Updated dynamics dictionary (with scalars converted to Constant expressions) - Concatenated dynamics expression ordered by states list |
Example
Convert dynamics dict to a single expression:
x = ox.State("x", shape=(3,))
y = ox.State("y", shape=(2,))
dynamics_dict = {"x": x * 2, "y": 1.0} # Scalar for y
converted_dict, concat_expr = convert_dynamics_dict_to_expr(
dynamics_dict, [x, y]
)
# converted_dict["y"] is now Constant(1.0)
# concat_expr is Concat(x * 2, Constant(1.0))
Source code in openscvx/symbolic/preprocessing.py
fill_default_guesses(states: List[State], N: int) -> None
¶
Fill in default linspace guesses for states with guess=None.
For states with both initial and final conditions set, generates a linear interpolation from initial to final values.
This function modifies states in-place.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
states
|
List[State]
|
List of State objects to fill guesses for |
required |
N
|
int
|
Number of discretization nodes |
required |
Source code in openscvx/symbolic/preprocessing.py
validate_and_normalize_constraint_nodes(exprs: Union[Expr, list[Expr]], n_nodes: int)
¶
Validate and normalize constraint node specifications.
This function validates and normalizes node specifications for constraint wrappers:
For NodalConstraint
- nodes should be a list of specific node indices: [2, 4, 6, 8]
- Validates all nodes are within the valid range [0, n_nodes)
For CTCS (Continuous-Time Constraint Satisfaction) constraints: - nodes should be a tuple of (start, end): (0, 10) - None is replaced with (0, n_nodes) to apply over entire trajectory - Validation ensures tuple has exactly 2 elements and start < end - Validates indices are within trajectory bounds
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
exprs
|
Union[Expr, list[Expr]]
|
Single expression or list of expressions to validate |
required |
n_nodes
|
int
|
Total number of nodes in the trajectory |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If node specifications are invalid (out of range, malformed, etc.) |
Example
x = ox.State("x", shape=(3,)) constraint = (x <= 5).at([0, 10, 20]) # NodalConstraint validate_and_normalize_constraint_nodes([constraint], n_nodes=50) # OK
ctcs_constraint = (x <= 5).over((0, 100)) # CTCS validate_and_normalize_constraint_nodes([ctcs_constraint], n_nodes=50)
# Raises ValueError: Range exceeds trajectory length
Source code in openscvx/symbolic/preprocessing.py
validate_boundary_conditions(states: List[State]) -> None
¶
Validate that all states have initial and final boundary conditions set.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
states
|
List[State]
|
List of State objects to validate |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If any state is missing initial or final conditions |
Source code in openscvx/symbolic/preprocessing.py
validate_bounds(variables: List[Variable]) -> None
¶
Validate that all variables have min and max bounds set.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
variables
|
List[Variable]
|
List of Variable objects (State or Control) to validate |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If any variable is missing min or max bounds |
Source code in openscvx/symbolic/preprocessing.py
validate_constraints_at_root(exprs: Union[Expr, list[Expr]])
¶
Validate that constraints only appear at the root level of expression trees.
Constraints and constraint wrappers (CTCS, NodalConstraint, CrossNodeConstraint) must only appear as top-level expressions, not nested within other expressions. However, constraints inside constraint wrappers are allowed (e.g., the constraint inside CTCS(x <= 5)).
This ensures constraints are properly processed during problem compilation and prevents ambiguous constraint specifications.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
exprs
|
Union[Expr, list[Expr]]
|
Single expression or list of expressions to validate |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If any constraint or constraint wrapper is found at depth > 0 |
Example
x = ox.State("x", shape=(3,)) constraint = x <= 5 validate_constraints_at_root([constraint]) # OK - constraint at root
bad_expr = ox.Sum(x <= 5) # Constraint nested inside Sum validate_constraints_at_root([bad_expr]) # Raises ValueError
Source code in openscvx/symbolic/preprocessing.py
validate_cross_node_constraint(cross_node_constraint: CrossNodeConstraint, n_nodes: int) -> None
¶
Validate cross-node constraint bounds and variable consistency.
This function performs two validations in a single tree traversal:
-
Bounds checking: Ensures all NodeReference indices are within [0, n_nodes). Cross-node constraints reference fixed trajectory nodes (e.g., position.at(5)), and this validates those indices are valid. Negative indices are normalized (e.g., -1 becomes n_nodes-1) before checking.
-
Variable consistency: Ensures that if ANY variable uses .at(), then ALL state/control variables must use .at(). Mixing causes shape mismatches during lowering because:
- Variables with .at(k) extract single-node values: X[k, :] → shape (n_x,)
- Variables without .at() expect full trajectory: X[:, :] → shape (N, n_x)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
cross_node_constraint
|
CrossNodeConstraint
|
The CrossNodeConstraint to validate |
required |
n_nodes
|
int
|
Total number of trajectory nodes |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If any NodeReference accesses nodes outside [0, n_nodes) |
ValueError
|
If constraint mixes .at() and non-.at() variables |
Example
Valid cross-node constraint:
from openscvx.symbolic.expr import CrossNodeConstraint
position = State("pos", shape=(3,))
# Valid: all variables use .at(), indices in bounds
constraint = CrossNodeConstraint(position.at(5) - position.at(4) <= 0.1)
validate_cross_node_constraint(constraint, n_nodes=10) # OK
Invalid - out of bounds:
# Invalid: node 10 is out of bounds for n_nodes=10
bad_bounds = CrossNodeConstraint(position.at(0) == position.at(10))
validate_cross_node_constraint(bad_bounds, n_nodes=10) # Raises ValueError
Invalid - mixed .at() usage:
velocity = State("vel", shape=(3,))
# Invalid: position uses .at(), velocity doesn't
bad_mixed = CrossNodeConstraint(position.at(5) - velocity <= 0.1)
validate_cross_node_constraint(bad_mixed, n_nodes=10) # Raises ValueError
Source code in openscvx/symbolic/preprocessing.py
350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 | |
validate_dynamics_dict(dynamics: Dict[str, Expr], states: List[State], byof_dynamics: Optional[Dict[str, callable]] = None) -> None
¶
Validate that dynamics dictionary keys match state names exactly.
Ensures that the dynamics dictionary (combined with optional byof dynamics) has exactly the same keys as the state names, with no missing states, no extra keys, and no overlap between symbolic and byof dynamics.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dynamics
|
Dict[str, Expr]
|
Dictionary mapping state names to their dynamics expressions |
required |
states
|
List[State]
|
List of State objects |
required |
byof_dynamics
|
Optional[Dict[str, callable]]
|
Optional dictionary mapping state names to raw JAX functions. States in byof_dynamics should NOT appear in dynamics dict. |
None
|
Raises:
| Type | Description |
|---|---|
ValueError
|
If there's a mismatch between state names and dynamics keys, or if a state appears in both dynamics and byof_dynamics. |
Example
x = ox.State("x", shape=(3,)) y = ox.State("y", shape=(2,)) dynamics = {"x": x * 2, "y": y + 1} validate_dynamics_dict(dynamics, [x, y]) # OK
bad_dynamics = {"x": x * 2} # Missing "y" validate_dynamics_dict(bad_dynamics, [x, y]) # Raises ValueError
With byof_dynamics (expert user mode)¶
dynamics = {"x": x * 2} # Only symbolic for x byof_dynamics = {"y": some_jax_fn} # Raw JAX for y validate_dynamics_dict(dynamics, [x, y], byof_dynamics) # OK
Source code in openscvx/symbolic/preprocessing.py
validate_dynamics_dict_dimensions(dynamics: Dict[str, Expr], states: List[State]) -> None
¶
Validate that each dynamics expression matches its corresponding state shape.
For dictionary-based dynamics specification, ensures that each state's dynamics expression has the same shape as the state itself. This validates that each component of x_dot = f(x, u, t) has the correct dimension.
Scalars are normalized to shape (1,) for comparison, matching Concat behavior.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dynamics
|
Dict[str, Expr]
|
Dictionary mapping state names to their dynamics expressions |
required |
states
|
List[State]
|
List of State objects |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If any dynamics expression dimension doesn't match its state shape |
Example
x = ox.State("x", shape=(3,)) y = ox.State("y", shape=(2,)) u = ox.Control("u", shape=(3,)) dynamics = {"x": u, "y": y + 1} validate_dynamics_dict_dimensions(dynamics, [x, y]) # OK
bad_dynamics = {"x": u, "y": u} # y dynamics has wrong shape validate_dynamics_dict_dimensions(bad_dynamics, [x, y]) # Raises ValueError
Source code in openscvx/symbolic/preprocessing.py
validate_dynamics_dimension(dynamics_expr: Union[Expr, list[Expr]], states: Union[State, list[State]]) -> None
¶
Validate that dynamics expression dimensions match state dimensions.
Ensures that the total dimension of all dynamics expressions matches the total dimension of all states. Each dynamics expression must be a 1D vector, and their combined dimension must equal the sum of all state dimensions.
This is essential for ensuring the ODE system x_dot = f(x, u, t) is well-formed.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dynamics_expr
|
Union[Expr, list[Expr]]
|
Single dynamics expression or list of dynamics expressions. Combined, they represent x_dot = f(x, u, t) for all states. |
required |
states
|
Union[State, list[State]]
|
Single state variable or list of state variables that the dynamics describe. |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If dimensions don't match or if any dynamics is not a 1D vector |
Example
x = ox.State("x", shape=(3,)) y = ox.State("y", shape=(2,)) dynamics = ox.Concat(x * 2, y + 1) # Shape (5,) - matches total state dim validate_dynamics_dimension(dynamics, [x, y]) # OK
bad_dynamics = x # Shape (3,) - doesn't match total dim of 5 validate_dynamics_dimension(bad_dynamics, [x, y]) # Raises ValueError
Source code in openscvx/symbolic/preprocessing.py
validate_guesses(variables: List[Variable]) -> None
¶
Validate that all variables have initial guesses set.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
variables
|
List[Variable]
|
List of Variable objects (State or Control) to validate |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If any variable is missing a guess |
Source code in openscvx/symbolic/preprocessing.py
validate_input_types(dynamics: any, states: any, controls: any, constraints: any, N: any, time: any) -> None
¶
Validate that all user-facing inputs have correct types.
This catches common user errors like passing a single State or Control instead of a list, or passing wrong types for dynamics, N, or time. Should be called before any other validation in the preprocessing pipeline.
Raises:
| Type | Description |
|---|---|
TypeError
|
If any input has the wrong type |
ValueError
|
If N is not positive |
Source code in openscvx/symbolic/preprocessing.py
validate_propagation_input_types(dynamics_prop_extra: any, states_prop_extra: any) -> None
¶
Validate types for optional propagation inputs.
These parameters must either both be None or both be provided. When provided, dynamics_prop_extra must be a dict and states_prop_extra must be a list of State objects.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dynamics_prop_extra
|
any
|
Should be None or a dict mapping state names to expressions |
required |
states_prop_extra
|
any
|
Should be None or a list of State objects |
required |
Raises:
| Type | Description |
|---|---|
TypeError
|
If either input has the wrong type |
ValueError
|
If only one of the two is provided |
Example
distance = ox.State("distance", shape=(1,))
Wrong: passing bare State instead of list¶
validate_propagation_input_types({"distance": expr}, distance)
Raises TypeError: 'states_prop_extra' must be a list ...¶
Source code in openscvx/symbolic/preprocessing.py
validate_shapes(exprs: Union[Expr, list[Expr]]) -> None
¶
Validate shapes for a single expression or list of expressions.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
exprs
|
Union[Expr, list[Expr]]
|
Single expression or list of expressions to validate |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If any expression has invalid shapes |
Source code in openscvx/symbolic/preprocessing.py
validate_variable_names(exprs: Iterable[Expr], *, reserved_prefix: str = '_', reserved_names: Set[str] = None) -> None
¶
Validate variable names for uniqueness and reserved name conflicts.
This function ensures that all State and Control variable names are: 1. Unique across distinct variable instances 2. Not starting with the reserved prefix (default: "_") 3. Not colliding with explicitly reserved names
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
exprs
|
Iterable[Expr]
|
Iterable of expression trees to scan for variables |
required |
reserved_prefix
|
str
|
Prefix that user variables cannot start with (default: "_") |
'_'
|
reserved_names
|
Set[str]
|
Set of explicitly reserved names that cannot be used (default: None) |
None
|
Raises:
| Type | Description |
|---|---|
ValueError
|
If any variable name violates uniqueness or reserved name rules |
Example
x1 = ox.State("x", shape=(3,)) x2 = ox.State("x", shape=(2,)) # Same name, different object validate_variable_names([x1 + x2]) # Raises ValueError: Duplicate name 'x'
bad = ox.State("internal", shape=(2,)) validate_variable_names([bad]) # Raises ValueError: Reserved prefix ''