Cartpole Mjx¶
Cartpole swing-up using MuJoCo MJX dynamics.
This example demonstrates how to plug a MuJoCo MJX model directly into an OpenSCvx problem via the BYOF dynamics interface. The cart slides along the x-axis and a passive pole hangs from a hinge; the optimization computes a control sequence that swings the pole from downward (theta = pi) to upright (theta = 0).
Requires: pip install openscvx[mjx]
The MJX model is defined inline as an XML string so the example is fully
self-contained. For larger models, mujoco.MjModel.from_xml_path works
exactly the same way.
File: examples/mjx/cartpole_mjx.py
import os
import sys
import numpy as np
current_dir = os.path.dirname(os.path.abspath(__file__))
grandparent_dir = os.path.dirname(os.path.dirname(current_dir))
sys.path.append(grandparent_dir)
try:
import mujoco
import mujoco.mjx as mjx
except ImportError:
print(
"MuJoCo MJX is not installed. Install with: pip install openscvx[mjx]",
file=sys.stderr,
)
sys.exit(1)
import openscvx as ox
from openscvx.integrations import mjx_byof
CARTPOLE_XML = """
<mujoco model="cartpole">
<option gravity="0 0 -9.81" timestep="0.01" integrator="Euler"/>
<worldbody>
<body name="cart" pos="0 0 0">
<joint name="slider" type="slide" axis="1 0 0" limited="true" range="-3 3"/>
<geom name="cart_geom" type="box" size="0.2 0.15 0.1" mass="1.0" rgba="0.4 0.4 0.8 1"/>
<body name="pole" pos="0 0 0">
<joint name="hinge" type="hinge" axis="0 1 0" limited="false"/>
<geom name="pole_geom" type="capsule" fromto="0 0 0 0 0 0.6"
size="0.04" mass="0.1" rgba="0.8 0.4 0.4 1"/>
</body>
</body>
</worldbody>
<actuator>
<motor joint="slider" gear="20" ctrlrange="-1 1" ctrllimited="true"/>
</actuator>
</mujoco>
"""
mj_model = mujoco.MjModel.from_xml_string(CARTPOLE_XML)
# MJX's contact solver uses lax.while_loop which is not reverse-mode
# differentiable. Disabling contact keeps the dynamics fully JAX-differentiable.
mj_model.opt.disableflags |= mujoco.mjtDisableBit.mjDSBL_CONTACT
mjx_model = mjx.put_model(mj_model)
n_q = int(mjx_model.nq)
n_v = int(mjx_model.nv)
n_u = int(mjx_model.nu)
n = 60
total_time = 3.0
qpos = ox.State("qpos", shape=(n_q,))
qpos.min = np.array([-3.0, -2.0 * np.pi])
qpos.max = np.array([3.0, 2.0 * np.pi])
qpos.initial = np.array([0.0, np.pi])
qpos.final = np.array([0.0, 0.0])
qvel = ox.State("qvel", shape=(n_v,))
qvel.min = np.array([-10.0, -15.0])
qvel.max = np.array([10.0, 15.0])
qvel.initial = np.array([0.0, 0.0])
qvel.final = np.array([0.0, 0.0])
ctrl = ox.Control("ctrl", shape=(n_u,))
ctrl.min = np.array([-1.0])
ctrl.max = np.array([1.0])
ctrl.guess = np.zeros((n, n_u))
states = [qpos, qvel]
controls = [ctrl]
dynamics = {
"qpos": qvel,
}
byof: ox.ByofSpec = {"dynamics": mjx_byof(mjx_model, qpos=qpos, qvel=qvel, ctrl=ctrl)}
constraints = []
for state in states:
constraints.extend([ox.ctcs(state <= state.max), ox.ctcs(state.min <= state)])
for control in controls:
constraints.extend([ox.ctcs(control <= control.max), ox.ctcs(control.min <= control)])
theta_guess = np.linspace(np.pi, 0.0, n)
qpos.guess = np.column_stack([np.zeros(n), theta_guess])
qvel.guess = np.zeros((n, n_v))
time = ox.Time(
initial=0.0,
final=ox.Minimize(total_time),
min=0.0,
max=2.0 * total_time,
time_dilation_min=0.05 * total_time,
time_dilation_max=2.0 * total_time,
)
problem = ox.Problem(
dynamics=dynamics,
states=states,
controls=controls,
time=time,
constraints=constraints,
N=n,
byof=byof,
algorithm={
"lam_prox": 1e-1,
"lam_cost": 1e-2,
"lam_vc": 1e0,
"autotuner": ox.ConstantProximalWeight(),
},
float_dtype="float64",
)
POLE_LENGTH = 0.6 # matches <geom fromto="0 0 0 0 0 0.6"> in XML
def visualize(results) -> None:
"""Animate the optimized cartpole trajectory in a Viser 3D scene.
The scene uses the MuJoCo XML convention:
- Cart slides along the x-axis at z = 0.
- Pole hangs from the cart pivot; hinge rotates around y.
- theta = 0 → pole upright (tip at +z).
- theta = pi → pole hanging (tip at -z).
"""
import plotly.graph_objects as go
import viser
from openscvx.plotting.viser import (
add_animated_trail,
add_animation_controls,
compute_velocity_colors,
)
from openscvx.plotting.viser.plotly_integration import add_animated_plotly_vline
# ── Pull data from results.trajectory ─────────────────────────────────────
t_vec = results.trajectory["time"].flatten() # (N_fine,)
q_traj = results.trajectory["qpos"] # (N_fine, 2)
qd_traj = results.trajectory["qvel"] # (N_fine, 2)
u_traj = results.trajectory["ctrl"] # (N_fine, 1)
cart_x = q_traj[:, 0]
theta = q_traj[:, 1]
# Pole tip in world frame (hinge rotates around y-axis)
tip_x = cart_x + POLE_LENGTH * np.sin(theta)
tip_z = POLE_LENGTH * np.cos(theta)
# 3D positions for trail (tip of the pole)
tip_pos = np.column_stack([tip_x, np.zeros_like(tip_x), tip_z])
# ── Viser server ───────────────────────────────────────────────────────────
server = viser.ViserServer()
server.scene.set_up_direction("+z")
# Static rail along x
rail_pts = np.array([[-3.5, 0, 0], [3.5, 0, 0]], dtype=np.float32)
server.scene.add_line_segments(
"/rail",
points=rail_pts[None], # shape (1, 2, 3)
colors=np.array([80, 80, 80], dtype=np.uint8), # (3,) broadcast to all segments
line_width=4.0,
)
# ── Animated cart (box) ────────────────────────────────────────────────────
cart_handle = server.scene.add_box(
"/cart",
dimensions=(0.4, 0.3, 0.2),
position=(float(cart_x[0]), 0.0, 0.0),
color=(80, 100, 200),
)
# ── Animated pole (line segment from pivot to tip) ─────────────────────────
pole_pts = np.array(
[[[float(cart_x[0]), 0.0, 0.0], [float(tip_x[0]), 0.0, float(tip_z[0])]]],
dtype=np.float32,
)
pole_handle = server.scene.add_line_segments(
"/pole",
points=pole_pts,
colors=np.array([200, 80, 80], dtype=np.uint8), # (3,) broadcast to all segments
line_width=6.0,
)
# ── Animated trail of the pole tip ────────────────────────────────────────
tip_colors = compute_velocity_colors(tip_pos)
_, update_trail = add_animated_trail(server, tip_pos, tip_colors, point_size=0.02)
# ── Sidebar plots (phase portrait + control) ───────────────────────────────
fig_phase = go.Figure()
fig_phase.add_trace(
go.Scatter(
x=np.rad2deg(theta).tolist(),
y=qd_traj[:, 1].tolist(),
mode="lines",
line={"color": "royalblue", "width": 2},
name="Phase",
)
)
fig_phase.update_layout(
title="Phase portrait (θ vs θ̇)",
xaxis_title="θ (deg)",
yaxis_title="θ̇ (rad/s)",
margin={"l": 40, "r": 10, "t": 40, "b": 40},
)
fig_ctrl = go.Figure()
fig_ctrl.add_trace(
go.Scatter(
x=t_vec.tolist(),
y=u_traj[:, 0].tolist(),
mode="lines",
line={"color": "darkorange", "width": 2},
name="Control",
)
)
fig_ctrl.update_layout(
title="Control force (normalized)",
xaxis_title="Time (s)",
yaxis_title="u",
margin={"l": 40, "r": 10, "t": 40, "b": 40},
)
with server.gui.add_folder("Plots"):
_, update_phase = add_animated_plotly_vline(server, fig_phase, t_vec, folder_name=None)
_, update_ctrl = add_animated_plotly_vline(server, fig_ctrl, t_vec, folder_name=None)
# ── Per-frame cart + pole update ───────────────────────────────────────────
def update_cartpole(frame_idx: int) -> None:
cx = float(cart_x[frame_idx])
tx = float(tip_x[frame_idx])
tz = float(tip_z[frame_idx])
cart_handle.position = (cx, 0.0, 0.0)
pole_handle.points = np.array([[[cx, 0.0, 0.0], [tx, 0.0, tz]]], dtype=np.float32)
# ── Animation controls (play / pause / scrub) ──────────────────────────────
add_animation_controls(
server,
t_vec,
[update_cartpole, update_trail, update_phase, update_ctrl],
)
server.sleep_forever()
if __name__ == "__main__":
print("Cartpole swing-up via MuJoCo MJX dynamics")
print("=" * 60)
print(f"nq = {n_q}, nv = {n_v}, nu = {n_u}, N = {n}")
print()
problem.initialize()
problem.solve()
results = problem.post_process()
final_theta = float(results.nodes["qpos"][-1, 1])
final_thetadot = float(results.nodes["qvel"][-1, 1])
print()
print(f"Final pole angle: {np.rad2deg(final_theta):.2f} deg (target: 0.0)")
print(f"Final pole angle rate: {final_thetadot:.4f} rad/s (target: 0.0)")
print()
visualize(results)