Skip to content

Triple Cartpole Game

Triple-cartpole balancing game.

Three progressively harder levels: Level 1 — single-link inverted pendulum Level 2 — double-link inverted pendulum Level 3 — triple-link inverted pendulum

Win condition for each level: hold ALL link cumulative angles within the configured angle band for the configured hold time (defaults under Win condition in the GUI). A 3-second freeze screen shows your time before the next level loads.

Each link's target region is shown as a green/red wedge — the angle the link must point into — that tracks the link's pivot point in real time.

Interaction

  • Drag the red X-arrow on the cart to move the setpoint.
  • PD gains, setpoint speed, angle tolerance, and hold time are adjustable in the sidebar.
  • "Reset level" restarts the current level from hanging; "Reset game" starts over from level 1.

Usage::

python examples/mjx/triple_cartpole_game.py

File: examples/mjx/triple_cartpole_game.py

from __future__ import annotations

import os
import sys
import threading
import time

import numpy as np

current_dir = os.path.dirname(os.path.abspath(__file__))
grandparent_dir = os.path.dirname(os.path.dirname(current_dir))
sys.path.append(grandparent_dir)

try:
    import jax
    import jax.numpy as jnp
    import mujoco
    import mujoco.mjx as mjx
except ImportError:
    sys.exit("MuJoCo MJX / JAX not installed.  Run: pip install openscvx[mjx]")

# ── Constants ─────────────────────────────────────────────────────────────────
LINK_LENGTHS = [0.5, 0.4, 0.3]
LINK_MASSES = [0.5, 0.4, 0.3]
LINK_RADII = [0.04, 0.035, 0.03]
LINK_RGBA = ["0.85 0.3 0.3 1", "0.3 0.8 0.3 1", "0.3 0.3 0.85 1"]
LINK_RGB = [(220, 77, 77), (77, 204, 77), (77, 77, 220)]

RAIL_LIMIT = 8.0
GEAR = 60.0
DEFAULT_KP = 60.0
DEFAULT_KD = 15.0
DEFAULT_RATE = 4.0

WIN_TOL_DEG = 45.0  # default for GUI slider (deg from vertical)
HOLD_SECS = 0.1  # default for GUI slider (continuous hold time)
WIN_PAUSE_S = 3.0
N_LEVELS = 3
LEVEL_NAMES = ["Single-link", "Double-link", "Triple-link"]

# ── Physics backend ───────────────────────────────────────────────────────────
# True  → MJX + RK4  (mjx.forward continuous dynamics, same as post_process)
# False → CPU MuJoCo mj_step (semi-implicit Euler, same as the original sim)
USE_MJX_PHYSICS: bool = False


# ── MuJoCo XML ────────────────────────────────────────────────────────────────


def make_xml(n_links: int) -> str:
    def build_links(i: int) -> str:
        if i >= n_links:
            return ""
        pad = "  " * (i + 3)
        pos = f"0 0 {LINK_LENGTHS[i - 1]}" if i > 0 else "0 0 0"
        return (
            f'{pad}<body name="link{i + 1}" pos="{pos}">\n'
            f'{pad}  <joint name="hinge{i + 1}" type="hinge"'
            f' axis="0 1 0" limited="false"/>\n'
            f'{pad}  <geom name="pole{i + 1}" type="capsule"'
            f' fromto="0 0 0 0 0 {LINK_LENGTHS[i]}"\n'
            f'{pad}        size="{LINK_RADII[i]}" mass="{LINK_MASSES[i]}"'
            f' rgba="{LINK_RGBA[i]}"/>\n'
            f"{build_links(i + 1)}"
            f"{pad}</body>\n"
        )

    return f"""<mujoco model="cartpole_game">
  <option gravity="0 0 -9.81" timestep="0.005" integrator="Euler"/>
  <worldbody>
    <body name="cart" pos="0 0 0">
      <joint name="slider" type="slide" axis="1 0 0"
             limited="true" range="-{RAIL_LIMIT} {RAIL_LIMIT}"/>
      <geom name="cart_geom" type="box" size="0.25 0.15 0.1"
            mass="2.0" rgba="0.35 0.35 0.75 1"/>
{build_links(0)}    </body>
  </worldbody>
  <actuator>
    <motor joint="slider" name="cart_force" gear="{GEAR}"
           ctrlrange="-1 1" ctrllimited="true"/>
  </actuator>
</mujoco>"""


# ── Forward kinematics ────────────────────────────────────────────────────────


def fk_joints(q: np.ndarray, n_links: int) -> list[np.ndarray]:
    """Return [cart, tip_1, …, tip_n] XZ-plane positions."""
    cx = float(q[0])
    pts = [np.array([cx, 0.0, 0.0])]
    cum = 0.0
    for i in range(n_links):
        cum += float(q[i + 1])
        L = LINK_LENGTHS[i]
        pts.append(pts[-1] + np.array([L * np.sin(cum), 0.0, L * np.cos(cum)]))
    return pts


def per_link_ok(qpos: np.ndarray, n_links: int, win_tol_deg: float) -> tuple[list[bool], float]:
    """Return per-link balance flags and worst absolute cumulative angle (deg)."""
    ok = []
    worst = 0.0
    cum = 0.0
    for i in range(n_links):
        cum += float(qpos[i + 1])
        c_norm = (cum + np.pi) % (2.0 * np.pi) - np.pi
        d = abs(float(np.rad2deg(c_norm)))
        worst = max(worst, d)
        ok.append(d <= win_tol_deg)
    return ok, worst


# ── MJX physics step (same continuous dynamics as post_process) ───────────────


def _make_mjx_step(n_links: int) -> tuple:
    """Build and warm-up a JIT-compiled RK4+MJX step function.

    Uses ``mjx.forward`` to evaluate the continuous generalized accelerations
    (qacc) at each RK4 stage — identical to the dynamics that ``post_process``
    hands to Diffrax Dopri8.  Returns ``(step_fn, dt)`` where ``step_fn`` is
    already compiled (warm-up call included) so there is no JIT latency during
    gameplay.

    ``step_fn(qpos, qvel, ctrl) -> (new_qpos, new_qvel)`` accepts and returns
    float32 JAX arrays of shape ``(n_links+1,)``.
    """
    n_q = n_links + 1
    mj_m = mujoco.MjModel.from_xml_string(make_xml(n_links))
    mj_m.opt.disableflags |= mujoco.mjtDisableBit.mjDSBL_CONTACT
    mjx_m = mjx.put_model(mj_m)
    h = float(mj_m.opt.timestep)

    def _qacc(q: jax.Array, qd: jax.Array, ctrl: jax.Array) -> jax.Array:
        data = mjx.make_data(mjx_m)
        data = data.replace(qpos=q, qvel=qd, ctrl=ctrl)
        data = mjx.forward(mjx_m, data)
        return data.qacc

    @jax.jit
    def step(qpos: jax.Array, qvel: jax.Array, ctrl: jax.Array):
        # Standard RK4: y = [qpos, qvel],  y' = [qvel, qacc(qpos, qvel)]
        k1_q = qvel
        k1_v = _qacc(qpos, qvel, ctrl)
        qp2 = qpos + 0.5 * h * k1_q
        qv2 = qvel + 0.5 * h * k1_v
        k2_q = qv2
        k2_v = _qacc(qp2, qv2, ctrl)
        qp3 = qpos + 0.5 * h * k2_q
        qv3 = qvel + 0.5 * h * k2_v
        k3_q = qv3
        k3_v = _qacc(qp3, qv3, ctrl)
        qp4 = qpos + h * k3_q
        qv4 = qvel + h * k3_v
        k4_q = qv4
        k4_v = _qacc(qp4, qv4, ctrl)

        new_qpos = qpos + (h / 6.0) * (k1_q + 2.0 * k2_q + 2.0 * k3_q + k4_q)
        new_qvel = qvel + (h / 6.0) * (k1_v + 2.0 * k2_v + 2.0 * k3_v + k4_v)

        # Hard rail limits (MJX contacts are disabled; enforce manually)
        at_wall = jnp.abs(new_qpos[0]) >= RAIL_LIMIT
        new_qpos = new_qpos.at[0].set(jnp.clip(new_qpos[0], -RAIL_LIMIT, RAIL_LIMIT))
        new_qvel = new_qvel.at[0].set(jnp.where(at_wall, 0.0, new_qvel[0]))

        return new_qpos, new_qvel

    # Warm-up: trigger JIT compilation now so gameplay has no first-step lag.
    _z = jnp.zeros(n_q, dtype=jnp.float32)
    step(_z, _z, jnp.zeros(1, dtype=jnp.float32))
    return step, h


# ── Simulator ─────────────────────────────────────────────────────────────────


class _Sim:
    """Real-time cartpole simulator.

    Supports two physics backends selected by :data:`USE_MJX_PHYSICS`:

    * **MJX + RK4** (``step_fn`` provided): uses ``mjx.forward`` continuous
      dynamics at each of the four RK4 stages — identical to what
      ``post_process`` hands to Diffrax Dopri8.  State is kept as JAX float32
      arrays; no CPU ``MjData`` is used at runtime.

    * **CPU MuJoCo** (``step_fn=None``): calls ``mujoco.mj_step`` with
      semi-implicit Euler, exactly as the original simulator.
    """

    def __init__(self, n_links: int, step_fn=None, dt: float | None = None) -> None:
        self.n_links = n_links
        self.n_q = n_links + 1
        self._step_fn = step_fn  # None → CPU MuJoCo path

        if step_fn is None:
            # CPU MuJoCo path: build model + data here; dt from XML.
            self._mj_model = mujoco.MjModel.from_xml_string(make_xml(n_links))
            self._mj_model.opt.disableflags |= mujoco.mjtDisableBit.mjDSBL_CONTACT
            self._mj_data = mujoco.MjData(self._mj_model)
            self._dt = float(self._mj_model.opt.timestep)
        else:
            self._mj_model = None
            self._mj_data = None
            self._dt = dt  # supplied by _make_mjx_step

        self.x_commanded: float = 0.0
        self.x_target: float = 0.0
        self.max_rate: float = DEFAULT_RATE
        self.kp: float = DEFAULT_KP
        self.kd: float = DEFAULT_KD
        self.paused: bool = False

        self._snap_lock = threading.Lock()
        q0 = np.zeros(self.n_q, dtype=np.float32)
        q0[1] = np.pi
        self._qpos_snap = q0.copy()
        self._qvel_snap = np.zeros(self.n_q, dtype=np.float32)

        if step_fn is None:
            # Initialise CPU MuJoCo data to the hanging state.
            self._mj_data.qpos[:] = q0
            self._mj_data.qvel[:] = 0.0
            mujoco.mj_forward(self._mj_model, self._mj_data)

        self._running = False
        self._thread: threading.Thread | None = None

    def get_state(self) -> tuple[np.ndarray, np.ndarray]:
        with self._snap_lock:
            return self._qpos_snap.copy(), self._qvel_snap.copy()

    def start(self) -> None:
        self._running = True
        self._thread = threading.Thread(target=self._loop, daemon=True)
        self._thread.start()

    def stop(self) -> None:
        self._running = False
        if self._thread:
            self._thread.join(timeout=0.5)

    def _loop(self) -> None:
        if self._step_fn is not None:
            self._loop_mjx()
        else:
            self._loop_mujoco()

    def _loop_mjx(self) -> None:
        """Physics loop: MJX + RK4 (same continuous dynamics as post_process)."""
        dt = self._dt
        with self._snap_lock:
            qpos = jnp.array(self._qpos_snap, dtype=jnp.float32)
            qvel = jnp.array(self._qvel_snap, dtype=jnp.float32)

        while self._running:
            t0 = time.perf_counter()
            if not self.paused:
                dx_max = self.max_rate * dt
                x_cmd = float(np.clip(self.x_commanded, -RAIL_LIMIT, RAIL_LIMIT))
                self.x_target = self.x_target + float(
                    np.clip(x_cmd - self.x_target, -dx_max, dx_max)
                )

                x = float(qpos[0])
                xdot = float(qvel[0])
                u_N = self.kp * (self.x_target - x) - self.kd * xdot
                ctrl = jnp.array([float(np.clip(u_N / GEAR, -1.0, 1.0))], dtype=jnp.float32)

                qpos, qvel = self._step_fn(qpos, qvel, ctrl)

                with self._snap_lock:
                    self._qpos_snap = np.array(qpos)
                    self._qvel_snap = np.array(qvel)

            spare = dt - (time.perf_counter() - t0)
            if spare > 0:
                time.sleep(spare)

    def _loop_mujoco(self) -> None:
        """Physics loop: CPU MuJoCo mj_step (semi-implicit Euler)."""
        dt = self._dt
        while self._running:
            t0 = time.perf_counter()
            if not self.paused:
                dx_max = self.max_rate * dt
                x_cmd = float(np.clip(self.x_commanded, -RAIL_LIMIT, RAIL_LIMIT))
                self.x_target = self.x_target + float(
                    np.clip(x_cmd - self.x_target, -dx_max, dx_max)
                )

                x = float(self._mj_data.qpos[0])
                xdot = float(self._mj_data.qvel[0])
                u_N = self.kp * (self.x_target - x) - self.kd * xdot
                self._mj_data.ctrl[0] = float(np.clip(u_N / GEAR, -1.0, 1.0))
                mujoco.mj_step(self._mj_model, self._mj_data)

                with self._snap_lock:
                    self._qpos_snap = self._mj_data.qpos.copy()
                    self._qvel_snap = self._mj_data.qvel.copy()

            spare = dt - (time.perf_counter() - t0)
            if spare > 0:
                time.sleep(spare)


# ── Fireworks ─────────────────────────────────────────────────────────────────


def _run_fireworks(server: object, duration: float = 9.0) -> None:
    """Celebratory particle fireworks.  Runs in a daemon thread."""
    rng = np.random.default_rng()

    # Palette: each shell picks one base colour; particles fade to black.
    PALETTE = [
        (255, 80, 80),  # red
        (80, 255, 80),  # green
        (80, 130, 255),  # blue
        (255, 220, 50),  # gold
        (220, 80, 220),  # purple
        (80, 220, 220),  # cyan
        (255, 160, 40),  # orange
        (255, 255, 255),  # white
    ]

    GRAVITY = np.array([0.0, 0.0, -5.0], dtype=np.float32)
    DT = 1.0 / 60.0
    LIFETIME = 2.8  # particle lifetime (s)
    N_PER_BURST = 130
    BURST_INTERVAL = 0.45  # s between shells

    # Single point cloud for all particles.
    cloud = server.scene.add_point_cloud(
        "/fireworks",
        points=np.array([[0.0, 0.0, -20.0]], dtype=np.float32),
        colors=np.array([[0, 0, 0]], dtype=np.uint8),
        point_size=0.06,
    )

    # Particle state arrays.
    pos = np.zeros((0, 3), dtype=np.float32)
    vel = np.zeros((0, 3), dtype=np.float32)
    col = np.zeros((0, 3), dtype=np.float32)  # float for smooth fading
    life = np.zeros(0, dtype=np.float32)

    t = 0.0
    next_burst = 0.0
    color_idx = 0

    while t < duration:
        t0 = time.perf_counter()

        # ── Spawn new shell ───────────────────────────────────────────────────
        if t >= next_burst:
            next_burst += BURST_INTERVAL
            color_idx = (color_idx + 1) % len(PALETTE)

            bx = float(rng.uniform(-3.2, 3.2))
            bz = float(rng.uniform(0.6, 2.4))
            burst_pos = np.array([bx, 0.0, bz], dtype=np.float32)
            base_col = np.array(PALETTE[color_idx], dtype=np.float32)

            # Uniform spherical velocity distribution, slightly squeezed in Y.
            u = rng.uniform(-1.0, 1.0, N_PER_BURST).astype(np.float32)
            phi = rng.uniform(0.0, 2.0 * np.pi, N_PER_BURST).astype(np.float32)
            r = np.sqrt(np.maximum(1.0 - u * u, 0.0)).astype(np.float32)
            spd = rng.uniform(0.6, 3.0, N_PER_BURST).astype(np.float32)
            new_vel = (
                spd[:, None]
                * np.stack(
                    [
                        r * np.cos(phi),
                        r * 0.25 * np.sin(phi),  # thin Y spread — mostly XZ plane
                        u,
                    ],
                    axis=1,
                )
            ).astype(np.float32)

            new_pos = np.tile(burst_pos, (N_PER_BURST, 1))
            new_col = np.tile(base_col, (N_PER_BURST, 1))
            new_life = np.full(N_PER_BURST, LIFETIME, dtype=np.float32)

            if len(pos):
                pos = np.vstack([pos, new_pos])
                vel = np.vstack([vel, new_vel])
                col = np.vstack([col, new_col])
                life = np.concatenate([life, new_life])
            else:
                pos, vel, col, life = new_pos, new_vel, new_col, new_life

        # ── Physics step ─────────────────────────────────────────────────────
        if len(pos):
            vel += GRAVITY * DT
            pos += vel * DT
            life -= DT

            alive = life > 0.0
            if np.any(alive):
                frac = np.clip(life[alive] / LIFETIME, 0.0, 1.0)[:, None]
                disp_col = np.clip(col[alive] * frac, 0, 255).astype(np.uint8)
                cloud.points = pos[alive]
                cloud.colors = disp_col
            else:
                cloud.points = np.array([[0.0, 0.0, -20.0]], dtype=np.float32)
                cloud.colors = np.array([[0, 0, 0]], dtype=np.uint8)

            pos = pos[alive]
            vel = vel[alive]
            col = col[alive]
            life = life[alive]

        t += DT
        spare = DT - (time.perf_counter() - t0)
        if spare > 0:
            time.sleep(spare)

    try:
        cloud.remove()
    except Exception:
        pass


# ── Cone geometry factory ─────────────────────────────────────────────────────


def _make_cone_mesh(link_idx: int, win_tol_deg: float) -> tuple[np.ndarray, np.ndarray]:
    """Triangular wedge centred at origin, apex down, opening upward.

    Returns (vertices [3×3 float32], faces [2×3 uint32]).
    Two faces so the mesh is visible from both the +Y and −Y camera sides.
    """
    L = LINK_LENGTHS[link_idx]
    half = np.deg2rad(win_tol_deg)
    verts = np.array(
        [
            [0.0, 0.0, 0.0],  # apex
            [-L * np.sin(half), 0.0, L * np.cos(half)],  # left boundary
            [L * np.sin(half), 0.0, L * np.cos(half)],  # right boundary
        ],
        dtype=np.float32,
    )
    faces = np.array([[0, 1, 2], [0, 2, 1]], dtype=np.uint32)  # front + back
    return verts, faces


def _make_cone_lines(link_idx: int, win_tol_deg: float) -> np.ndarray:
    """Two boundary edges of the cone as a (2,2,3) float32 segment array."""
    L = LINK_LENGTHS[link_idx]
    half = np.deg2rad(win_tol_deg)
    apex = np.array([0.0, 0.0, 0.0], dtype=np.float32)
    left_tip = np.array([-L * np.sin(half), 0.0, L * np.cos(half)], dtype=np.float32)
    rght_tip = np.array([L * np.sin(half), 0.0, L * np.cos(half)], dtype=np.float32)
    return np.array([[apex, left_tip], [apex, rght_tip]], dtype=np.float32)


# ── Main game loop ────────────────────────────────────────────────────────────


def run() -> None:
    try:
        import viser
    except ImportError:
        sys.exit("viser not installed.  Run: pip install viser")

    # Pre-compile MJX step functions for all levels (only when MJX is active).
    _step_cache: dict[int, tuple] = {}  # n_links -> (step_fn, dt)
    if USE_MJX_PHYSICS:
        print("Compiling MJX physics (RK4 + mjx.forward) for all levels…")
        for _nl in range(1, N_LEVELS + 1):
            print(f"  Level {_nl} ({LEVEL_NAMES[_nl - 1]})…", end=" ", flush=True)
            _step_cache[_nl] = _make_mjx_step(_nl)
            print("done")
        print("Physics ready.  [MJX + RK4]\n")
    else:
        print("Physics backend: CPU MuJoCo mj_step (semi-implicit Euler)\n")

    server = viser.ViserServer()
    server.scene.set_up_direction("+z")

    # ── Game state ────────────────────────────────────────────────────────────
    gs: dict = {
        "level": 0,
        "status": "idle",
        "level_start": 0.0,
        "win_at": 0.0,
        "win_elapsed": 0.0,
        "balance_start": None,
        "advancing": False,
    }
    _sim: list[_Sim | None] = [None]
    _level_vis: list[dict | None] = [None]
    _handles: list[list] = [[]]  # all per-level scene handles
    _gen: list[int] = [0]  # incremented on each load; guards stale gizmo callbacks

    # ── Persistent scene ──────────────────────────────────────────────────────
    server.scene.add_grid(
        "/ground", width=10.0, height=4.0, cell_size=0.5, position=(0.0, 0.0, -0.115)
    )
    server.scene.add_box(
        "/rail",
        dimensions=(2 * RAIL_LIMIT + 0.1, 0.04, 0.015),
        position=(0.0, 0.0, 0.0),
        color=(130, 130, 130),
    )
    for sign in (-1, 1):
        server.scene.add_box(
            f"/rail/stop{'L' if sign < 0 else 'R'}",
            dimensions=(0.04, 0.12, 0.12),
            position=(sign * RAIL_LIMIT, 0.0, 0.0),
            color=(180, 60, 60),
        )

    # ── Persistent GUI ────────────────────────────────────────────────────────
    header_html = server.gui.add_html("")
    win_html = server.gui.add_html("", visible=False)
    progress_html = server.gui.add_html("")

    with server.gui.add_folder("Controls"):
        rate_sl = server.gui.add_slider(
            "Max setpoint speed (m/s)", min=0.1, max=8.0, step=0.1, initial_value=DEFAULT_RATE
        )
        btn_reset = server.gui.add_button("Reset level")
        btn_reset_game = server.gui.add_button("Reset game")

    with server.gui.add_folder("PD Gains"):
        kp_sl = server.gui.add_slider("Kp  (N/m)", min=0, max=300, step=1, initial_value=DEFAULT_KP)
        kd_sl = server.gui.add_slider(
            "Kd  (N·s/m)", min=0, max=80, step=0.5, initial_value=DEFAULT_KD
        )

    with server.gui.add_folder("Win condition"):
        win_tol_sl = server.gui.add_slider(
            "Angle tolerance (°)",
            min=1.0,
            max=90.0,
            step=0.5,
            initial_value=WIN_TOL_DEG,
        )
        hold_sl = server.gui.add_slider(
            "Hold time (s)",
            min=0.05,
            max=5.0,
            step=0.05,
            initial_value=HOLD_SECS,
        )

    with server.gui.add_folder("State"):
        state_md = server.gui.add_markdown("*Waiting…*")

    @rate_sl.on_update
    def _(_e):
        if _sim[0]:
            _sim[0].max_rate = float(rate_sl.value)

    @kp_sl.on_update
    def _(_e):
        if _sim[0]:
            _sim[0].kp = float(kp_sl.value)

    @kd_sl.on_update
    def _(_e):
        if _sim[0]:
            _sim[0].kd = float(kd_sl.value)

    @btn_reset.on_click
    def _(_e):
        gs["advancing"] = False
        win_html.visible = False
        _load_level(gs["level"])

    @btn_reset_game.on_click
    def _(_e):
        gs["advancing"] = False
        win_html.visible = False
        progress_html.content = ""
        _load_level(0)

    # ── HTML helpers ──────────────────────────────────────────────────────────

    def _header(level_idx: int, elapsed: float | None = None) -> str:
        stars = "★" * (level_idx + 1) + "☆" * (N_LEVELS - level_idx - 1)
        timer = (
            f'<div style="font-size:24px;font-weight:bold;color:#f5c518;">⏱ {elapsed:.1f}s</div>'
            if elapsed is not None
            else ""
        )
        return (
            f'<div style="text-align:center;padding:6px 0;">'
            f'<div style="font-size:20px;font-weight:bold;color:#e8e8e8;">'
            f"Level {level_idx + 1} / {N_LEVELS}</div>"
            f'<div style="font-size:13px;color:#aaa;margin-bottom:4px;">'
            f"{LEVEL_NAMES[level_idx]} pendulum</div>"
            f'<div style="font-size:18px;color:#f5c518;">{stars}</div>'
            f"{timer}</div>"
        )

    def _win_screen(level_idx: int, elapsed: float, countdown: float) -> str:
        if level_idx < N_LEVELS - 1:
            nxt = (
                f'<div style="color:#888;font-size:13px;margin-top:8px;">'
                f"Level {level_idx + 2} in {countdown:.0f}s…</div>"
            )
        else:
            nxt = (
                '<div style="color:#f5c518;font-size:14px;margin-top:8px;">'
                "🏆 All levels complete!</div>"
            )
        return (
            f'<div style="background:linear-gradient(135deg,#1a1a2e,#16213e);'
            f"border:2px solid #50fa7b;border-radius:12px;padding:20px;"
            f'text-align:center;margin:6px 0;">'
            f'<div style="font-size:36px;margin-bottom:6px;">🎉</div>'
            f'<div style="color:#50fa7b;font-size:20px;font-weight:bold;">'
            f"Level {level_idx + 1} Complete!</div>"
            f'<div style="color:#f8f8f2;font-size:16px;margin-top:6px;">'
            f"Balanced in <strong>{elapsed:.1f}s</strong></div>"
            f"{nxt}</div>"
        )

    def _all_done() -> str:
        return (
            '<div style="background:linear-gradient(135deg,#2d1b69,#1a1a2e);'
            "border:2px solid #f5c518;border-radius:12px;padding:24px;"
            'text-align:center;margin:6px 0;">'
            '<div style="font-size:48px;margin-bottom:8px;">🏆</div>'
            '<div style="color:#f5c518;font-size:22px;font-weight:bold;">'
            "You beat the game!</div>"
            '<div style="color:#aaa;font-size:14px;margin-top:8px;">'
            "All 3 levels balanced. Impressive.</div></div>"
        )

    def _progress_bar(hold: float, hold_secs: float) -> str:
        denom = max(float(hold_secs), 1e-6)
        frac = min(1.0, hold / denom)
        return (
            f'<div style="margin:6px 0;text-align:center;">'
            f'<div style="color:#50fa7b;font-size:12px;margin-bottom:4px;">'
            f"Balancing… {hold:.1f} / {hold_secs:.2f}s</div>"
            f'<div style="background:#333;border-radius:4px;height:8px;">'
            f'<div style="background:#50fa7b;width:{frac * 100:.0f}%;'
            f'height:8px;border-radius:4px;"></div></div></div>'
        )

    def _angle_warning(worst: float, win_tol_deg: float) -> str:
        return (
            f'<div style="text-align:center;color:#ff6b6b;font-size:12px;'
            f'margin:4px 0;">Worst angle: {worst:.1f}°'
            f" (need ≤{win_tol_deg:.1f}°)</div>"
        )

    def _refresh_cone_geometry(win_tol_deg: float) -> None:
        vis = _level_vis[0]
        if vis is None:
            return
        n = vis["n"]
        for i in range(n):
            verts, _faces = _make_cone_mesh(i, win_tol_deg)
            lpts = _make_cone_lines(i, win_tol_deg)
            vis["green_fill"][i].vertices = verts
            vis["red_fill"][i].vertices = verts
            vis["green_lines"][i].points = lpts
            vis["red_lines"][i].points = lpts

    @win_tol_sl.on_update
    def _(_e):
        _refresh_cone_geometry(float(win_tol_sl.value))

    # ── Level setup ───────────────────────────────────────────────────────────

    def _load_level(level_idx: int) -> None:
        # Block render loop during transition.
        _level_vis[0] = None
        # Bump generation so stale gizmo callbacks from the old level are ignored.
        _gen[0] += 1
        my_gen = _gen[0]

        # Remove all per-level scene objects from the previous level.
        for h in _handles[0]:
            try:
                h.remove()
            except Exception:
                pass
        _handles[0].clear()

        # Stop old sim.
        old = _sim[0]
        _sim[0] = None
        if old:
            old.paused = False
            old.stop()

        # Create new sim, using the pre-compiled MJX step function or CPU
        # MuJoCo depending on the USE_MJX_PHYSICS flag.
        n = level_idx + 1
        sim = _Sim(n, *_step_cache[n]) if USE_MJX_PHYSICS else _Sim(n)
        sim.kp = float(kp_sl.value)
        sim.kd = float(kd_sl.value)
        sim.max_rate = float(rate_sl.value)
        sim.start()
        _sim[0] = sim

        gs.update(
            {
                "level": level_idx,
                "status": "playing",
                "level_start": time.time(),
                "win_at": 0.0,
                "win_elapsed": 0.0,
                "balance_start": None,
                "advancing": False,
            }
        )
        header_html.content = _header(level_idx)
        win_html.visible = False
        progress_html.content = ""

        # ── Build scene ───────────────────────────────────────────────────────
        q0 = np.zeros(n + 1)
        q0[1] = np.pi
        pts = fk_joints(q0, n)

        def reg(h):
            _handles[0].append(h)
            return h

        cart_h = reg(
            server.scene.add_box(
                "/lvl/cart",
                dimensions=(0.5, 0.3, 0.2),
                position=tuple(float(v) for v in pts[0]),
                color=(90, 90, 190),
            )
        )

        # One scene object per link — avoids any buffer-resize issue in Viser
        # when the link count changes between levels.
        link_handles = [
            reg(
                server.scene.add_line_segments(
                    f"/lvl/link{i}",
                    points=np.array([[pts[i], pts[i + 1]]], dtype=np.float32),
                    colors=np.array([[LINK_RGB[i], LINK_RGB[i]]], dtype=np.uint8),
                    line_width=7.0,
                )
            )
            for i in range(n)
        ]

        jhandles = [
            reg(
                server.scene.add_icosphere(
                    f"/lvl/j{i}",
                    radius=0.05,
                    color=LINK_RGB[i],
                    position=tuple(float(v) for v in pts[i]),
                )
            )
            for i in range(n)
        ]

        tip_trail: list[np.ndarray] = []
        tip_cloud = reg(
            server.scene.add_point_cloud(
                "/lvl/tip",
                points=np.array([pts[-1]], dtype=np.float32),
                colors=np.array([[255, 200, 50]], dtype=np.uint8),
                point_size=0.025,
            )
        )

        # Use a generation-unique path so Viser always sends a fresh node to
        # the client on reset — prevents the old gizmo position from bleeding
        # into the new level.
        gizmo = reg(
            server.scene.add_transform_controls(
                f"/lvl/gizmo_{my_gen}",
                scale=0.9,
                active_axes=(True, False, False),
                disable_rotations=True,
                translation_limits=((-RAIL_LIMIT, RAIL_LIMIT), (0, 0), (0, 0)),
                position=(0.0, 0.0, 0.0),
            )
        )

        @gizmo.on_update
        def _drag(_e):
            if _gen[0] != my_gen:  # stale callback from a removed gizmo
                return
            x_new = float(np.clip(gizmo.position[0], -RAIL_LIMIT, RAIL_LIMIT))
            gizmo.position = (x_new, 0.0, 0.0)
            if _sim[0]:
                _sim[0].x_commanded = x_new

        # ── Tolerance cones ───────────────────────────────────────────────────
        # Each link i gets TWO mesh wedges (green = ok, red = not ok) plus
        # two boundary line pairs.  Visibility is toggled every render frame.
        wtol = float(win_tol_sl.value)
        verts_cone, faces_cone = zip(*[_make_cone_mesh(i, wtol) for i in range(n)])
        lines_cone = [_make_cone_lines(i, wtol) for i in range(n)]

        green_fill, red_fill = [], []
        green_lines, red_lines = [], []
        _col_lines = np.array(
            [[[c, c], [c, c]] for c in [(60, 200, 60), (60, 200, 60)]], dtype=np.uint8
        )
        _red_lines = np.array(
            [[[c, c], [c, c]] for c in [(200, 60, 60), (200, 60, 60)]], dtype=np.uint8
        )

        for i in range(n):
            base = tuple(float(v) for v in pts[i])
            lpts = lines_cone[i]
            g_lc = np.array(
                [[[60, 200, 60], [60, 200, 60]], [[60, 200, 60], [60, 200, 60]]], dtype=np.uint8
            )
            r_lc = np.array(
                [[[200, 60, 60], [200, 60, 60]], [[200, 60, 60], [200, 60, 60]]], dtype=np.uint8
            )

            green_fill.append(
                reg(
                    server.scene.add_mesh_simple(
                        f"/lvl/cone_g{i}",
                        vertices=verts_cone[i],
                        faces=faces_cone[i],
                        color=(60, 200, 60),
                        opacity=0.18,
                        side="double",
                        position=base,
                        visible=False,
                    )
                )
            )
            red_fill.append(
                reg(
                    server.scene.add_mesh_simple(
                        f"/lvl/cone_r{i}",
                        vertices=verts_cone[i],
                        faces=faces_cone[i],
                        color=(200, 60, 60),
                        opacity=0.18,
                        side="double",
                        position=base,
                        visible=True,
                    )
                )
            )
            green_lines.append(
                reg(
                    server.scene.add_line_segments(
                        f"/lvl/cline_g{i}",
                        points=lpts,
                        colors=g_lc,
                        line_width=1.5,
                        position=base,
                        visible=False,
                    )
                )
            )
            red_lines.append(
                reg(
                    server.scene.add_line_segments(
                        f"/lvl/cline_r{i}",
                        points=lpts,
                        colors=r_lc,
                        line_width=1.5,
                        position=base,
                        visible=True,
                    )
                )
            )

        _level_vis[0] = {
            "n": n,
            "cart_h": cart_h,
            "link_handles": link_handles,
            "jhandles": jhandles,
            "tip_trail": tip_trail,
            "tip_cloud": tip_cloud,
            "green_fill": green_fill,
            "red_fill": red_fill,
            "green_lines": green_lines,
            "red_lines": red_lines,
        }

    # ── Win / advance ─────────────────────────────────────────────────────────

    def _trigger_win(elapsed: float) -> None:
        gs.update({"status": "won", "win_at": time.time(), "win_elapsed": elapsed})
        if _sim[0]:
            _sim[0].paused = True
        progress_html.content = ""
        win_html.content = _win_screen(gs["level"], elapsed, WIN_PAUSE_S)
        win_html.visible = True

    def _advance_level() -> None:
        nxt = gs["level"] + 1
        if nxt < N_LEVELS:
            _load_level(nxt)
        else:
            gs["status"] = "done"
            header_html.content = progress_html.content = ""
            win_html.content = _all_done()
            win_html.visible = True
            threading.Thread(target=_run_fireworks, args=(server,), daemon=True).start()

    # ── Render / game loop ────────────────────────────────────────────────────
    RENDER_DT = 1.0 / 60.0
    MAX_TRAIL = 500

    def _render_loop() -> None:
        while True:
            t0 = time.perf_counter()
            vis = _level_vis[0]
            sim = _sim[0]
            if vis is None or sim is None:
                time.sleep(RENDER_DT)
                continue

            n = vis["n"]
            status = gs["status"]
            qpos, qvel = sim.get_state()
            pts = fk_joints(qpos, n)

            # ── 3-D scene ─────────────────────────────────────────────────────
            vis["cart_h"].position = (float(pts[0][0]), 0.0, 0.0)
            for i, lh in enumerate(vis["link_handles"]):
                lh.points = np.array([[pts[i], pts[i + 1]]], dtype=np.float32)
            for i, jh in enumerate(vis["jhandles"]):
                jh.position = tuple(float(v) for v in pts[i])

            tip_trail = vis["tip_trail"]
            tip_trail.append(pts[-1].copy())
            if len(tip_trail) > MAX_TRAIL:
                tip_trail.pop(0)
            if tip_trail:
                ta = np.array(tip_trail, dtype=np.float32)
                nf = len(ta)
                fr = np.linspace(0.0, 1.0, nf)
                tc = np.zeros((nf, 3), dtype=np.uint8)
                tc[:, 0] = (255 * fr).astype(np.uint8)
                tc[:, 1] = (200 * (1.0 - fr)).astype(np.uint8)
                tc[:, 2] = 50
                vis["tip_cloud"].points = ta
                vis["tip_cloud"].colors = tc

            # ── Tolerance cones — move to link bases, toggle green/red ────────
            win_tol = float(win_tol_sl.value)
            hold_secs = float(hold_sl.value)
            link_ok, worst = per_link_ok(qpos, n, win_tol)
            for i in range(n):
                base = tuple(float(v) for v in pts[i])
                ok = link_ok[i]
                vis["green_fill"][i].position = base
                vis["red_fill"][i].position = base
                vis["green_lines"][i].position = base
                vis["red_lines"][i].position = base
                vis["green_fill"][i].visible = ok
                vis["red_fill"][i].visible = not ok
                vis["green_lines"][i].visible = ok
                vis["red_lines"][i].visible = not ok

            # ── State readout ─────────────────────────────────────────────────
            deg = np.rad2deg(qpos[1 : n + 1])
            lines = [f"**x:** {qpos[0]:.3f} m  ·  **ẋ:** {qvel[0]:.2f} m/s"]
            for i in range(n):
                lines.append(
                    f"**θ{i + 1}:** {deg[i]:.1f}°  ·  "
                    f"**θ̇{i + 1}:** {np.rad2deg(qvel[i + 1]):.1f}°/s"
                )
            state_md.content = "\n\n".join(lines)

            # ── Game logic ────────────────────────────────────────────────────
            if status == "playing":
                elapsed = time.time() - gs["level_start"]
                all_ok = all(link_ok)
                header_html.content = _header(gs["level"], elapsed)

                if all_ok:
                    if gs["balance_start"] is None:
                        gs["balance_start"] = time.time()
                    hold = time.time() - gs["balance_start"]
                    progress_html.content = _progress_bar(hold, hold_secs)
                    if hold >= hold_secs:
                        _trigger_win(elapsed)
                else:
                    gs["balance_start"] = None
                    progress_html.content = _angle_warning(worst, win_tol)

            elif status == "won":
                rem = WIN_PAUSE_S - (time.time() - gs["win_at"])
                if rem > 0:
                    win_html.content = _win_screen(gs["level"], gs["win_elapsed"], rem)
                elif not gs["advancing"]:
                    gs["advancing"] = True
                    threading.Thread(target=_advance_level, daemon=True).start()

            spare = RENDER_DT - (time.perf_counter() - t0)
            if spare > 0:
                time.sleep(spare)

    threading.Thread(target=_render_loop, daemon=True).start()
    _load_level(0)

    print("Cartpole Balancing Game — open http://localhost:8080")
    print("Drag the red X-arrow to move the cart.")
    print(
        f"Hold links in the green wedges (default ±{WIN_TOL_DEG:.0f}°, "
        f"{HOLD_SECS:.2f}s — adjustable under Win condition)."
    )
    server.sleep_forever()


if __name__ == "__main__":
    run()