Triple Cartpole Mjx¶
Triple-link cartpole swing-up using MuJoCo MJX dynamics.
Extends the single-link cartpole to three serial links (a triple pendulum on a cart). The optimizer drives all three links from the hanging equilibrium (θ₁=π, θ₂=0, θ₃=0) to the unstable upright equilibrium (θ₁=θ₂=θ₃=0) using a single horizontal force applied to the cart — a classic underactuated control benchmark.
State : qpos = [cart_x, θ₁, θ₂, θ₃], qvel = [ẋ, θ̇₁, θ̇₂, θ̇₃] Control: ctrl = [F_cart] (normalised; gear=60 → max ±60 N)
Link lengths: L₁=0.5 m, L₂=0.4 m, L₃=0.3 m.
Requires: pip install openscvx[mjx]
File: examples/mjx/triple_cartpole_mjx.py
import os
import sys
import numpy as np
current_dir = os.path.dirname(os.path.abspath(__file__))
grandparent_dir = os.path.dirname(os.path.dirname(current_dir))
sys.path.append(grandparent_dir)
try:
import mujoco
import mujoco.mjx as mjx
except ImportError:
print(
"MuJoCo MJX is not installed. Install with: pip install openscvx[mjx]",
file=sys.stderr,
)
sys.exit(1)
import openscvx as ox
from openscvx import ByofSpec, Problem
from openscvx.integrations import mjx_byof
L1, L2, L3 = 0.5, 0.4, 0.3 # link lengths (m)
TRIPLE_CARTPOLE_XML = f"""
<mujoco model="triple_cartpole">
<option gravity="0 0 -9.81" timestep="0.005" integrator="Euler"/>
<worldbody>
<body name="cart" pos="0 0 0">
<joint name="slider" type="slide" axis="1 0 0" limited="true" range="-10 10"/>
<geom name="cart_geom" type="box" size="0.25 0.15 0.1"
mass="2.0" rgba="0.35 0.35 0.75 1"/>
<!-- Link 1 — pivot at cart centre (z=0) -->
<body name="link1" pos="0 0 0">
<joint name="hinge1" type="hinge" axis="0 1 0" limited="false"/>
<geom name="pole1" type="capsule" fromto="0 0 0 0 0 {L1}"
size="0.06" mass="2.0" rgba="0.85 0.3 0.3 1"/>
<!-- Link 2 — pivot at tip of link 1 -->
<body name="link2" pos="0 0 {L1}">
<joint name="hinge2" type="hinge" axis="0 1 0" limited="false"/>
<geom name="pole2" type="capsule" fromto="0 0 0 0 0 {L2}"
size="0.035" mass="1.25" rgba="0.3 0.8 0.3 1"/>
<!-- Link 3 — pivot at tip of link 2 -->
<body name="link3" pos="0 0 {L2}">
<joint name="hinge3" type="hinge" axis="0 1 0" limited="false"/>
<geom name="pole3" type="capsule" fromto="0 0 0 0 0 {L3}"
size="0.03" mass="0.75" rgba="0.3 0.3 0.85 1"/>
</body>
</body>
</body>
</body>
</worldbody>
<actuator>
<!-- Single actuator on the cart slider; gear scales ctrl ∈ [−1,1] to force in N -->
<motor joint="slider" name="cart_force" gear="60"
ctrlrange="-3 3" ctrllimited="true"/>
</actuator>
</mujoco>
"""
mj_model = mujoco.MjModel.from_xml_string(TRIPLE_CARTPOLE_XML)
# Contacts not needed; disabling them keeps MJX forward-dynamics JAX-differentiable.
mj_model.opt.disableflags |= mujoco.mjtDisableBit.mjDSBL_CONTACT
mjx_model = mjx.put_model(mj_model)
n_q = int(mjx_model.nq) # 4: cart_x, θ1, θ2, θ3
n_v = int(mjx_model.nv) # 4: ẋ, θ̇1, θ̇2, θ̇3 (nq == nv, no quaternion)
n_u = int(mjx_model.nu) # 1: cart force
n = 60 # more nodes → finer resolution near the unstable upright equilibrium
total_time = 2.5
# ── State / control definitions ───────────────────────────────────────────────
qpos = ox.State("qpos", shape=(n_q,))
qpos.min = np.array([-100.0, -2 * np.pi, -2 * np.pi, -2 * np.pi])
qpos.max = np.array([100.0, 2 * np.pi, 2 * np.pi, 2 * np.pi])
qpos.initial = np.array([0.0, np.pi, 0.0, 0.0]) # all links hanging down
qpos.final = [ox.Free(0.0), 0.0, 0.0, 0.0] # all links upright
qvel = ox.State("qvel", shape=(n_v,))
qvel.min = np.array([-12.0, -12.0, -12.0, -12.0])
qvel.max = np.array([12.0, 12.0, 12.0, 12.0])
qvel.initial = np.zeros(n_v)
qvel.final = [0.0, 0.0, 0.0, 0.0]
ctrl = ox.Control("ctrl", shape=(n_u,))
ctrl.min = np.array([-3.0])
ctrl.max = np.array([3.0])
ctrl.guess = np.zeros((n, n_u))
states = [qpos, qvel]
controls = [ctrl]
# ── Dynamics: position kinematics symbolically, velocity via MJX ──────────────
dynamics: dict = {"qpos": qvel} # nq==nv so this is always valid
byof: ByofSpec = {"dynamics": mjx_byof(mjx_model, qpos=qpos, qvel=qvel, ctrl=ctrl)}
# ── Constraints (CTCS on state / control bounds) ───────────────────────────────
constraints = []
for state in states:
constraints.extend([ox.ctcs(state <= state.max), ox.ctcs(state.min <= state)])
# ── Initial guess: linearly swing θ₁ from π → 0, others stay 0 ───────────────
th1_guess = np.linspace(np.pi, 0.0, n)
qpos.guess = np.column_stack([np.zeros(n), th1_guess, np.zeros(n), np.zeros(n)])
qvel.guess = np.zeros((n, n_v))
time = ox.Time(
initial=0.0,
final=total_time,
min=0.0,
max=total_time,
)
problem = Problem(
dynamics=dynamics,
states=states,
controls=controls,
time=time,
constraints=constraints,
N=n,
byof=byof,
algorithm={
"lam_prox": 1e0,
"lam_cost": 0e0,
"lam_vc": 4e1,
"autotuner": ox.ConstantProximalWeight(),
},
float_dtype="float64",
)
# ── Forward kinematics helpers ─────────────────────────────────────────────────
def fk_joints(q: np.ndarray) -> tuple[np.ndarray, ...]:
"""Return world-frame positions (in XZ plane) of cart, h1, h2, h3, tip."""
cx = float(q[0])
t1, t2, t3 = float(q[1]), float(q[2]), float(q[3])
cart = np.array([cx, 0.0, 0.0])
h1 = cart # hinge1 is at cart centre
h2 = h1 + np.array([L1 * np.sin(t1), 0.0, L1 * np.cos(t1)])
h3 = h2 + np.array([L2 * np.sin(t1 + t2), 0.0, L2 * np.cos(t1 + t2)])
tip = h3 + np.array([L3 * np.sin(t1 + t2 + t3), 0.0, L3 * np.cos(t1 + t2 + t3)])
return cart, h1, h2, h3, tip
def qpos_from_V_multishot(
V: np.ndarray,
*,
n_q: int,
n_v: int,
n_u: int,
t_nodes: np.ndarray,
) -> tuple[np.ndarray | None, np.ndarray | None]:
"""Unpack generalized coordinates from the SCP multi-shoot matrix ``V``.
``V`` has shape ``((N - 1) * i4, n_substeps)`` where each segment occupies ``i4``
consecutive rows (state + sensitivities). The first ``n_x = n_q + n_v`` rows of
each segment block are the propagated state at that substep.
Returns ``(qpos, t_sample)`` with one row per substep (duplicate boundary samples
between segments are skipped), or ``(None, None)`` if ``V`` cannot be decoded.
"""
if V.size == 0:
return None, None
n_x = n_q + n_v
i4 = n_x + n_x * n_x + 2 * n_x * n_u
n_rows, n_sub = V.shape
if i4 <= 0 or n_rows % i4 != 0 or n_sub < 1:
return None, None
n_seg = n_rows // i4
if n_seg != len(t_nodes) - 1:
return None, None
q_rows: list[np.ndarray] = []
t_rows: list[float] = []
for seg in range(n_seg):
t0 = float(t_nodes[seg])
t1 = float(t_nodes[seg + 1])
j0 = 0 if seg == 0 else 1 # skip duplicated state at segment joints
for j in range(j0, n_sub):
alpha = j / (n_sub - 1) if n_sub > 1 else 0.0
t_s = (1.0 - alpha) * t0 + alpha * t1
row0 = seg * i4
x_vec = np.asarray(V[row0 : row0 + n_x, j], dtype=np.float64).ravel()
q_rows.append(x_vec[:n_q])
t_rows.append(t_s)
if not q_rows:
return None, None
return np.stack(q_rows, axis=0), np.asarray(t_rows, dtype=np.float64)
def visualize(results) -> None:
"""Animate the triple-link cartpole in a Viser 3D scene (multi-shoot integrated path only)."""
import plotly.graph_objects as go
import viser
from openscvx.plotting.viser import (
add_animated_trail,
add_animation_controls,
compute_velocity_colors,
)
from openscvx.plotting.viser.plotly_integration import add_animated_plotly_vline
# ── Extract trajectory data ────────────────────────────────────────────────
t_vec = results.trajectory["time"].flatten() # (N_fine,) — playback clock
u_traj = results.trajectory["ctrl"] # (N_fine, 1)
q_nodes = results.nodes["qpos"] # (n_nodes, 4)
t_nodes = results.nodes.get("time", None)
if t_nodes is None:
t_nodes = np.linspace(float(t_vec[0]), float(t_vec[-1]), len(q_nodes))
else:
t_nodes = np.asarray(t_nodes).flatten()
N = len(t_vec)
# Multi-shoot integrated trajectory from ``V`` (last SCP discretization matrix).
_dh = getattr(results, "discretization_history", None) or []
V_multishot = _dh[-1] if len(_dh) > 0 else None
q_ms_v, t_ms_v = (
qpos_from_V_multishot(
np.asarray(V_multishot, dtype=np.float64),
n_q=n_q,
n_v=n_v,
n_u=n_u,
t_nodes=t_nodes,
)
if V_multishot is not None
else (None, None)
)
if q_ms_v is not None and t_ms_v is not None:
fk_multishot_anim = [fk_joints(q_ms_v[i]) for i in range(len(q_ms_v))]
t_multishot_lookup = t_ms_v
t_angle = t_ms_v
q_angle = q_ms_v
else:
# Fallback: interpolate nodal qpos onto the post-process time grid.
q_aligned = np.column_stack(
[np.interp(t_vec, t_nodes, q_nodes[:, j]) for j in range(q_nodes.shape[1])]
)
fk_multishot_anim = [fk_joints(q_aligned[i]) for i in range(N)]
t_multishot_lookup = t_vec
t_angle = t_vec
q_angle = q_aligned
# Tip trail samples (multi-shoot pose at each playback time)
tip_pos = np.zeros((N, 3))
for i in range(N):
ms_i = int(np.argmin(np.abs(t_multishot_lookup - float(t_vec[i]))))
ms_i = int(np.clip(ms_i, 0, len(fk_multishot_anim) - 1))
tip_pos[i] = fk_multishot_anim[ms_i][4]
# ── Viser server ───────────────────────────────────────────────────────────
server = viser.ViserServer()
server.scene.set_up_direction("+z")
# Static rail
rail = np.array([[[-4.0, 0.0, 0.0], [4.0, 0.0, 0.0]]], dtype=np.float32)
server.scene.add_line_segments(
"/rail",
points=rail,
colors=np.array([80, 80, 80], dtype=np.uint8),
line_width=4.0,
)
# Upright target marker at top of triple stack
upright_tip = np.array([0.0, 0.0, L1 + L2 + L3])
server.scene.add_icosphere(
"/target",
radius=0.05,
color=(50, 220, 100),
position=tuple(float(v) for v in upright_tip),
)
n_nodes = len(q_nodes)
# Multishot path: dense integrated samples from ``V_multishot``, else chords between nodes.
if q_ms_v is not None and len(q_ms_v) >= 2:
fk_ms_poly = [fk_joints(q_ms_v[i]) for i in range(len(q_ms_v))]
cart_ms_xy = np.array(
[[float(fk_ms_poly[i][0][0]), 0.0, 0.0] for i in range(len(fk_ms_poly))],
dtype=np.float32,
)
tip_ms_xy = np.array([fk_ms_poly[i][4] for i in range(len(fk_ms_poly))], dtype=np.float32)
cart_multishot_segs = np.stack(
[
np.stack([cart_ms_xy[i], cart_ms_xy[i + 1]], axis=0)
for i in range(len(cart_ms_xy) - 1)
],
axis=0,
)
tip_multishot_segs = np.stack(
[np.stack([tip_ms_xy[i], tip_ms_xy[i + 1]], axis=0) for i in range(len(tip_ms_xy) - 1)],
axis=0,
)
elif n_nodes >= 2:
fk_nd = [fk_joints(q_nodes[i]) for i in range(n_nodes)]
cart_node_pos = np.array(
[[float(fk_nd[i][0][0]), 0.0, 0.0] for i in range(n_nodes)],
dtype=np.float32,
)
tip_node_pos = np.array([fk_nd[i][4] for i in range(n_nodes)], dtype=np.float32)
cart_multishot_segs = np.stack(
[
np.stack([cart_node_pos[i], cart_node_pos[i + 1]], axis=0)
for i in range(n_nodes - 1)
],
axis=0,
)
tip_multishot_segs = np.stack(
[np.stack([tip_node_pos[i], tip_node_pos[i + 1]], axis=0) for i in range(n_nodes - 1)],
axis=0,
)
else:
cart_multishot_segs = None
if cart_multishot_segs is not None:
server.scene.add_line_segments(
"/multishot/cart_path",
points=cart_multishot_segs,
colors=np.array([90, 90, 230], dtype=np.uint8),
line_width=3.5,
)
server.scene.add_line_segments(
"/multishot/tip_path",
points=tip_multishot_segs,
colors=np.array([230, 170, 40], dtype=np.uint8),
line_width=3.5,
)
# ── Animated multishot rig (aligned to same playback clock) ───────────────
ms_cart_handle = server.scene.add_box(
"/multishot/cart",
dimensions=(0.42, 0.24, 0.16),
position=tuple(float(v) for v in fk_multishot_anim[0][0]),
color=(55, 150, 255),
)
def _multishot_link_segments(i: int) -> np.ndarray:
_, h1, h2, h3, tip = fk_multishot_anim[i]
return np.array([[h1, h2], [h2, h3], [h3, tip]], dtype=np.float32)
ms_link_colors = np.array(
[
[[120, 190, 255], [120, 190, 255]],
[[95, 175, 245], [95, 175, 245]],
[[70, 160, 235], [70, 160, 235]],
],
dtype=np.uint8,
)
ms_link_handle = server.scene.add_line_segments(
"/multishot/links",
points=_multishot_link_segments(0),
colors=ms_link_colors,
line_width=5.0,
)
ms_joint_handles = []
for jname in ["/multishot/j1", "/multishot/j2", "/multishot/j3"]:
h = server.scene.add_icosphere(
jname,
radius=0.03,
color=(80, 170, 245),
position=tuple(float(v) for v in fk_multishot_anim[0][1]),
)
ms_joint_handles.append(h)
# ── Animated tip trail ─────────────────────────────────────────────────────
tip_colors = compute_velocity_colors(tip_pos)
_, update_trail = add_animated_trail(server, tip_pos, tip_colors, point_size=0.02)
# ── Sidebar: joint angles ──────────────────────────────────────────────────
fig_angles = go.Figure()
angle_names = ["θ₁ (link 1)", "θ₂ (link 2)", "θ₃ (link 3)"]
angle_colors = ["royalblue", "darkorange", "green"]
for k in range(3):
fig_angles.add_trace(
go.Scatter(
x=t_angle.tolist(),
y=np.rad2deg(q_angle[:, k + 1]).tolist(),
mode="lines",
name=angle_names[k],
line={"color": angle_colors[k], "width": 2},
)
)
fig_angles.add_hline(y=0, line_dash="dash", line_color="gray", annotation_text="Upright")
fig_angles.update_layout(
title="Joint angles",
xaxis_title="Time (s)",
yaxis_title="Angle (deg)",
legend={"orientation": "h"},
margin={"l": 40, "r": 10, "t": 40, "b": 40},
)
# ── Sidebar: control force ─────────────────────────────────────────────────
fig_ctrl = go.Figure()
fig_ctrl.add_trace(
go.Scatter(
x=t_vec.tolist(),
y=u_traj[:, 0].tolist(),
mode="lines",
name="Cart force",
line={"color": "crimson", "width": 2},
)
)
fig_ctrl.add_hline(y=0, line_dash="dash", line_color="gray")
fig_ctrl.update_layout(
title="Cart control (normalised)",
xaxis_title="Time (s)",
yaxis_title="u",
margin={"l": 40, "r": 10, "t": 40, "b": 40},
)
with server.gui.add_folder("Plots"):
_, update_angles = add_animated_plotly_vline(server, fig_angles, t_vec, folder_name=None)
_, update_ctrl = add_animated_plotly_vline(server, fig_ctrl, t_vec, folder_name=None)
def update_multishot_scene(frame_idx: int) -> None:
t_cur = float(t_vec[frame_idx])
ms_i = int(np.argmin(np.abs(t_multishot_lookup - t_cur)))
ms_i = int(np.clip(ms_i, 0, len(fk_multishot_anim) - 1))
_, h1, h2, h3, _ = fk_multishot_anim[ms_i]
ms_cart_x = float(fk_multishot_anim[ms_i][0][0])
ms_cart_handle.position = (ms_cart_x, 0.0, 0.0)
ms_link_handle.points = _multishot_link_segments(ms_i)
for handle, pos in zip(ms_joint_handles, (h1, h2, h3)):
handle.position = tuple(float(v) for v in pos)
# ── Animation controls ─────────────────────────────────────────────────────
add_animation_controls(
server,
t_vec,
[
update_multishot_scene,
update_trail,
update_angles,
update_ctrl,
],
)
print("Viser running — open http://localhost:8080 in your browser.")
server.sleep_forever()
if __name__ == "__main__":
print("Triple-link cartpole swing-up — MuJoCo MJX + OpenSCvx")
print("=" * 60)
print(f"nq={n_q}, nv={n_v}, nu={n_u}, N={n}")
print(f"Links: L1={L1} m, L2={L2} m, L3={L3} m (total {L1 + L2 + L3} m)")
print()
problem.initialize()
problem.solve()
results = problem.post_process()
final_q = results.nodes["qpos"][-1]
final_qd = results.nodes["qvel"][-1]
print()
print(f"Final joint angles [deg]: {np.rad2deg(final_q[1:])}")
print(f"Final cart position: {final_q[0]:.4f} m")
print(f"Final joint rates [rad/s]:{final_qd[1:]}")
from openscvx.plotting import plot_controls, plot_states
plot_states(results).show()
plot_controls(results).show()
visualize(results)