Skip to content

mjx

MuJoCo MJX dynamics adapters for OpenSCvx BYOF.

The recommended entry-point is :func:mjx_byof, which returns a complete byof["dynamics"] dict and automatically handles free-joint quaternion kinematics — no separate imports required:

byof = {"dynamics": mjx_byof(mjx_model, qpos=qpos, qvel=qvel, ctrl=ctrl)}

For models without free joints (cartpoles, manipulators, etc.) the returned dict contains only "qvel", and qpos kinematics must still be specified symbolically via dynamics={"qpos": qvel}. For models with free joints (drones, humanoids) "qpos" is included automatically and no symbolic dynamics entry is needed.

The lower-level :func:mjx_dynamics is also public for advanced users who need direct access to the BYOF callable for the qvel derivative.

Note

Time dilation is handled automatically by the BYOF lowering pipeline; all functions return physical (un-dilated) quantities.

mjx_byof(mjx_model: Any, *, qpos: State | slice, qvel: State | slice, ctrl: Control | slice, return_component: str = 'qacc', extra_postprocess: Optional[Callable[[Any], Any]] = None) -> dict

Return a complete byof["dynamics"] dict for a MuJoCo MJX model.

This is the recommended high-level entry-point. It inspects the model's nq and nv to detect free joints and automatically includes the quaternion kinematics callable for qpos when needed.

Parameters:

Name Type Description Default
mjx_model Any

A model produced by :func:mujoco.mjx.put_model.

required
qpos State | slice

Position state (or slice). Length must equal mjx_model.nq.

required
qvel State | slice

Velocity state (or slice). Length must equal mjx_model.nv.

required
ctrl Control | slice

Control variable (or slice). Length must equal mjx_model.nu.

required
return_component str

Passed to :func:mjx_dynamics. "qacc" (default) uses the generalized acceleration as the qvel derivative; "qvel" returns qvel directly (rarely needed).

'qacc'
extra_postprocess Optional[Callable[[Any], Any]]

Optional callable applied to the MJX data object after mjx.forward. Passed through to :func:mjx_dynamics.

None

Returns:

Type Description
dict

A dict suitable for use as byof["dynamics"].

dict

For models without free joints (nq == nv) only "qvel" is

dict

included; position kinematics should still be provided symbolically

dict

via dynamics={"qpos": qvel}.

dict

For models with free joints (nq > nv) both "qpos" and

dict

"qvel" are included and no symbolic dynamics entry is needed.

Example

Cartpole (nq == nv, no free joint)::

byof = {"dynamics": mjx_byof(mjx_model, qpos=qpos, qvel=qvel, ctrl=ctrl)}
problem = ox.Problem(
    dynamics={"qpos": qvel},   # still required for non-free models
    byof=byof, ...
)

Quadrotor / drone (nq > nv, one free joint)::

byof = {"dynamics": mjx_byof(mjx_model, qpos=qpos, qvel=qvel, ctrl=ctrl)}
problem = ox.Problem(
    dynamics={},               # qpos handled automatically
    byof=byof, ...
)
Source code in openscvx/integrations/mjx.py
def mjx_byof(
    mjx_model: Any,
    *,
    qpos: "State | slice",
    qvel: "State | slice",
    ctrl: "Control | slice",
    return_component: str = "qacc",
    extra_postprocess: Optional[Callable[[Any], Any]] = None,
) -> dict:
    """Return a complete ``byof["dynamics"]`` dict for a MuJoCo MJX model.

    This is the recommended high-level entry-point.  It inspects the model's
    ``nq`` and ``nv`` to detect free joints and automatically includes the
    quaternion kinematics callable for ``qpos`` when needed.

    Args:
        mjx_model: A model produced by :func:`mujoco.mjx.put_model`.
        qpos: Position state (or slice). Length must equal ``mjx_model.nq``.
        qvel: Velocity state (or slice). Length must equal ``mjx_model.nv``.
        ctrl: Control variable (or slice). Length must equal ``mjx_model.nu``.
        return_component: Passed to :func:`mjx_dynamics`. ``"qacc"``
            (default) uses the generalized acceleration as the ``qvel``
            derivative; ``"qvel"`` returns qvel directly (rarely needed).
        extra_postprocess: Optional callable applied to the MJX ``data``
            object after ``mjx.forward``. Passed through to
            :func:`mjx_dynamics`.

    Returns:
        A dict suitable for use as ``byof["dynamics"]``.
        For models **without** free joints (``nq == nv``) only ``"qvel"`` is
        included; position kinematics should still be provided symbolically
        via ``dynamics={"qpos": qvel}``.
        For models **with** free joints (``nq > nv``) both ``"qpos"`` and
        ``"qvel"`` are included and no symbolic ``dynamics`` entry is needed.

    Example:
        Cartpole (nq == nv, no free joint)::

            byof = {"dynamics": mjx_byof(mjx_model, qpos=qpos, qvel=qvel, ctrl=ctrl)}
            problem = ox.Problem(
                dynamics={"qpos": qvel},   # still required for non-free models
                byof=byof, ...
            )

        Quadrotor / drone (nq > nv, one free joint)::

            byof = {"dynamics": mjx_byof(mjx_model, qpos=qpos, qvel=qvel, ctrl=ctrl)}
            problem = ox.Problem(
                dynamics={},               # qpos handled automatically
                byof=byof, ...
            )
    """
    nq = int(mjx_model.nq)
    nv = int(mjx_model.nv)

    result: dict = {
        "qvel": mjx_dynamics(
            mjx_model,
            qpos=qpos,
            qvel=qvel,
            ctrl=ctrl,
            return_component=return_component,
            extra_postprocess=extra_postprocess,
        ),
    }

    n_free = nq - nv  # each free joint contributes exactly 1 extra position DOF
    if n_free > 0:
        result["qpos"] = _free_joint_qpos_dynamics(
            qpos=qpos,
            qvel=qvel,
            n_free_joints=n_free,
        )

    return result

mjx_dynamics(mjx_model: Any, *, qpos: State | slice, qvel: State | slice, ctrl: Control | slice, return_component: str = 'qacc', extra_postprocess: Optional[Callable[[Any], Any]] = None) -> Callable

Wrap a MuJoCo MJX model as a BYOF dynamics function.

Parameters:

Name Type Description Default
mjx_model Any

A model produced by :func:mujoco.mjx.put_model. Must be a JAX pytree (the standard MJX representation).

required
qpos State | slice

Position state (or slice into the unified x vector). Length must equal mjx_model.nq.

required
qvel State | slice

Velocity state (or slice). Length must equal mjx_model.nv.

required
ctrl Control | slice

Control variable (or slice into the unified u vector). Length must equal mjx_model.nu.

required
return_component str

Which MJX field to return. "qacc" (default) returns the generalized acceleration qacc for use as the qvel state derivative. "qvel" returns qvel for use as the qpos state derivative (rarely needed because that is already symbolic).

'qacc'
extra_postprocess Optional[Callable[[Any], Any]]

Optional callable applied to the MJX data after mjx.forward. Useful for computing custom outputs (e.g. site positions) used elsewhere.

None

Returns:

Type Description
Callable

A function f(x, u, node, params) -> jnp.ndarray matching the BYOF

Callable

dynamics signature.

Raises:

Type Description
ImportError

If mujoco.mjx is not installed.

ValueError

If return_component is not one of the allowed values.

Note

MJX's contact solver uses lax.while_loop, which is not reverse-mode differentiable. For contact-free systems (manipulators, cartpoles, etc.) disable contacts before uploading the model::

mj_model.opt.disableflags |= mujoco.mjtDisableBit.mjDSBL_CONTACT
mjx_model = mjx.put_model(mj_model)
Example

Cartpole swing-up dynamics::

import mujoco
import mujoco.mjx as mjx
import openscvx as ox
from openscvx.integrations import mjx_dynamics

mj_model = mujoco.MjModel.from_xml_path("cartpole.xml")
mjx_model = mjx.put_model(mj_model)

qpos = ox.State("qpos", shape=(mjx_model.nq,))
qvel = ox.State("qvel", shape=(mjx_model.nv,))
ctrl = ox.Control("ctrl", shape=(mjx_model.nu,))

qvel_dynamics = mjx_dynamics(mjx_model, qpos=qpos, qvel=qvel, ctrl=ctrl)

problem = ox.Problem(
    dynamics={"qpos": qvel},
    byof={"dynamics": {"qvel": qvel_dynamics}},
    states=[qpos, qvel],
    controls=[ctrl],
    ...
)
Source code in openscvx/integrations/mjx.py
def mjx_dynamics(
    mjx_model: Any,
    *,
    qpos: "State | slice",
    qvel: "State | slice",
    ctrl: "Control | slice",
    return_component: str = "qacc",
    extra_postprocess: Optional[Callable[[Any], Any]] = None,
) -> Callable:
    """Wrap a MuJoCo MJX model as a BYOF dynamics function.

    Args:
        mjx_model: A model produced by :func:`mujoco.mjx.put_model`. Must be a
            JAX pytree (the standard MJX representation).
        qpos: Position state (or slice into the unified ``x`` vector).
            Length must equal ``mjx_model.nq``.
        qvel: Velocity state (or slice). Length must equal ``mjx_model.nv``.
        ctrl: Control variable (or slice into the unified ``u`` vector).
            Length must equal ``mjx_model.nu``.
        return_component: Which MJX field to return. ``"qacc"`` (default)
            returns the generalized acceleration ``qacc`` for use as the
            ``qvel`` state derivative. ``"qvel"`` returns ``qvel`` for use as
            the ``qpos`` state derivative (rarely needed because that is
            already symbolic).
        extra_postprocess: Optional callable applied to the MJX ``data`` after
            ``mjx.forward``. Useful for computing custom outputs (e.g. site
            positions) used elsewhere.

    Returns:
        A function ``f(x, u, node, params) -> jnp.ndarray`` matching the BYOF
        dynamics signature.

    Raises:
        ImportError: If ``mujoco.mjx`` is not installed.
        ValueError: If ``return_component`` is not one of the allowed values.

    Note:
        MJX's contact solver uses ``lax.while_loop``, which is **not**
        reverse-mode differentiable. For contact-free systems (manipulators,
        cartpoles, etc.) disable contacts before uploading the model::

            mj_model.opt.disableflags |= mujoco.mjtDisableBit.mjDSBL_CONTACT
            mjx_model = mjx.put_model(mj_model)

    Example:
        Cartpole swing-up dynamics::

            import mujoco
            import mujoco.mjx as mjx
            import openscvx as ox
            from openscvx.integrations import mjx_dynamics

            mj_model = mujoco.MjModel.from_xml_path("cartpole.xml")
            mjx_model = mjx.put_model(mj_model)

            qpos = ox.State("qpos", shape=(mjx_model.nq,))
            qvel = ox.State("qvel", shape=(mjx_model.nv,))
            ctrl = ox.Control("ctrl", shape=(mjx_model.nu,))

            qvel_dynamics = mjx_dynamics(mjx_model, qpos=qpos, qvel=qvel, ctrl=ctrl)

            problem = ox.Problem(
                dynamics={"qpos": qvel},
                byof={"dynamics": {"qvel": qvel_dynamics}},
                states=[qpos, qvel],
                controls=[ctrl],
                ...
            )
    """
    try:
        import mujoco.mjx as mjx
    except ImportError as e:
        raise ImportError(
            "mujoco.mjx is required for mjx_dynamics. Install with: pip install openscvx[mjx]"
        ) from e

    if return_component not in ("qacc", "qvel"):
        raise ValueError(f"return_component must be 'qacc' or 'qvel', got {return_component!r}")

    # Store the raw args; slices are resolved lazily on first call so that
    # mjx_dynamics() can be called before Problem construction assigns .slice.
    _qpos_arg = qpos
    _qvel_arg = qvel
    _ctrl_arg = ctrl
    _resolved: list = []  # populated on first call: [qpos_slice, qvel_slice, ctrl_slice]

    def f(x, u, node, params):
        del node, params  # MJX dynamics are stateless w.r.t. node and OpenSCvx params
        if not _resolved:
            _resolved.append(_resolve_slice(_qpos_arg, "qpos"))
            _resolved.append(_resolve_slice(_qvel_arg, "qvel"))
            _resolved.append(_resolve_slice(_ctrl_arg, "ctrl"))
        qpos_slice, qvel_slice, ctrl_slice = _resolved

        qpos_val = x[qpos_slice]
        qvel_val = x[qvel_slice]
        ctrl_val = u[ctrl_slice]

        data = mjx.make_data(mjx_model)
        data = data.replace(qpos=qpos_val, qvel=qvel_val, ctrl=ctrl_val)
        data = mjx.forward(mjx_model, data)

        if extra_postprocess is not None:
            data = extra_postprocess(data)

        if return_component == "qacc":
            return jnp.asarray(data.qacc)
        return jnp.asarray(data.qvel)

    return f