mjx
MuJoCo MJX dynamics adapters for OpenSCvx BYOF.
The recommended entry-point is :func:mjx_byof, which returns a complete
byof["dynamics"] dict and automatically handles free-joint quaternion
kinematics — no separate imports required:
byof = {"dynamics": mjx_byof(mjx_model, qpos=qpos, qvel=qvel, ctrl=ctrl)}
For models without free joints (cartpoles, manipulators, etc.) the
returned dict contains only "qvel", and qpos kinematics must still be
specified symbolically via dynamics={"qpos": qvel}. For models with
free joints (drones, humanoids) "qpos" is included automatically and no
symbolic dynamics entry is needed.
The lower-level :func:mjx_dynamics is also public for advanced users who
need direct access to the BYOF callable for the qvel derivative.
Note
Time dilation is handled automatically by the BYOF lowering pipeline; all functions return physical (un-dilated) quantities.
mjx_byof(mjx_model: Any, *, qpos: State | slice, qvel: State | slice, ctrl: Control | slice, return_component: str = 'qacc', extra_postprocess: Optional[Callable[[Any], Any]] = None) -> dict
¶
Return a complete byof["dynamics"] dict for a MuJoCo MJX model.
This is the recommended high-level entry-point. It inspects the model's
nq and nv to detect free joints and automatically includes the
quaternion kinematics callable for qpos when needed.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mjx_model
|
Any
|
A model produced by :func: |
required |
qpos
|
State | slice
|
Position state (or slice). Length must equal |
required |
qvel
|
State | slice
|
Velocity state (or slice). Length must equal |
required |
ctrl
|
Control | slice
|
Control variable (or slice). Length must equal |
required |
return_component
|
str
|
Passed to :func: |
'qacc'
|
extra_postprocess
|
Optional[Callable[[Any], Any]]
|
Optional callable applied to the MJX |
None
|
Returns:
| Type | Description |
|---|---|
dict
|
A dict suitable for use as |
dict
|
For models without free joints ( |
dict
|
included; position kinematics should still be provided symbolically |
dict
|
via |
dict
|
For models with free joints ( |
dict
|
|
Example
Cartpole (nq == nv, no free joint)::
byof = {"dynamics": mjx_byof(mjx_model, qpos=qpos, qvel=qvel, ctrl=ctrl)}
problem = ox.Problem(
dynamics={"qpos": qvel}, # still required for non-free models
byof=byof, ...
)
Quadrotor / drone (nq > nv, one free joint)::
byof = {"dynamics": mjx_byof(mjx_model, qpos=qpos, qvel=qvel, ctrl=ctrl)}
problem = ox.Problem(
dynamics={}, # qpos handled automatically
byof=byof, ...
)
Source code in openscvx/integrations/mjx.py
mjx_dynamics(mjx_model: Any, *, qpos: State | slice, qvel: State | slice, ctrl: Control | slice, return_component: str = 'qacc', extra_postprocess: Optional[Callable[[Any], Any]] = None) -> Callable
¶
Wrap a MuJoCo MJX model as a BYOF dynamics function.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mjx_model
|
Any
|
A model produced by :func: |
required |
qpos
|
State | slice
|
Position state (or slice into the unified |
required |
qvel
|
State | slice
|
Velocity state (or slice). Length must equal |
required |
ctrl
|
Control | slice
|
Control variable (or slice into the unified |
required |
return_component
|
str
|
Which MJX field to return. |
'qacc'
|
extra_postprocess
|
Optional[Callable[[Any], Any]]
|
Optional callable applied to the MJX |
None
|
Returns:
| Type | Description |
|---|---|
Callable
|
A function |
Callable
|
dynamics signature. |
Raises:
| Type | Description |
|---|---|
ImportError
|
If |
ValueError
|
If |
Note
MJX's contact solver uses lax.while_loop, which is not
reverse-mode differentiable. For contact-free systems (manipulators,
cartpoles, etc.) disable contacts before uploading the model::
mj_model.opt.disableflags |= mujoco.mjtDisableBit.mjDSBL_CONTACT
mjx_model = mjx.put_model(mj_model)
Example
Cartpole swing-up dynamics::
import mujoco
import mujoco.mjx as mjx
import openscvx as ox
from openscvx.integrations import mjx_dynamics
mj_model = mujoco.MjModel.from_xml_path("cartpole.xml")
mjx_model = mjx.put_model(mj_model)
qpos = ox.State("qpos", shape=(mjx_model.nq,))
qvel = ox.State("qvel", shape=(mjx_model.nv,))
ctrl = ox.Control("ctrl", shape=(mjx_model.nu,))
qvel_dynamics = mjx_dynamics(mjx_model, qpos=qpos, qvel=qvel, ctrl=ctrl)
problem = ox.Problem(
dynamics={"qpos": qvel},
byof={"dynamics": {"qvel": qvel_dynamics}},
states=[qpos, qvel],
controls=[ctrl],
...
)
Source code in openscvx/integrations/mjx.py
47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 | |