jax
JAX backend for lowering symbolic expressions to executable functions.
This package implements the JAX lowering backend that converts symbolic
expression AST nodes into JAX functions with automatic differentiation
support. The lowering uses a visitor pattern where each expression type
has a corresponding visitor function registered via @visitor.
The visitor functions are split across submodules that mirror the
openscvx.symbolic.expr package structure. Importing this package
triggers registration of all visitors.
Example::
from openscvx.symbolic.lowerers.jax import JaxLowerer
lowerer = JaxLowerer()
f = lowerer.lower(expr)
result = f(x_val, u_val, node=0, params={})
JaxLowerer
¶
JAX backend for lowering symbolic expressions to executable functions.
This class implements the visitor pattern for converting symbolic expression AST nodes to JAX functions. Each expression type has a corresponding visitor function decorated with @visitor that handles the lowering logic.
The lowering process is recursive: each visitor lowers its child expressions first, then composes them into a JAX operation. All lowered functions have a standardized signature (x, u, node, params) -> result.
Note
This is a stateless lowerer - all state is in the expression tree.
Example
Set up the JaxLowerer and lower an expression to a JAX function::
lowerer = JaxLowerer()
expr = ox.Norm(x)**2 + 0.1 * ox.Norm(u)**2
f = lowerer.lower(expr)
result = f(x_val, u_val, node=0, params={})
Note
The lowerer is stateless and can be reused for multiple expressions.
Source code in openscvx/symbolic/lowerers/jax/_lowerer.py
lower(expr: Expr) -> Callable
¶
Lower a symbolic expression to a JAX function.
Main entry point for lowering. Delegates to dispatch() which looks up the appropriate visitor method based on the expression type.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
expr
|
Expr
|
Symbolic expression to lower (any Expr subclass) |
required |
Returns:
| Type | Description |
|---|---|
Callable
|
JAX function with signature (x, u, node, params) -> result |
Raises:
| Type | Description |
|---|---|
NotImplementedError
|
If no visitor exists for the expression type |
ValueError
|
If the expression is malformed (e.g., State without slice) |