Skip to content

logic

JAX visitors for logic expressions.

Visitors: All, Any, Cond

get_default_float_dtype()

Get the current default float dtype for conditional branches.

Source code in openscvx/symbolic/lowerers/jax/logic.py
def get_default_float_dtype():
    """Get the current default float dtype for conditional branches."""
    return _DEFAULT_FLOAT_DTYPE

set_default_float_dtype(dtype: str) -> None

Set the default float dtype for conditional branches.

This is called by Problem.init to configure the dtype used in jax.lax.cond branches.

Parameters:

Name Type Description Default
dtype str

String like "float32" or "float64"

required
Source code in openscvx/symbolic/lowerers/jax/logic.py
def set_default_float_dtype(dtype: str) -> None:
    """Set the default float dtype for conditional branches.

    This is called by Problem.__init__ to configure the dtype used in jax.lax.cond branches.

    Args:
        dtype: String like "float32" or "float64"
    """
    global _DEFAULT_FLOAT_DTYPE
    dtype_lower = dtype.lower()
    if dtype_lower in ("float64", "f64", "double"):
        _DEFAULT_FLOAT_DTYPE = jnp.float64
    else:
        _DEFAULT_FLOAT_DTYPE = jnp.float32