Skip to content

jax

JAX backend for lowering symbolic expressions to executable functions.

This package implements the JAX lowering backend that converts symbolic expression AST nodes into JAX functions with automatic differentiation support. The lowering uses a visitor pattern where each expression type has a corresponding visitor function registered via @visitor.

The visitor functions are split across submodules that mirror the openscvx.symbolic.expr package structure. Importing this package triggers registration of all visitors.

Example::

from openscvx.symbolic.lowerers.jax import JaxLowerer

lowerer = JaxLowerer()
f = lowerer.lower(expr)
result = f(x_val, u_val, node=0, params={})

JaxLowerer

JAX backend for lowering symbolic expressions to executable functions.

This class implements the visitor pattern for converting symbolic expression AST nodes to JAX functions. Each expression type has a corresponding visitor function decorated with @visitor that handles the lowering logic.

The lowering process is recursive: each visitor lowers its child expressions first, then composes them into a JAX operation. All lowered functions have a standardized signature (x, u, node, params) -> result.

Note

This is a stateless lowerer - all state is in the expression tree.

Example

Set up the JaxLowerer and lower an expression to a JAX function::

lowerer = JaxLowerer()
expr = ox.Norm(x)**2 + 0.1 * ox.Norm(u)**2
f = lowerer.lower(expr)
result = f(x_val, u_val, node=0, params={})
Note

The lowerer is stateless and can be reused for multiple expressions.

Source code in openscvx/symbolic/lowerers/jax/_lowerer.py
class JaxLowerer:
    """JAX backend for lowering symbolic expressions to executable functions.

    This class implements the visitor pattern for converting symbolic expression
    AST nodes to JAX functions. Each expression type has a corresponding visitor
    function decorated with @visitor that handles the lowering logic.

    The lowering process is recursive: each visitor lowers its child expressions
    first, then composes them into a JAX operation. All lowered functions have
    a standardized signature (x, u, node, params) -> result.

    Note:
        This is a stateless lowerer - all state is in the expression tree.

    Example:
        Set up the JaxLowerer and lower an expression to a JAX function::

            lowerer = JaxLowerer()
            expr = ox.Norm(x)**2 + 0.1 * ox.Norm(u)**2
            f = lowerer.lower(expr)
            result = f(x_val, u_val, node=0, params={})

    Note:
        The lowerer is stateless and can be reused for multiple expressions.
    """

    def lower(self, expr: Expr) -> Callable:
        """Lower a symbolic expression to a JAX function.

        Main entry point for lowering. Delegates to dispatch() which looks up
        the appropriate visitor method based on the expression type.

        Args:
            expr: Symbolic expression to lower (any Expr subclass)

        Returns:
            JAX function with signature (x, u, node, params) -> result

        Raises:
            NotImplementedError: If no visitor exists for the expression type
            ValueError: If the expression is malformed (e.g., State without slice)
        """
        return dispatch(self, expr)
lower(expr: Expr) -> Callable

Lower a symbolic expression to a JAX function.

Main entry point for lowering. Delegates to dispatch() which looks up the appropriate visitor method based on the expression type.

Parameters:

Name Type Description Default
expr Expr

Symbolic expression to lower (any Expr subclass)

required

Returns:

Type Description
Callable

JAX function with signature (x, u, node, params) -> result

Raises:

Type Description
NotImplementedError

If no visitor exists for the expression type

ValueError

If the expression is malformed (e.g., State without slice)

Source code in openscvx/symbolic/lowerers/jax/_lowerer.py
def lower(self, expr: Expr) -> Callable:
    """Lower a symbolic expression to a JAX function.

    Main entry point for lowering. Delegates to dispatch() which looks up
    the appropriate visitor method based on the expression type.

    Args:
        expr: Symbolic expression to lower (any Expr subclass)

    Returns:
        JAX function with signature (x, u, node, params) -> result

    Raises:
        NotImplementedError: If no visitor exists for the expression type
        ValueError: If the expression is malformed (e.g., State without slice)
    """
    return dispatch(self, expr)