logic
JAX visitors for logic expressions.
Visitors: All, Any, Cond
get_default_float_dtype()
¶
set_default_float_dtype(dtype: str) -> None
¶
Set the default float dtype for conditional branches.
This is called by Problem.init to configure the dtype used in jax.lax.cond branches.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dtype
|
str
|
String like "float32" or "float64" |
required |