Skip to content

stl

JAX visitors and GMSR math for STL (Signal Temporal Logic) expressions.

Visitors: Or, And, IfThen, IntegerVariable

Lowers symbolic STL expression nodes to JAX functions using Generalized Mean-based Smooth Robustness (GMSR) parameterizations.

The GMSR helper functions (AND, OR, IfThen, etc.) are pure JAX math implementations used by the visitor functions below.

Author: Samet Uzun Reference: [https://doi.org/10.48550/arxiv.2405.10996] [https://doi.org/10.2514/6.2025-1895]

AND(y: ArrayLike, c: float = 1e-08) -> Array

Smooth conjunction: AND(y) <= 0 iff y_i <= 0 for all i.

Source code in openscvx/symbolic/lowerers/jax/stl.py
def AND(y: ArrayLike, c: float = 1e-8) -> Array:
    """Smooth conjunction: AND(y) <= 0  iff  y_i <= 0 for all i."""
    y = jnp.asarray(y)

    positive_part = jnp.maximum(y, 0.0)
    negative_part = jnp.maximum(-y, 0.0)

    mp = jnp.mean(positive_part**2) + c
    m0 = _root_sum_of_product_terms(negative_part**2, c)

    return jnp.sqrt(mp) - jnp.sqrt(m0)

AND_lite(y: ArrayLike, c: float = 1e-08) -> Array

Lite conjunction (positive part only): AND_lite(y) = 0 iff y_i <= 0 for all i.

Source code in openscvx/symbolic/lowerers/jax/stl.py
def AND_lite(y: ArrayLike, c: float = 1e-8) -> Array:
    """Lite conjunction (positive part only): AND_lite(y) = 0  iff  y_i <= 0 for all i."""
    y = jnp.asarray(y)

    mp = jnp.mean(jnp.maximum(y, 0.0) ** 2) + c
    return jnp.sqrt(mp) - jnp.sqrt(c)

OR(y: ArrayLike, c: float = 1e-08) -> Array

Smooth disjunction: OR(y) <= 0 iff y_i <= 0 for some i.

Source code in openscvx/symbolic/lowerers/jax/stl.py
def OR(y: ArrayLike, c: float = 1e-8) -> Array:
    """Smooth disjunction: OR(y) <= 0  iff  y_i <= 0 for some i."""
    y = jnp.asarray(y)
    return -AND(-y, c=c)

OR_lite(y: ArrayLike, c: float = 1e-08) -> Array

Lite disjunction (positive part only): OR_lite(y) = 0 iff y_i <= 0 for some i.

Source code in openscvx/symbolic/lowerers/jax/stl.py
def OR_lite(y: ArrayLike, c: float = 1e-8) -> Array:
    """Lite disjunction (positive part only): OR_lite(y) = 0  iff  y_i <= 0 for some i."""
    y = jnp.asarray(y)

    m0 = _root_sum_of_product_terms(jnp.maximum(y, 0.0) ** 2, c)
    return jnp.sqrt(m0) - jnp.sqrt(c)

gmsr_IfThen(y: ArrayLike, c: float = 1e-08) -> Array

Smooth implication: IfThen(y) <= 0 iff (y_0 <= 0 => y_1 <= 0).

Source code in openscvx/symbolic/lowerers/jax/stl.py
def gmsr_IfThen(y: ArrayLike, c: float = 1e-8) -> Array:
    """Smooth implication: IfThen(y) <= 0  iff  (y_0 <= 0 => y_1 <= 0)."""
    y = jnp.asarray(y)
    return OR(jnp.array([-y[0], y[1]]), c=c)

gmsr_IfThen_lite(y: ArrayLike, c: float = 1e-08) -> Array

Lite implication: IfThen_lite(y) = 0 iff (y_0 <= 0 => y_1 <= 0).

Can enforce continuous-time implication via periodic auxiliary state

z_dot(t) = IfThen_lite([y_0(t), y_1(t)]) z(0) = z(T)

Source code in openscvx/symbolic/lowerers/jax/stl.py
def gmsr_IfThen_lite(y: ArrayLike, c: float = 1e-8) -> Array:
    """Lite implication: IfThen_lite(y) = 0  iff  (y_0 <= 0 => y_1 <= 0).

    Can enforce continuous-time implication via periodic auxiliary state:
        z_dot(t) = IfThen_lite([y_0(t), y_1(t)])
        z(0) = z(T)
    """
    y = jnp.asarray(y)
    return OR_lite(jnp.array([-y[0], y[1]]), c=c)

integer_variable(y: ArrayLike, values: ArrayLike, c: float = 1e-08) -> Array

Smooth discrete constraint: returns 0 iff y equals one of values.

Source code in openscvx/symbolic/lowerers/jax/stl.py
def integer_variable(y: ArrayLike, values: ArrayLike, c: float = 1e-8) -> Array:
    """Smooth discrete constraint: returns 0 iff y equals one of values."""
    y = jnp.asarray(y)
    values = jnp.asarray(values)
    return OR(_smooth_equality(y - values, c=c), c=c)