Triple Cartpole 3d Mjx¶
3D triple-link cartpole swing-up using MuJoCo MJX dynamics.
Extends triple_cartpole_mjx.py to 3D:
- The cart slides on a horizontal plane (two slide joints, X and Y).
- Each pendulum link has a 2-DOF universal joint — two hinges with perpendicular axes (around the parent's X then Y) — giving spherical pendulum motion without the quaternion bookkeeping of a ball joint.
The resulting system has
nq = nv = 8 (cart_x, cart_y, α₁, β₁, α₂, β₂, α₃, β₃)
nu = 2 (cart force in X and Y)
The optimizer drives all three links from the hanging equilibrium (α₁=π, every other angle 0) to the unstable upright equilibrium (all angles 0) using only horizontal forces on the cart.
Link convention¶
Each link's geom extends in its parent's local +Z direction. With both hinge angles equal to 0 the link points along the parent's +Z axis. For the bottom link this means upright at angle 0; the link hangs straight down at α=π.
Both 2-DOF parameterisations (X-then-Y and Y-then-X intrinsic rotations) have a coordinate singularity when the second hinge angle reaches ±π/2. The straight swing-up trajectory keeps β=0 throughout, so the parameterisation is well-conditioned for this problem.
State : qpos = [cart_x, cart_y, α₁, β₁, α₂, β₂, α₃, β₃], qvel = q̇ Control: ctrl = [F_x, F_y] (normalised; gear=60 → ±60 N per axis)
Link lengths: L₁=0.5 m, L₂=0.4 m, L₃=0.3 m.
Requires: pip install openscvx[mjx]
File: examples/mjx/triple_cartpole_3d_mjx.py
import os
import sys
import numpy as np
current_dir = os.path.dirname(os.path.abspath(__file__))
grandparent_dir = os.path.dirname(os.path.dirname(current_dir))
sys.path.append(grandparent_dir)
try:
import mujoco
import mujoco.mjx as mjx
except ImportError:
print(
"MuJoCo MJX is not installed. Install with: pip install openscvx[mjx]",
file=sys.stderr,
)
sys.exit(1)
import openscvx as ox
from openscvx import ByofSpec, Problem
from openscvx.integrations import mjx_byof
L1, L2, L3 = 0.5, 0.4, 0.3 # link lengths (m)
TRIPLE_CARTPOLE_3D_XML = f"""
<mujoco model="triple_cartpole_3d">
<option gravity="0 0 -9.81" timestep="0.005" integrator="Euler"/>
<worldbody>
<body name="cart" pos="0 0 0">
<!-- Cart slides freely on the horizontal plane -->
<joint name="slide_x" type="slide" axis="1 0 0" limited="true" range="-4 4"/>
<joint name="slide_y" type="slide" axis="0 1 0" limited="true" range="-4 4"/>
<geom name="cart_geom" type="box" size="0.25 0.25 0.1"
mass="2.0" rgba="0.35 0.35 0.75 1"/>
<!-- Link 1 — 2-DOF universal joint at cart centre -->
<body name="link1" pos="0 0 0">
<joint name="hinge1_x" type="hinge" axis="1 0 0" limited="false"/>
<joint name="hinge1_y" type="hinge" axis="0 1 0" limited="false"/>
<geom name="pole1" type="capsule" fromto="0 0 0 0 0 {L1}"
size="0.04" mass="0.5" rgba="0.85 0.3 0.3 1"/>
<!-- Link 2 — 2-DOF universal joint at tip of link 1 -->
<body name="link2" pos="0 0 {L1}">
<joint name="hinge2_x" type="hinge" axis="1 0 0" limited="false"/>
<joint name="hinge2_y" type="hinge" axis="0 1 0" limited="false"/>
<geom name="pole2" type="capsule" fromto="0 0 0 0 0 {L2}"
size="0.035" mass="0.4" rgba="0.3 0.8 0.3 1"/>
<!-- Link 3 — 2-DOF universal joint at tip of link 2 -->
<body name="link3" pos="0 0 {L2}">
<joint name="hinge3_x" type="hinge" axis="1 0 0" limited="false"/>
<joint name="hinge3_y" type="hinge" axis="0 1 0" limited="false"/>
<geom name="pole3" type="capsule" fromto="0 0 0 0 0 {L3}"
size="0.03" mass="0.3" rgba="0.3 0.3 0.85 1"/>
</body>
</body>
</body>
</body>
</worldbody>
<actuator>
<!-- Two actuators on the cart slide joints; gear scales ctrl ∈ [−1,1] to N -->
<motor joint="slide_x" name="cart_force_x" gear="60"
ctrlrange="-1 1" ctrllimited="true"/>
<motor joint="slide_y" name="cart_force_y" gear="60"
ctrlrange="-1 1" ctrllimited="true"/>
</actuator>
</mujoco>
"""
mj_model = mujoco.MjModel.from_xml_string(TRIPLE_CARTPOLE_3D_XML)
mj_model.opt.disableflags |= mujoco.mjtDisableBit.mjDSBL_CONTACT
mjx_model = mjx.put_model(mj_model)
n_q = int(mjx_model.nq) # 8: cart_xy + 3 × (αᵢ, βᵢ)
n_v = int(mjx_model.nv) # 8 (nq == nv)
n_u = int(mjx_model.nu) # 2: F_x, F_y
n = 60
total_time = 4.0
# Convenience indices for clarity.
IDX_CART_X, IDX_CART_Y = 0, 1
IDX_A1, IDX_B1 = 2, 3
IDX_A2, IDX_B2 = 4, 5
IDX_A3, IDX_B3 = 6, 7
# ── State / control definitions ───────────────────────────────────────────────
qpos = ox.State("qpos", shape=(n_q,))
qpos.min = np.array(
[-8.0, -8.0, -2 * np.pi, -2 * np.pi, -2 * np.pi, -2 * np.pi, -2 * np.pi, -2 * np.pi]
)
qpos.max = -qpos.min
# Hanging: cart at origin, link 1 hanging (α₁ = π), all others 0.
qpos.initial = np.array([-0.5, 1.0, np.pi, np.pi / 64, 0.0, 0.0, 0.0, 0.0])
# Upright: cart free, all angles 0.
qpos.final = [
ox.Free(0.0),
ox.Free(0.0), # cart x, y free
0.0,
0.0, # link 1 upright
0.0,
0.0, # link 2 upright
0.0,
0.0, # link 3 upright
]
qvel = ox.State("qvel", shape=(n_v,))
qvel.min = np.full(n_v, -12.0)
qvel.max = np.full(n_v, 12.0)
qvel.initial = np.zeros(n_v)
qvel.final = [0.0] * n_v
ctrl = ox.Control("ctrl", shape=(n_u,))
ctrl.min = np.array([-2.0, -2.0])
ctrl.max = np.array([2.0, 2.0])
ctrl.guess = np.zeros((n, n_u))
states = [qpos, qvel]
controls = [ctrl]
# ── Dynamics: position kinematics symbolically, velocity via MJX ──────────────
dynamics: dict = {"qpos": qvel} # nq == nv
byof: ByofSpec = {"dynamics": mjx_byof(mjx_model, qpos=qpos, qvel=qvel, ctrl=ctrl)}
# ── Constraints (CTCS on state / control bounds) ─────────────────────────────
constraints = []
# ── Initial guess: linearly swing α₁ from π → 0; everything else stays 0 ─────
a1_guess = np.linspace(np.pi, 0.0, n)
qpos_guess = np.zeros((n, n_q))
qpos_guess[:, IDX_A1] = a1_guess
qpos.guess = qpos_guess
qvel.guess = np.zeros((n, n_v))
time = ox.Time(
initial=0.0,
final=total_time,
min=0.0,
max=total_time,
)
problem = Problem(
dynamics=dynamics,
states=states,
controls=controls,
time=time,
constraints=constraints,
N=n,
byof=byof,
algorithm={
"lam_prox": 1e-1,
"lam_cost": 0e0,
"lam_vc": 4e0,
"autotuner": ox.ConstantProximalWeight(),
},
float_dtype="float64",
)
# ── Forward kinematics helpers ────────────────────────────────────────────────
def _R_x(a: float) -> np.ndarray:
c, s = np.cos(a), np.sin(a)
return np.array([[1, 0, 0], [0, c, -s], [0, s, c]], dtype=np.float64)
def _R_y(b: float) -> np.ndarray:
c, s = np.cos(b), np.sin(b)
return np.array([[c, 0, s], [0, 1, 0], [-s, 0, c]], dtype=np.float64)
def fk_joints(q: np.ndarray) -> tuple[np.ndarray, ...]:
"""Return world-frame positions of cart, h1, h2, h3, tip (each 3-vectors).
The pendulum joint frames compose intrinsically: each link's hinges
rotate the body frame inherited from the parent. In rotation-matrix
form, the cumulative orientation of the i-th link is
R_i = R_x(α₁) R_y(β₁) … R_x(αᵢ) R_y(βᵢ),
and the link extends along its body +Z axis, so its tip is
``parent_pos + Lᵢ · (R_i @ ẑ)``.
"""
cart = np.array([float(q[IDX_CART_X]), float(q[IDX_CART_Y]), 0.0])
z_hat = np.array([0.0, 0.0, 1.0])
R = np.eye(3)
h1 = cart # hinge 1 at cart centre
R = R @ _R_x(float(q[IDX_A1])) @ _R_y(float(q[IDX_B1]))
h2 = h1 + L1 * (R @ z_hat)
R = R @ _R_x(float(q[IDX_A2])) @ _R_y(float(q[IDX_B2]))
h3 = h2 + L2 * (R @ z_hat)
R = R @ _R_x(float(q[IDX_A3])) @ _R_y(float(q[IDX_B3]))
tip = h3 + L3 * (R @ z_hat)
return cart, h1, h2, h3, tip
def qpos_from_V_multishot(
V: np.ndarray,
*,
n_q: int,
n_v: int,
n_u: int,
t_nodes: np.ndarray,
) -> tuple[np.ndarray | None, np.ndarray | None]:
"""Unpack generalized coordinates from the SCP multi-shoot matrix ``V``.
Identical structure to the 2D version — generic in (n_q, n_v, n_u).
"""
if V.size == 0:
return None, None
n_x = n_q + n_v
i4 = n_x + n_x * n_x + 2 * n_x * n_u
n_rows, n_sub = V.shape
if i4 <= 0 or n_rows % i4 != 0 or n_sub < 1:
return None, None
n_seg = n_rows // i4
if n_seg != len(t_nodes) - 1:
return None, None
q_rows: list[np.ndarray] = []
t_rows: list[float] = []
for seg in range(n_seg):
t0 = float(t_nodes[seg])
t1 = float(t_nodes[seg + 1])
j0 = 0 if seg == 0 else 1
for j in range(j0, n_sub):
alpha = j / (n_sub - 1) if n_sub > 1 else 0.0
t_s = (1.0 - alpha) * t0 + alpha * t1
row0 = seg * i4
x_vec = np.asarray(V[row0 : row0 + n_x, j], dtype=np.float64).ravel()
q_rows.append(x_vec[:n_q])
t_rows.append(t_s)
if not q_rows:
return None, None
return np.stack(q_rows, axis=0), np.asarray(t_rows, dtype=np.float64)
def visualize(results) -> None:
"""Animate the 3D triple-link cartpole in a Viser scene (multi-shoot integrated path only)."""
import plotly.graph_objects as go
import viser
from openscvx.plotting.viser import (
add_animated_trail,
add_animation_controls,
compute_velocity_colors,
)
from openscvx.plotting.viser.plotly_integration import add_animated_plotly_vline
# ── Extract trajectory data ────────────────────────────────────────────────
t_vec = results.trajectory["time"].flatten() # playback clock
u_traj = results.trajectory["ctrl"]
q_nodes = results.nodes["qpos"]
t_nodes = results.nodes.get("time", None)
if t_nodes is None:
t_nodes = np.linspace(float(t_vec[0]), float(t_vec[-1]), len(q_nodes))
else:
t_nodes = np.asarray(t_nodes).flatten()
N = len(t_vec)
# Multishot path from V (last SCP discretization matrix)
_dh = getattr(results, "discretization_history", None) or []
V_multishot = _dh[-1] if len(_dh) > 0 else None
q_ms_v, t_ms_v = (
qpos_from_V_multishot(
np.asarray(V_multishot, dtype=np.float64),
n_q=n_q,
n_v=n_v,
n_u=n_u,
t_nodes=t_nodes,
)
if V_multishot is not None
else (None, None)
)
if q_ms_v is not None and t_ms_v is not None:
fk_multishot_anim = [fk_joints(q_ms_v[i]) for i in range(len(q_ms_v))]
t_multishot_lookup = t_ms_v
t_angle = t_ms_v
q_angle = q_ms_v
else:
q_aligned = np.column_stack(
[np.interp(t_vec, t_nodes, q_nodes[:, j]) for j in range(q_nodes.shape[1])]
)
fk_multishot_anim = [fk_joints(q_aligned[i]) for i in range(N)]
t_multishot_lookup = t_vec
t_angle = t_vec
q_angle = q_aligned
tip_pos = np.zeros((N, 3))
for i in range(N):
ms_i = int(np.argmin(np.abs(t_multishot_lookup - float(t_vec[i]))))
ms_i = int(np.clip(ms_i, 0, len(fk_multishot_anim) - 1))
tip_pos[i] = fk_multishot_anim[ms_i][4]
n_nodes = len(q_nodes)
# ── Viser server ───────────────────────────────────────────────────────────
server = viser.ViserServer()
server.scene.set_up_direction("+z")
# Cart-motion plane (a faint horizontal grid + outline box).
server.scene.add_grid(
"/plane",
width=10.0,
height=10.0,
cell_size=0.5,
position=(0.0, 0.0, -0.105),
)
plane_outline = np.array(
[
[
[-4.0, -4.0, 0.0],
[4.0, -4.0, 0.0],
],
[
[4.0, -4.0, 0.0],
[4.0, 4.0, 0.0],
],
[
[4.0, 4.0, 0.0],
[-4.0, 4.0, 0.0],
],
[
[-4.0, 4.0, 0.0],
[-4.0, -4.0, 0.0],
],
],
dtype=np.float32,
)
server.scene.add_line_segments(
"/plane/outline",
points=plane_outline,
colors=np.array([110, 110, 110], dtype=np.uint8),
line_width=2.0,
)
# Upright target marker
upright_tip = np.array([0.0, 0.0, L1 + L2 + L3])
server.scene.add_icosphere(
"/target",
radius=0.05,
color=(50, 220, 100),
position=tuple(float(v) for v in upright_tip),
)
# ── Multishot polylines (cart path on plane + tip path in 3D) ────────────
if q_ms_v is not None and len(q_ms_v) >= 2:
fk_ms_poly = [fk_joints(q_ms_v[i]) for i in range(len(q_ms_v))]
cart_ms = np.array(
[
[float(fk_ms_poly[i][0][0]), float(fk_ms_poly[i][0][1]), 0.0]
for i in range(len(fk_ms_poly))
],
dtype=np.float32,
)
tip_ms = np.array([fk_ms_poly[i][4] for i in range(len(fk_ms_poly))], dtype=np.float32)
cart_multishot_segs = np.stack(
[np.stack([cart_ms[i], cart_ms[i + 1]], axis=0) for i in range(len(cart_ms) - 1)],
axis=0,
)
tip_multishot_segs = np.stack(
[np.stack([tip_ms[i], tip_ms[i + 1]], axis=0) for i in range(len(tip_ms) - 1)], axis=0
)
elif n_nodes >= 2:
fk_nd = [fk_joints(q_nodes[i]) for i in range(n_nodes)]
cart_node_pos = np.array(
[[float(fk_nd[i][0][0]), float(fk_nd[i][0][1]), 0.0] for i in range(n_nodes)],
dtype=np.float32,
)
tip_node_pos = np.array([fk_nd[i][4] for i in range(n_nodes)], dtype=np.float32)
cart_multishot_segs = np.stack(
[
np.stack([cart_node_pos[i], cart_node_pos[i + 1]], axis=0)
for i in range(n_nodes - 1)
],
axis=0,
)
tip_multishot_segs = np.stack(
[np.stack([tip_node_pos[i], tip_node_pos[i + 1]], axis=0) for i in range(n_nodes - 1)],
axis=0,
)
else:
cart_multishot_segs = None
if cart_multishot_segs is not None:
server.scene.add_line_segments(
"/multishot/cart_path",
points=cart_multishot_segs,
colors=np.array([90, 90, 230], dtype=np.uint8),
line_width=3.5,
)
server.scene.add_line_segments(
"/multishot/tip_path",
points=tip_multishot_segs,
colors=np.array([230, 170, 40], dtype=np.uint8),
line_width=3.5,
)
# ── Multishot animated rig ─────────────────────────────────────────────────
ms_cart_handle = server.scene.add_box(
"/multishot/cart",
dimensions=(0.42, 0.42, 0.16),
position=tuple(float(v) for v in fk_multishot_anim[0][0]),
color=(55, 150, 255),
)
def _multishot_link_segments(i: int) -> np.ndarray:
_, h1, h2, h3, tip = fk_multishot_anim[i]
return np.array([[h1, h2], [h2, h3], [h3, tip]], dtype=np.float32)
ms_link_colors = np.array(
[
[[120, 190, 255], [120, 190, 255]],
[[95, 175, 245], [95, 175, 245]],
[[70, 160, 235], [70, 160, 235]],
],
dtype=np.uint8,
)
ms_link_handle = server.scene.add_line_segments(
"/multishot/links",
points=_multishot_link_segments(0),
colors=ms_link_colors,
line_width=5.0,
)
ms_joint_handles = []
for jname in ["/multishot/j1", "/multishot/j2", "/multishot/j3"]:
ms_joint_handles.append(
server.scene.add_icosphere(
jname,
radius=0.03,
color=(80, 170, 245),
position=tuple(float(v) for v in fk_multishot_anim[0][1]),
)
)
# ── Animated tip trail ────────────────────────────────────────────────────
tip_colors = compute_velocity_colors(tip_pos)
_, update_trail = add_animated_trail(server, tip_pos, tip_colors, point_size=0.02)
# ── Sidebar: joint angles (all 6) ─────────────────────────────────────────
fig_angles = go.Figure()
angle_specs = [
("α₁ (link 1, X)", IDX_A1, "royalblue"),
("β₁ (link 1, Y)", IDX_B1, "deepskyblue"),
("α₂ (link 2, X)", IDX_A2, "darkorange"),
("β₂ (link 2, Y)", IDX_B2, "gold"),
("α₃ (link 3, X)", IDX_A3, "green"),
("β₃ (link 3, Y)", IDX_B3, "limegreen"),
]
for name, idx, col in angle_specs:
fig_angles.add_trace(
go.Scatter(
x=t_angle.tolist(),
y=np.rad2deg(q_angle[:, idx]).tolist(),
mode="lines",
name=name,
line={"color": col, "width": 2},
)
)
fig_angles.add_hline(y=0, line_dash="dash", line_color="gray", annotation_text="Upright")
fig_angles.update_layout(
title="Joint angles",
xaxis_title="Time (s)",
yaxis_title="Angle (deg)",
legend={"orientation": "h"},
margin={"l": 40, "r": 10, "t": 40, "b": 40},
)
# ── Sidebar: cart position (XY plot) and control forces (2 traces) ───────
fig_cart = go.Figure()
fig_cart.add_trace(
go.Scatter(
x=q_angle[:, IDX_CART_X].tolist(),
y=q_angle[:, IDX_CART_Y].tolist(),
mode="lines",
name="Cart trajectory (multi-shoot)",
line={"color": "royalblue", "width": 2},
)
)
fig_cart.update_layout(
title="Cart path on plane",
xaxis_title="x (m)",
yaxis_title="y (m)",
yaxis={"scaleanchor": "x", "scaleratio": 1.0},
margin={"l": 40, "r": 10, "t": 40, "b": 40},
)
fig_ctrl = go.Figure()
fig_ctrl.add_trace(
go.Scatter(
x=t_vec.tolist(),
y=u_traj[:, 0].tolist(),
mode="lines",
name="F_x",
line={"color": "crimson", "width": 2},
)
)
fig_ctrl.add_trace(
go.Scatter(
x=t_vec.tolist(),
y=u_traj[:, 1].tolist(),
mode="lines",
name="F_y",
line={"color": "darkmagenta", "width": 2},
)
)
fig_ctrl.add_hline(y=0, line_dash="dash", line_color="gray")
fig_ctrl.update_layout(
title="Cart controls (normalised)",
xaxis_title="Time (s)",
yaxis_title="u",
legend={"orientation": "h"},
margin={"l": 40, "r": 10, "t": 40, "b": 40},
)
with server.gui.add_folder("Plots"):
_, update_angles = add_animated_plotly_vline(server, fig_angles, t_vec, folder_name=None)
_, update_ctrl = add_animated_plotly_vline(server, fig_ctrl, t_vec, folder_name=None)
# cart-XY plot is static (no time slider needed)
server.gui.add_plotly(fig_cart)
def update_multishot_scene(frame_idx: int) -> None:
t_cur = float(t_vec[frame_idx])
ms_i = int(np.argmin(np.abs(t_multishot_lookup - t_cur)))
ms_i = int(np.clip(ms_i, 0, len(fk_multishot_anim) - 1))
cart, h1, h2, h3, _ = fk_multishot_anim[ms_i]
ms_cart_handle.position = (float(cart[0]), float(cart[1]), 0.0)
ms_link_handle.points = _multishot_link_segments(ms_i)
for handle, pos in zip(ms_joint_handles, (h1, h2, h3)):
handle.position = tuple(float(v) for v in pos)
add_animation_controls(
server,
t_vec,
[
update_multishot_scene,
update_trail,
update_angles,
update_ctrl,
],
)
print("Viser running — open http://localhost:8080 in your browser.")
server.sleep_forever()
if __name__ == "__main__":
print("3D Triple-link cartpole swing-up — MuJoCo MJX + OpenSCvx")
print("=" * 60)
print(f"nq={n_q}, nv={n_v}, nu={n_u}, N={n}")
print(f"Links: L1={L1} m, L2={L2} m, L3={L3} m (total {L1 + L2 + L3} m)")
print()
problem.initialize()
problem.solve()
results = problem.post_process()
final_q = results.nodes["qpos"][-1]
print()
print(f"Final cart position: ({final_q[IDX_CART_X]:.4f}, {final_q[IDX_CART_Y]:.4f}) m")
a1, b1 = np.rad2deg(final_q[IDX_A1]), np.rad2deg(final_q[IDX_B1])
a2, b2 = np.rad2deg(final_q[IDX_A2]), np.rad2deg(final_q[IDX_B2])
a3, b3 = np.rad2deg(final_q[IDX_A3]), np.rad2deg(final_q[IDX_B3])
print(f"Final α₁,β₁ [deg]: ({a1:.2f}, {b1:.2f})")
print(f"Final α₂,β₂ [deg]: ({a2:.2f}, {b2:.2f})")
print(f"Final α₃,β₃ [deg]: ({a3:.2f}, {b3:.2f})")
from openscvx.plotting import plot_controls, plot_states
plot_states(results).show()
plot_controls(results).show()
visualize(results)