Skip to content

6DoF Pdg Realtime Base

6-DoF powered descent guidance base problem for realtime OpenSCvx demos.

Adapted from the SCvxGEN repository by Abhi Kamath, https://scvxgen.mintlify.app/introduction.

File: examples/realtime/base_problems/6DoF_pdg_realtime_base.py

import contextlib
import io
import os
import queue
import sys
import time as pytime

import jax

jax.config.update("jax_enable_x64", True)

import numpy as np

current_dir = os.path.dirname(os.path.abspath(__file__))
repo_root_dir = os.path.dirname(os.path.dirname(os.path.dirname(current_dir)))
sys.path.append(repo_root_dir)

import openscvx as ox
from examples.plotting_viser import (
    create_animated_plotting_server,
    create_scp_animated_plotting_server,
)
from openscvx import Problem
from openscvx.utils import printing as _openscvx_printing


def _silence_openscvx_console_printing() -> None:
    _openscvx_printing.intro = lambda: None
    _openscvx_printing.print_problem_summary = lambda *args, **kwargs: None
    _openscvx_printing.print_results_summary = lambda *args, **kwargs: None


def _scp_intermediate_drain_only(problem: Problem):
    def _run(print_queue, params, columns) -> None:
        hz = 30.0
        while True:
            t_start = pytime.time()
            try:
                data = print_queue.get(timeout=1.0 / hz)
                problem._scp_last_emit = data
            except queue.Empty:
                pass
            pytime.sleep(max(0.0, 1.0 / hz - (pytime.time() - t_start)))

    return _run


_gpq_patched = False


def _patch_get_print_queue_data_for_pdg() -> None:
    global _gpq_patched
    if _gpq_patched:
        return
    import examples.plotting_viser as _pv

    _orig_gpq = _pv.get_print_queue_data

    def _gpq(problem):
        emit = getattr(problem, "_scp_last_emit", None)
        if emit is not None:
            return {
                "dis_time": emit.get("dis_time", 0.0),
                "prob_stat": emit.get("prob_stat", "--"),
                "cost": emit.get("cost", 0.0),
            }
        return _orig_gpq(problem)

    _pv.get_print_queue_data = _gpq
    _gpq_patched = True


def install_pdg_scp_console_style(problem: Problem) -> None:
    problem.settings.dev.printing = True
    _openscvx_printing.intermediate = _scp_intermediate_drain_only(problem)
    _patch_get_print_queue_data_for_pdg()


def initialize_problem_quiet(problem: Problem) -> None:
    with contextlib.redirect_stdout(io.StringIO()):
        problem.initialize()


def print_scp_table_header_once(problem: Problem) -> None:
    cols = getattr(problem, "_columns", None)
    if cols is not None:
        _openscvx_printing.header(cols)


_silence_openscvx_console_printing()

# Position components live at indices 1:4 in the state vector (mass is index 0).
POSITION_STATE_SLICE = slice(1, 4)


def model_vec_to_viser_xyz(v: np.ndarray) -> np.ndarray:
    """Map model-frame 3-vectors to Viser (x, y, z): (z, y, x) component order → (x, y, z).

    Linear involution: same mapping converts Viser coordinates back to model.
    """
    a = np.asarray(v, dtype=np.float64)
    if a.size == 0:
        return a
    if a.ndim == 1 and a.shape[0] == 3:
        return np.array([a[2], a[1], a[0]], dtype=np.float64)
    if a.ndim >= 2 and a.shape[-1] == 3:
        return np.stack([a[..., 2], a[..., 1], a[..., 0]], axis=-1)
    return a


def remap_optimization_results_for_viser_xyz(
    results, position_slice: slice = POSITION_STATE_SLICE
) -> None:
    """After ``post_process()``, remap stored trajectories and SCP history for Viser axes."""
    traj = results.trajectory
    if traj.get("position") is not None:
        traj["position"] = model_vec_to_viser_xyz(np.asarray(traj["position"]))
    if traj.get("velocity") is not None:
        traj["velocity"] = model_vec_to_viser_xyz(np.asarray(traj["velocity"]))

    idx_list = list(range(position_slice.start, position_slice.stop, position_slice.step or 1))
    if len(idx_list) != 3:
        idx_list = None

    def _remap_state_rows(X: np.ndarray) -> np.ndarray:
        X = np.asarray(X, dtype=np.float64, copy=True)
        sl = position_slice
        X[:, sl] = model_vec_to_viser_xyz(X[:, sl])
        return X

    for i in range(len(results.X)):
        results.X[i] = _remap_state_rows(results.X[i])

    if getattr(results, "x_full", None) is not None:
        results.x_full = _remap_state_rows(results.x_full)

    if idx_list is None or not results.discretization_history:
        return

    n_x = int(results.X[-1].shape[1]) if results.X else 0
    n_u = int(results.U[-1].shape[1]) if results.U else 0
    if n_x <= 0:
        return

    i4 = n_x + n_x * n_x + 2 * n_x * n_u
    i0, i1, i2 = idx_list[0], idx_list[1], idx_list[2]

    for k in range(len(results.discretization_history)):
        V = np.asarray(results.discretization_history[k], dtype=np.float64, copy=True)
        n_timesteps = V.shape[1]
        n_segments = V.shape[0] // i4
        for seg in range(n_segments):
            b = seg * i4
            for t in range(n_timesteps):
                old = np.array([V[b + i0, t], V[b + i1, t], V[b + i2, t]], dtype=np.float64)
                new = model_vec_to_viser_xyz(old)
                V[b + i0, t] = new[0]
                V[b + i1, t] = new[1]
                V[b + i2, t] = new[2]
        results.discretization_history[k] = V


n = 5

# Runtime-updatable constants (see problem.parameters["name"] after building the problem)
gI = ox.Parameter("gI", value=1.0)
l_arm = ox.Parameter("l", value=0.25)
J_diag = ox.Parameter(
    "J_diag",
    shape=(3,),
    value=np.array([0.168 * 2e-2, 0.168, 0.168]),
)
J_mat = ox.Diag(J_diag)
J_inv_mat = ox.Inv(ox.Diag(J_diag))
g0 = ox.Parameter("g0", value=1.0)
Isp = ox.Parameter("Isp", value=30.0)
m_dry = ox.Parameter("m_dry", value=1.0)
v_max = ox.Parameter("v_max", value=3.0)
w_max = ox.Parameter("w_max", value=0.3752)
del_max = ox.Parameter("del_max", value=20.0)
theta_max = ox.Parameter("theta_max", value=75.0)
T_min = ox.Parameter("T_min", value=1.5)
T_max = ox.Parameter("T_max", value=6.5)
gamma = ox.Parameter("gamma", value=75.0)
beta = ox.Parameter("beta", value=0.01)
c_ax = ox.Parameter("c_ax", value=0.5)
c_ayz = ox.Parameter("c_ayz", value=1.0)
S_a = ox.Parameter("S_a", value=0.5)
rho = ox.Parameter("rho", value=1.0)
l_p = ox.Parameter("l_p", value=0.05)
initial_position = ox.Parameter("initial_position", shape=(3,), value=np.array([7.5, 4.5, 2.5]))
final_position = ox.Parameter("final_position", shape=(2,), value=np.array([0.0, 0.0]))

# Concat (not Stack): JAX lowering stacks scalars as (3,1); Diag needs a length-3 vector (3,).
CA = ox.Diag(ox.Concat(c_ax, c_ayz, c_ayz))
r_arm = ox.Concat(-l_arm, 0.0, 0.0)
r_cp = ox.Concat(l_p, 0.0, 0.0)

mass = ox.State("mass", shape=(1,))
mass.max = [2.0]
mass.min = [1.0]
mass.initial = [2.0]
mass.final = [ox.Maximize(1.5)]

position = ox.State("position", shape=(3,))
position.max = [10.0, 10.0, 10.0]
position.min = [-10.0, -10.0, -10.0]
position.initial = [
    ox.Free(float(initial_position.value[0])),
    ox.Free(float(initial_position.value[1])),
    ox.Free(float(initial_position.value[2])),
]
position.final = [0.0, ox.Free(0.0), ox.Free(0.0)]

velocity = ox.State("velocity", shape=(3,))
velocity.max = [v_max.value, v_max.value, v_max.value]
velocity.min = [-v_max.value, -v_max.value, -v_max.value]
velocity.initial = [-0.5, -2.8, 0.0]
velocity.final = [-0.1, 0.0, 0.0]

attitude = ox.State("attitude", shape=(4,))
attitude.max = [1.0, 1.0, 1.0, 1.0]
attitude.min = [-1.0, -1.0, -1.0, -1.0]
attitude.initial = [ox.Free(0.0), ox.Free(0.0), ox.Free(0.0), ox.Free(1.0)]
attitude.final = [0.0, 0.0, 0.0, 1.0]

angular_velocity = ox.State("angular_velocity", shape=(3,))
angular_velocity.max = [w_max.value, w_max.value, w_max.value]
angular_velocity.min = [-w_max.value, -w_max.value, -w_max.value]
angular_velocity.initial = [1e-8, 0.0, 0.0]
angular_velocity.final = [1e-8, 0.0, 0.0]

thrust = ox.Control("thrust", shape=(3,))
thrust.max = [T_max.value, T_max.value, T_max.value]
thrust.min = [-T_max.value, -T_max.value, -T_max.value]
thrust.guess = np.linspace(
    np.array([gI.value * mass.initial[0], 0, 0]),
    np.array([gI.value * m_dry.value, 0, 0]),
    n,
).reshape(-1, 3)

# Extract quaternion components
q1 = attitude[0]
q2 = attitude[1]
q3 = attitude[2]
q4 = attitude[3]

# Direction cosine matrix (DCM) from quaternion
CBI = ox.Block(
    [
        [q4**2 + q1**2 - q2**2 - q3**2, 2 * (q1 * q2 - q4 * q3), 2 * (q4 * q2 + q1 * q3)],
        [2 * (q4 * q3 + q1 * q2), q4**2 - q1**2 + q2**2 - q3**2, 2 * (q2 * q3 - q4 * q1)],
        [2 * (q1 * q3 - q4 * q2), 2 * (q4 * q1 + q2 * q3), q4**2 - q1**2 - q2**2 + q3**2],
    ]
).T  # Transpose to get inertial to body frame


def cross(a, b):
    """Cross product of two vectors"""
    return ox.Concat(
        a[1] * b[2] - a[2] * b[1], a[2] * b[0] - a[0] * b[2], a[0] * b[1] - a[1] * b[0]
    )


w1 = angular_velocity[0]
w2 = angular_velocity[1]
w3 = angular_velocity[2]

q1_dot = 0.5 * (w1 * q4 - w2 * q3 + w3 * q2)
q2_dot = 0.5 * (w1 * q3 - w3 * q1 + w2 * q4)
q3_dot = 0.5 * (w2 * q1 - w1 * q2 + w3 * q4)
q4_dot = -0.5 * (w1 * q1 + w2 * q2 + w3 * q3)

attitude_dot = ox.Concat(q1_dot, q2_dot, q3_dot, q4_dot)

# Aerodynamic force in body frame (computed symbolically)
A = -0.5 * rho * ox.linalg.Norm(velocity) * S_a * CA @ CBI @ velocity

dynamics = {
    "mass": -(1 / (Isp * g0)) * ox.linalg.Norm(thrust) - beta,
    "position": velocity,
    "velocity": CBI.T @ (thrust + A) / mass[0] + ox.Concat(-gI, 0.0, 0.0),
    "attitude": attitude_dot,
    "angular_velocity": J_inv_mat
    @ (cross(r_arm, thrust) + cross(r_cp, A) - cross(angular_velocity, (J_mat @ angular_velocity))),
}


states = [mass, position, velocity, attitude, angular_velocity]
controls = [thrust]

constraint_exprs = []
for state in states:
    constraint_exprs.extend([ox.ctcs(state <= state.max), ox.ctcs(state.min <= state)])

# Boundary Constraints
# Initial position constraint
constraint_exprs.append((position == initial_position).convex().at([0]))

# Terminal position constraint
constraint_exprs.append((position[1:3] == final_position).convex().at([n - 1]))


constraint_exprs.append(ox.ctcs(1.0 * (mass - m_dry) >= 0))
constraint_exprs.append(
    ox.ctcs(0.1 * ox.linalg.Norm(position[1:]) - ox.Tan(gamma * np.pi / 180.0) * position[0] <= 0)
)
constraint_exprs.append(ox.ctcs(0.1 * ox.linalg.Norm(velocity) ** 2 - v_max**2 <= 0))
constraint_exprs.append(
    ox.ctcs(1.0 * ox.Cos(theta_max * np.pi / 180.0) - 1.0 + 2.0 * (q2**2 + q3**2) <= 0)
)
constraint_exprs.append(ox.ctcs(1.0 * ox.linalg.Norm(angular_velocity) ** 2 - w_max**2 <= 0))
constraint_exprs.append(
    ox.ctcs(0.1 * ox.linalg.Norm(thrust) - thrust[0] / ox.Cos(del_max * np.pi / 180.0) <= 0)
)
constraint_exprs.append(ox.ctcs(0.1 * ox.linalg.Norm(thrust) ** 2 - T_max**2 <= 0))
constraint_exprs.append(ox.ctcs(0.1 * T_min**2 - ox.linalg.Norm(thrust) ** 2 <= 0))

# Nominal final time (must match free-final guess). Old API used
# time_dilation_factor_min/max as factors times this value for absolute bounds.
t_final_guess = 10.0
time = ox.Time(
    initial=0.0,
    final=ox.Free(t_final_guess),
    min=0.0,
    max=10.0,
    time_dilation_min=0.2 * t_final_guess,
    time_dilation_max=2.0 * t_final_guess,
)

problem = Problem(
    N=n,
    states=states,
    controls=controls,
    dynamics=dynamics,
    constraints=constraint_exprs,
    time=time,
    float_dtype="float64",
    algorithm={
        "autotuner": ox.ConstantProximalWeight(),
        "lam_cost": 1e-2,
        "lam_vc": 1e1,
        "lam_prox": 1e0,
        "ep_tr": 5e-3,
        "ep_vc": 1e-6,
    },
)

problem.solver.solver_args = {"abstol": 1e-7, "reltol": 1e-7}

install_pdg_scp_console_style(problem)

# Alias for receding-horizon scripts that follow the 3DoF ``total_time`` naming.
total_time = t_final_guess

plotting_dict: dict = {}


if __name__ == "__main__":
    initialize_problem_quiet(problem)
    result = problem.solve()
    result = problem.post_process()
    result.update(plotting_dict)
    remap_optimization_results_for_viser_xyz(result)

    # Create PDG trajectory visualization
    traj_server = create_animated_plotting_server(
        result,
        thrust_key="thrust",
    )

    # Create SCP iteration visualization
    scp_server = create_scp_animated_plotting_server(
        result,
        frame_duration_ms=50.0,
    )

    # Keep both servers running
    traj_server.sleep_forever()