Let Transfer¶
Low-Energy Transfer (LET) setup in Sun-Earth CR3BP with one departure impulse.
Modeling choices: - Sun-Earth CR3BP rotating-frame dynamics (x shifted so Earth is at x=0) - Impulsive delta-v at departure and at the final node (arrival burn) - Fixed initial state, fixed final position, free final velocity - Free final time with uniform time grid (single global dilation behavior) - Objective: minimize total impulsive delta-v magnitude
File: examples/spacecraft/let_transfer.py
import os
import shutil
import sys
import time as pytime
import urllib.request
from pathlib import Path
import jax
import jax.numpy as jnp
import numpy as np
import plotly.graph_objects as go
# Add grandparent directory to path to import openscvx without installation.
current_dir = os.path.dirname(os.path.abspath(__file__))
grandparent_dir = os.path.dirname(os.path.dirname(current_dir))
sys.path.append(grandparent_dir)
import openscvx as ox
from openscvx import Problem
from openscvx.algorithms import OptimizationResults
from openscvx.integrators import solve_ivp_diffrax
from openscvx.plotting import plot_projections_2d, plot_states
from openscvx.symbolic.lower import lower_to_jax
from openscvx.symbolic.lowerers.jax.logic import set_default_float_dtype
# Keep JAX precision aligned with the problem dtype before any lowering/integration runs.
LET_FLOAT_DTYPE = "float64"
REFERENCE_DATE = "26 December 2025"
ENABLE_VISER_ANIMATION = True
ENABLE_VISER_INERTIAL_ANIMATION = True
VISER_VISUAL_SCALE = 250.0
VISER_MIN_SPEED_FOR_SAMPLING = 0.01
VISER_TARGET_FPS = 60.0
VISER_MAX_RESAMPLED_POINTS = 120000
VISER_ROTATING_PORT = 8080
VISER_INERTIAL_PORT = 8081
VISER_REQUEST_SHARE_URLS = True
KERNEL_DIR = Path(current_dir) / "ker"
KERNEL_URLS = {
"naif0012.tls": "https://naif.jpl.nasa.gov/pub/naif/generic_kernels/lsk/naif0012.tls",
"de440.bsp": "https://naif.jpl.nasa.gov/pub/naif/generic_kernels/spk/planets/de440.bsp",
"pck00011.tpc": "https://naif.jpl.nasa.gov/pub/naif/generic_kernels/pck/pck00011.tpc",
"gm_de440.tpc": "https://naif.jpl.nasa.gov/pub/naif/generic_kernels/pck/gm_de440.tpc",
}
KERNEL_FILENAMES = tuple(KERNEL_URLS.keys())
def _download_kernel(url: str, destination: Path) -> None:
"""Download a single SPICE kernel to destination atomically."""
temp_destination = destination.with_suffix(destination.suffix + ".part")
with (
urllib.request.urlopen(url, timeout=120) as response,
temp_destination.open("wb") as out_file,
):
shutil.copyfileobj(response, out_file)
temp_destination.replace(destination)
def _ensure_spice_kernels(kernel_dir: Path) -> None:
"""Ensure all required kernels exist in kernel_dir, downloading missing files."""
kernel_dir.mkdir(parents=True, exist_ok=True)
missing = [name for name in KERNEL_FILENAMES if not (kernel_dir / name).is_file()]
if not missing:
return
download_errors = []
for kernel_name in missing:
destination = kernel_dir / kernel_name
try:
_download_kernel(KERNEL_URLS[kernel_name], destination)
except Exception as exc:
part_file = destination.with_suffix(destination.suffix + ".part")
if part_file.exists():
part_file.unlink()
download_errors.append(f"{kernel_name}: {exc}")
if download_errors:
raise RuntimeError("Failed to download SPICE kernels: " + "; ".join(download_errors))
def _load_spice_problem_data(reference_date: str) -> dict:
"""Load constants and characteristic distances from SPICE kernels."""
import spiceypy as spice
_ensure_spice_kernels(KERNEL_DIR)
spice.kclear()
for kernel_name in KERNEL_FILENAMES:
spice.furnsh(str(KERNEL_DIR / kernel_name))
et = spice.str2et(reference_date)
mu_earth_val = spice.bodvrd("Earth", "GM", 1)[1][0]
mu_sun_val = spice.bodvrd("Sun", "GM", 1)[1][0]
r_earth_val = spice.bodvrd("Earth", "RADII", 3)[1][0]
pos_earth = spice.spkezr("Earth", et, "ECLIPJ2000", "NONE", "SSB")[0][:3]
pos_sun = spice.spkezr("Sun", et, "ECLIPJ2000", "NONE", "SSB")[0][:3]
pos_moon = spice.spkezr("Moon", et, "ECLIPJ2000", "NONE", "SSB")[0][:3]
return {
"mu_earth": float(mu_earth_val),
"mu_sun": float(mu_sun_val),
"r_earth": float(r_earth_val),
"d_earth_sun": float(np.linalg.norm(pos_earth - pos_sun)),
"d_earth_moon": float(np.linalg.norm(pos_earth - pos_moon)),
"kernel_dir": str(KERNEL_DIR),
"reference_date": reference_date,
}
def _normalized_node_grid(n: int, mode: str) -> np.ndarray:
"""Build a normalized node grid in [0, 1] according to the selected mode."""
s_uniform = np.linspace(0.0, 1.0, n)
mode_l = mode.strip().lower()
if mode_l == "uniform":
return s_uniform
if mode_l == "cosine":
return 0.5 * (1.0 - np.cos(np.pi * s_uniform))
raise ValueError(f"Unknown NODE_DISTRIBUTION_MODE={mode!r}. Expected 'uniform' or 'cosine'.")
def _configure_jax_float_dtype(
float_dtype: str,
*,
update_jax_enable_x64: bool = False,
) -> None:
"""Synchronize lowerer precision and optionally JAX global float mode."""
dtype_l = float_dtype.lower()
enable_x64 = dtype_l in ("float64", "f64", "double")
if update_jax_enable_x64:
jax.config.update("jax_enable_x64", enable_x64)
set_default_float_dtype(float_dtype)
def _add_moon_orbit_overlay(fig, earth_pos: np.ndarray, moon_radius: float) -> None:
"""Overlay Moon orbit projections (XY/XZ/YZ) on the 2D projection figure."""
theta = np.linspace(0.0, 2.0 * np.pi, 361)
x_orbit = earth_pos[0] + moon_radius * np.cos(theta)
y_orbit = earth_pos[1] + moon_radius * np.sin(theta)
z_orbit = np.zeros_like(theta)
orbit_line = {"color": "rgba(255, 255, 255, 0.55)", "width": 1.5, "dash": "dash"}
# XY plane
fig.add_trace(
go.Scatter(
x=x_orbit,
y=y_orbit,
mode="lines",
line=orbit_line,
name="Moon orbit",
legendgroup="moon_orbit",
showlegend=True,
),
row=1,
col=1,
)
# XZ plane (projection of the orbit onto z=0)
fig.add_trace(
go.Scatter(
x=x_orbit,
y=z_orbit,
mode="lines",
line=orbit_line,
name="Moon orbit",
legendgroup="moon_orbit",
showlegend=False,
),
row=1,
col=2,
)
# YZ plane (projection of the orbit onto z=0)
fig.add_trace(
go.Scatter(
x=y_orbit,
y=z_orbit,
mode="lines",
line=orbit_line,
name="Moon orbit",
legendgroup="moon_orbit",
showlegend=False,
),
row=2,
col=1,
)
def _set_projection_axis_labels_km(fig) -> None:
"""Set projection subplot axis labels to km."""
fig.update_xaxes(title_text="X (km)", row=1, col=1)
fig.update_yaxes(title_text="Y (km)", row=1, col=1)
fig.update_xaxes(title_text="X (km)", row=1, col=2)
fig.update_yaxes(title_text="Z (km)", row=1, col=2)
fig.update_xaxes(title_text="Y (km)", row=2, col=1)
fig.update_yaxes(title_text="Z (km)", row=2, col=1)
def _set_projection_speed_colorbar_kms(fig) -> None:
"""Relabel projection colorbar to km/s when velocity-based coloring is used."""
for trace in fig.data:
marker = getattr(trace, "marker", None)
if marker is not None and getattr(marker, "colorbar", None) is not None:
marker.colorbar.title = "‖velocity‖ (km/s)"
def _create_viser_server_compat(ox_viser, pos: np.ndarray, show_grid: bool, port: int):
"""Create a viser server across both create_server() API variants.
Some branches expose create_server(..., port=...), while others do not.
For the latter, use viser's `_VISER_PORT_OVERRIDE` env hook so we can still
bind to the requested port without changing the plotting module.
"""
try:
return ox_viser.create_server(pos, show_grid=show_grid, port=port)
except TypeError as exc:
if "unexpected keyword argument 'port'" not in str(exc):
raise
previous_port_override = os.environ.get("_VISER_PORT_OVERRIDE")
os.environ["_VISER_PORT_OVERRIDE"] = str(port)
try:
return ox_viser.create_server(pos, show_grid=show_grid)
finally:
if previous_port_override is None:
os.environ.pop("_VISER_PORT_OVERRIDE", None)
else:
os.environ["_VISER_PORT_OVERRIDE"] = previous_port_override
def _server_local_url(server, fallback_port: int) -> str:
"""Build a localhost URL from a viser server handle, with fallback."""
try:
host = str(server.get_host())
port = int(server.get_port())
if host == "0.0.0.0":
host = "localhost"
return f"http://{host}:{port}"
except Exception:
return f"http://localhost:{fallback_port}"
def _create_let_viser_server(
trajectory: np.ndarray,
traj_time_days: np.ndarray,
earth_pos: np.ndarray,
sun_pos: np.ndarray,
moon_radius: float,
moon_rate_rad_per_day: float,
guess_trajectory: np.ndarray | None = None,
port: int = VISER_ROTATING_PORT,
):
"""Create a viser server for LET trajectory playback."""
try:
from openscvx.plotting import viser as ox_viser
pos = np.asarray(trajectory[:, :3], dtype=np.float64)
vel = np.asarray(trajectory[:, 3:6], dtype=np.float64)
# Autoscale scene so tiny normalized CR3BP coordinates remain visible.
scale = max(float(moon_radius), float(np.linalg.norm(pos, axis=1).max()), 1e-9)
pos_vis = (pos / scale) * VISER_VISUAL_SCALE
earth_vis = (np.asarray(earth_pos, dtype=np.float64) / scale) * VISER_VISUAL_SCALE
moon_radius_vis = (float(moon_radius) / scale) * VISER_VISUAL_SCALE
colors = ox_viser.compute_velocity_colors(vel, fallback_length=pos.shape[0])
server = _create_viser_server_compat(
ox_viser=ox_viser, pos=pos_vis, show_grid=True, port=port
)
ox_viser.add_circular_orbit(
server,
radius=moon_radius_vis,
name="moon_orbit",
center=earth_vis,
color=(135, 135, 135),
line_width=1.6,
)
earth_radius = max(0.03 * moon_radius_vis, 0.15)
spacecraft_radius = 0.6 * earth_radius
server.scene.add_icosphere(
"/bodies/earth",
radius=earth_radius,
color=(80, 160, 255),
position=earth_vis,
)
sun_vis_real = (np.asarray(sun_pos, dtype=np.float64) / scale) * VISER_VISUAL_SCALE
sun_vis_norm = float(np.linalg.norm(sun_vis_real))
if sun_vis_norm > 1.8 * VISER_VISUAL_SCALE:
sun_vis = sun_vis_real * ((1.8 * VISER_VISUAL_SCALE) / sun_vis_norm)
else:
sun_vis = sun_vis_real
server.scene.add_icosphere(
"/bodies/sun",
radius=1.15 * earth_radius,
color=(255, 210, 70),
position=sun_vis,
)
# Phase the Moon so that at final time it rendezvous with the terminal trajectory point.
traj_time_days = np.asarray(traj_time_days, dtype=np.float64).flatten()
rel_final = pos_vis[-1] - earth_vis
rel_final_xy = np.array([rel_final[0], rel_final[1]], dtype=np.float64)
if np.linalg.norm(rel_final_xy) > 1e-12:
theta_final = float(np.arctan2(rel_final_xy[1], rel_final_xy[0]))
else:
theta_final = -0.5 * np.pi
theta_0 = theta_final - moon_rate_rad_per_day * float(traj_time_days[-1])
theta = theta_0 + moon_rate_rad_per_day * traj_time_days
moon_positions = earth_vis.reshape(1, 3) + moon_radius_vis * np.column_stack(
[np.cos(theta), np.sin(theta), np.zeros_like(theta)]
)
moon_handle = server.scene.add_icosphere(
"/bodies/moon",
radius=0.6 * earth_radius,
color=(220, 220, 220),
position=moon_positions[0],
)
if guess_trajectory is not None:
guess_pos = np.asarray(guess_trajectory[:, :3], dtype=np.float64)
guess_pos_vis = (guess_pos / scale) * VISER_VISUAL_SCALE
guess_colors = np.broadcast_to(
np.array([190, 150, 255], dtype=np.uint8), (guess_pos_vis.shape[0], 3)
).copy()
ox_viser.add_ghost_trajectory(
server, guess_pos_vis, guess_colors, opacity=0.08, point_size=0.20
)
ox_viser.add_ghost_trajectory(server, pos_vis, colors, opacity=0.25, point_size=0.25)
_, update_trail = ox_viser.add_animated_trail(server, pos_vis, colors, point_size=0.45)
_, update_marker = ox_viser.add_position_marker(
server, pos_vis, radius=spacecraft_radius, color=(255, 160, 90)
)
def update_moon(frame_idx: int) -> None:
moon_handle.position = moon_positions[frame_idx]
ox_viser.add_animation_controls(
server,
np.asarray(traj_time_days, dtype=np.float64),
[update_trail, update_marker, update_moon],
loop=True,
folder_name="LET Animation",
)
return server
except Exception as exc:
print(f"Viser animation unavailable: {exc}")
return None
def _rotate_about_z(vectors: np.ndarray, theta: np.ndarray) -> np.ndarray:
"""Rotate Nx3 vectors around +Z by angle array theta (radians)."""
vectors = np.asarray(vectors, dtype=np.float64)
theta = np.asarray(theta, dtype=np.float64).flatten()
c = np.cos(theta)
s = np.sin(theta)
out = np.empty_like(vectors)
out[:, 0] = c * vectors[:, 0] - s * vectors[:, 1]
out[:, 1] = s * vectors[:, 0] + c * vectors[:, 1]
out[:, 2] = vectors[:, 2]
return out
def _hohmann_transfer_metrics(
mu_central_km3_s2: float,
r1_km: float,
r2_km: float,
) -> dict:
"""Compute Earth-centered two-impulse Hohmann delta-v."""
if r1_km <= 0.0 or r2_km <= 0.0:
raise ValueError(f"Invalid Hohmann radii: r1={r1_km}, r2={r2_km}.")
a_t = 0.5 * (r1_km + r2_km)
v_c1 = np.sqrt(mu_central_km3_s2 / r1_km)
v_c2 = np.sqrt(mu_central_km3_s2 / r2_km)
v_t1 = np.sqrt(mu_central_km3_s2 * (2.0 / r1_km - 1.0 / a_t))
v_t2 = np.sqrt(mu_central_km3_s2 * (2.0 / r2_km - 1.0 / a_t))
dv1_km_s = abs(v_t1 - v_c1)
dv2_km_s = abs(v_c2 - v_t2)
total_dv_km_s = dv1_km_s + dv2_km_s
return {
"dv1_km_s": float(dv1_km_s),
"dv2_km_s": float(dv2_km_s),
"total_dv_km_s": float(total_dv_km_s),
}
def _create_let_viser_server_inertial(
trajectory: np.ndarray,
traj_time_days: np.ndarray,
r_ref_km: float,
d_earth_sun_km: float,
d_earth_moon_km: float,
moon_rate_rad_per_day: float,
kappa_val: float,
guess_trajectory: np.ndarray | None = None,
guess_time_days: np.ndarray | None = None,
port: int = VISER_INERTIAL_PORT,
):
"""Create a Sun-centered inertial-frame viser animation."""
try:
from openscvx.plotting import viser as ox_viser
traj_time_days = np.asarray(traj_time_days, dtype=np.float64).flatten()
pos_rot = np.asarray(trajectory[:, :3], dtype=np.float64)
vel_rot = np.asarray(trajectory[:, 3:6], dtype=np.float64)
tau = traj_time_days * d_2_sec / t_ref
theta = kappa_val * tau
rho_es_local = d_earth_sun_km / r_ref_km
rho_em_local = d_earth_moon_km / r_ref_km
earth_pos = np.column_stack(
[rho_es_local * np.cos(theta), rho_es_local * np.sin(theta), np.zeros_like(theta)]
)
sat_rel_inertial = _rotate_about_z(pos_rot, theta)
sat_pos = earth_pos + sat_rel_inertial
rel_final = sat_pos[-1] - earth_pos[-1]
rel_final_xy = np.array([rel_final[0], rel_final[1]], dtype=np.float64)
if np.linalg.norm(rel_final_xy) > 1e-12:
phi_final = float(np.arctan2(rel_final_xy[1], rel_final_xy[0]))
else:
phi_final = 0.0
phi_0 = phi_final - moon_rate_rad_per_day * float(traj_time_days[-1])
phi = phi_0 + moon_rate_rad_per_day * traj_time_days
moon_pos = earth_pos + rho_em_local * np.column_stack(
[np.cos(phi), np.sin(phi), np.zeros_like(phi)]
)
scene_max = max(
float(np.linalg.norm(sat_pos, axis=1).max()),
float(np.linalg.norm(earth_pos, axis=1).max()),
float(np.linalg.norm(moon_pos, axis=1).max()),
1e-9,
)
pos_vis = (sat_pos / scene_max) * VISER_VISUAL_SCALE
earth_vis = (earth_pos / scene_max) * VISER_VISUAL_SCALE
moon_vis = (moon_pos / scene_max) * VISER_VISUAL_SCALE
sun_vis = np.zeros(3, dtype=np.float64)
colors = ox_viser.compute_velocity_colors(vel_rot, fallback_length=pos_vis.shape[0])
server = _create_viser_server_compat(
ox_viser=ox_viser, pos=pos_vis, show_grid=True, port=port
)
sun_radius = max(0.04 * VISER_VISUAL_SCALE, 0.2)
earth_radius = max(0.0018 * VISER_VISUAL_SCALE, 0.035)
moon_radius = 0.45 * earth_radius
spacecraft_radius = 0.22 * earth_radius
server.scene.add_icosphere(
"/bodies/sun",
radius=sun_radius,
color=(255, 210, 70),
position=sun_vis,
)
earth_handle = server.scene.add_icosphere(
"/bodies/earth",
radius=earth_radius,
color=(80, 160, 255),
position=earth_vis[0],
)
moon_handle = server.scene.add_icosphere(
"/bodies/moon",
radius=moon_radius,
color=(220, 220, 220),
position=moon_vis[0],
)
earth_orbit_pos = np.column_stack(
[
rho_es_local * np.cos(np.linspace(0.0, 2.0 * np.pi, 721)),
rho_es_local * np.sin(np.linspace(0.0, 2.0 * np.pi, 721)),
np.zeros(721),
]
)
earth_orbit_vis = (earth_orbit_pos / scene_max) * VISER_VISUAL_SCALE
earth_orbit_colors = np.broadcast_to(
np.array([90, 140, 255], dtype=np.uint8), (earth_orbit_vis.shape[0], 3)
).copy()
ox_viser.add_ghost_trajectory(
server, earth_orbit_vis, earth_orbit_colors, opacity=0.14, point_size=0.20
)
moon_orbit_colors = np.broadcast_to(
np.array([175, 175, 175], dtype=np.uint8), (moon_vis.shape[0], 3)
).copy()
ox_viser.add_ghost_trajectory(
server, moon_vis, moon_orbit_colors, opacity=0.04, point_size=0.10
)
if guess_trajectory is not None:
guess_pos_rot = np.asarray(guess_trajectory[:, :3], dtype=np.float64)
if guess_time_days is None:
guess_time_days = np.linspace(0.0, traj_time_days[-1], guess_pos_rot.shape[0])
guess_tau = np.asarray(guess_time_days, dtype=np.float64).flatten() * d_2_sec / t_ref
guess_theta = kappa_val * guess_tau
guess_earth = np.column_stack(
[
rho_es_local * np.cos(guess_theta),
rho_es_local * np.sin(guess_theta),
np.zeros_like(guess_theta),
]
)
guess_sat = guess_earth + _rotate_about_z(guess_pos_rot, guess_theta)
guess_vis = (guess_sat / scene_max) * VISER_VISUAL_SCALE
guess_colors = np.broadcast_to(
np.array([180, 145, 250], dtype=np.uint8), (guess_vis.shape[0], 3)
).copy()
ox_viser.add_ghost_trajectory(
server, guess_vis, guess_colors, opacity=0.07, point_size=0.20
)
ox_viser.add_ghost_trajectory(server, pos_vis, colors, opacity=0.16, point_size=0.12)
_, update_trail = ox_viser.add_animated_trail(server, pos_vis, colors, point_size=0.18)
_, update_marker = ox_viser.add_position_marker(
server, pos_vis, radius=spacecraft_radius, color=(255, 160, 90)
)
camera_view_dir: dict[int, np.ndarray] = {}
camera_view_dist: dict[int, float] = {}
def _initialize_camera_tracking(client, earth_target: np.ndarray) -> bool:
try:
if float(client.camera.update_timestamp) <= 0.0:
return False
rel = np.asarray(client.camera.position) - np.asarray(client.camera.look_at)
rel_norm = float(np.linalg.norm(rel))
if rel_norm < 1e-6:
rel = np.array(
[0.0, -0.20 * VISER_VISUAL_SCALE, 0.08 * VISER_VISUAL_SCALE],
dtype=np.float64,
)
rel_norm = float(np.linalg.norm(rel))
camera_view_dir[client.client_id] = rel / rel_norm
camera_view_dist[client.client_id] = rel_norm
client.camera.position = earth_target + camera_view_dir[client.client_id] * rel_norm
client.camera.look_at = earth_target
return True
except Exception:
return False
@server.on_client_connect
def _on_client_connect(client) -> None:
_initialize_camera_tracking(client, earth_vis[0])
@server.on_client_disconnect
def _on_client_disconnect(client) -> None:
camera_view_dir.pop(client.client_id, None)
camera_view_dist.pop(client.client_id, None)
def update_earth(frame_idx: int) -> None:
earth_handle.position = earth_vis[frame_idx]
earth_target = earth_vis[frame_idx]
for client_id, client in server.get_clients().items():
try:
if client_id not in camera_view_dir:
if not _initialize_camera_tracking(client, earth_target):
continue
continue
current_rel = np.asarray(client.camera.position) - np.asarray(
client.camera.look_at
)
current_dist = float(np.linalg.norm(current_rel))
if current_dist > 1e-6:
camera_view_dist[client_id] = current_dist
client.camera.position = (
earth_target + camera_view_dir[client_id] * camera_view_dist[client_id]
)
client.camera.look_at = earth_target
except Exception as exc:
print(
(
f"[LET visualization] Failed to update camera for client "
f"{client_id}: {exc}"
),
file=sys.stderr,
)
continue
def update_moon(frame_idx: int) -> None:
moon_handle.position = moon_vis[frame_idx]
ox_viser.add_animation_controls(
server,
traj_time_days,
[update_trail, update_marker, update_earth, update_moon],
loop=True,
folder_name="LET Inertial (Sun-Centered)",
)
return server
except Exception as exc:
print(f"Sun-centered inertial viser animation unavailable: {exc}")
return None
def _resample_trajectory_for_viser(
trajectory: np.ndarray,
traj_time_days: np.ndarray,
) -> tuple[np.ndarray, np.ndarray]:
"""Resample trajectory for smoother low-speed viser playback."""
traj_time_days = np.asarray(traj_time_days, dtype=np.float64).flatten()
trajectory = np.asarray(trajectory, dtype=np.float64)
if traj_time_days.size < 2 or trajectory.shape[0] < 2:
return trajectory, traj_time_days
desired_dt_days = VISER_MIN_SPEED_FOR_SAMPLING / VISER_TARGET_FPS
horizon_days = float(traj_time_days[-1] - traj_time_days[0])
n_target = int(np.ceil(horizon_days / max(desired_dt_days, 1e-12))) + 1
n_target = min(max(n_target, trajectory.shape[0]), VISER_MAX_RESAMPLED_POINTS)
if n_target <= trajectory.shape[0]:
return trajectory, traj_time_days
t_dense = np.linspace(traj_time_days[0], traj_time_days[-1], n_target)
traj_dense = np.column_stack(
[np.interp(t_dense, traj_time_days, trajectory[:, i]) for i in range(trajectory.shape[1])]
)
return traj_dense, t_dense
spice_data = _load_spice_problem_data(REFERENCE_DATE)
mu_earth = spice_data["mu_earth"]
mu_sun = spice_data["mu_sun"]
r_earth = spice_data["r_earth"]
d_earth_sun = spice_data["d_earth_sun"]
d_earth_moon = spice_data["d_earth_moon"]
spice_source = f"SPICE ({spice_data['kernel_dir']})"
# Sun-Earth normalized CR3BP parameters
mu = mu_earth / (mu_earth + mu_sun)
d_2_sec = 86400.0
n_system = np.sqrt((mu_earth + mu_sun) / d_earth_sun**3)
canonical_t_ref = 1 / n_system
legacy_t_ref = d_2_sec * 365.0 / (2.0 * np.pi)
# Optional nondimensional reference scales:
# - If omitted, defaults preserve the original scaling
REF_LENGTH_KM = None
REF_TIME_S = None
ref_length_km = REF_LENGTH_KM
ref_time_s = REF_TIME_S
use_legacy_dynamics = ref_length_km is None and ref_time_s is None
r_ref = float(d_earth_sun if ref_length_km is None else ref_length_km)
t_ref = float(legacy_t_ref if ref_time_s is None else ref_time_s)
if r_ref <= 0.0 or t_ref <= 0.0:
raise ValueError(f"Invalid reference scales: L={r_ref}, T={t_ref}. Both must be positive.")
v_ref = r_ref / t_ref
kappa = n_system * t_ref
rho_es = d_earth_sun / r_ref
grav_scale = t_ref**2 / r_ref**3
# Mission setup
h_earth = 1500.0
v_moon = np.sqrt(mu_earth / d_earth_moon) / v_ref
r_0 = r_earth + h_earth
# Earth-centered rotating coordinates: Earth at x=0, Sun at x=-1.
pos_earth_rot = np.array([0.0, 0.0, 0.0])
rot_deg = 0.0
rot_rad = np.deg2rad(rot_deg)
rot_mat = np.array(
[[np.cos(rot_rad), -np.sin(rot_rad), 0], [np.sin(rot_rad), np.cos(rot_rad), 0], [0, 0, 1]]
)
pos_0 = pos_earth_rot + rot_mat @ np.array([r_0 / r_ref, 0.0, 0.0])
vel_0 = rot_mat @ np.array([0.0, 7.8 * np.sqrt(2.0) * 0.9085 / v_ref, 0.0])
x0_seed = np.concatenate([pos_0, vel_0])
pos_f = pos_earth_rot + np.array([0.0, -d_earth_moon / r_ref, 0.0])
vel_f = np.array([np.sqrt(mu_earth / d_earth_moon) / v_ref, 0.0, 0.0])
t_f_guess_days = 78.0
t_f_guess = t_f_guess_days * d_2_sec / t_ref
# Initial impulse guess.
v_circular = np.sqrt(mu_earth / r_0) / v_ref
v_circular_vect = rot_mat @ np.array([0, v_circular, 0])
delta_v0_guess = vel_0 - v_circular_vect
n_nodes = 45
integration_tol = 1e-10
integration_max_steps = 3000
# Guess-node distribution toggle:
# - "uniform": evenly spaced nodes in [0, 1]
# - "cosine": denser near interval endpoints
NODE_DISTRIBUTION_MODE = "cosine"
def _is_float64_dtype(float_dtype: str) -> bool:
return float_dtype.lower() in ("float64", "f64", "double")
def _build_let_problem_bundle(float_dtype: str = LET_FLOAT_DTYPE) -> dict:
"""Construct LET problem and precomputed JAX artifacts for one dtype."""
_configure_jax_float_dtype(float_dtype, update_jax_enable_x64=False)
jax_float_dtype = jnp.float64 if _is_float64_dtype(float_dtype) else jnp.float32
# Build symbolic CR3BP model once and reuse it for optimization and propagation.
position = ox.State("position", shape=(3,))
velocity = ox.State("velocity", shape=(3,))
fuel = ox.State("fuel", shape=(1,))
# Assign slices for standalone lowering/evaluation on [x, y, z, vx, vy, vz].
position._slice = slice(0, 3)
velocity._slice = slice(3, 6)
x_e = position[0]
y_e = position[1]
z_e = position[2]
# In this shifted frame: Earth is at x=0 and Sun is at x=-1.
# For general reference distance L, Sun is at x = -d_earth_sun / L = -rho_es.
sun_dx = x_e + rho_es
earth_dx = x_e
d_sun = ox.Sqrt(sun_dx**2 + y_e**2 + z_e**2)
d_earth = ox.Sqrt(earth_dx**2 + y_e**2 + z_e**2)
ax = (
2.0 * kappa * velocity[1]
+ kappa**2 * (x_e + rho_es * (1.0 - mu))
- grav_scale * (mu_sun * sun_dx / d_sun**3 + mu_earth * earth_dx / d_earth**3)
)
ay = (
-2.0 * kappa * velocity[0]
+ kappa**2 * y_e
- grav_scale * (mu_sun * y_e / d_sun**3 + mu_earth * y_e / d_earth**3)
)
az = -grav_scale * (mu_sun * z_e / d_sun**3 + mu_earth * z_e / d_earth**3)
velocity_dot = ox.Concat(ax, ay, az)
dynamics = {
"position": velocity,
"velocity": velocity_dot,
"fuel": 0.0,
}
delta_v = ox.Control(
"delta_v",
shape=(3,),
parameterization="impulsive",
nodes=[0, n_nodes - 1],
)
eps_impulse = 1e-12
dynamics_discrete = {
"position": position,
"velocity": velocity + delta_v,
"fuel": fuel - ox.linalg.Norm(delta_v + eps_impulse),
}
cr3bp_rhs = lower_to_jax(ox.Concat(velocity, velocity_dot))
# Dense propagation for an initialization trajectory.
guess_dense = np.asarray(
solve_ivp_diffrax(
lambda t, x: cr3bp_rhs(x, jnp.zeros((0,), dtype=x.dtype), 0, {}),
tau_final=t_f_guess,
y_0=jnp.asarray(x0_seed, dtype=jax_float_dtype),
args=(),
tau_0=0.0,
num_substeps=3000,
solver_name="Dopri8",
rtol=integration_tol,
atol=integration_tol,
),
dtype=float,
)
# Build nodal guess and apply the pre-impulse state offset at node 0.
s_uniform = np.linspace(0.0, 1.0, n_nodes)
node_grid = _normalized_node_grid(n_nodes, NODE_DISTRIBUTION_MODE)
node_idx = np.round((guess_dense.shape[0] - 1) * node_grid).astype(int)
nodal_guess = guess_dense[node_idx].copy()
# Boundary conditions
position.initial = pos_0
velocity.initial = vel_0
fuel.initial = np.array([1.0])
position.final = [
ox.Free(float(pos_f[0])),
ox.Free(float(pos_f[1])),
ox.Free(float(pos_f[2])),
]
velocity.final = [
ox.Free(float(vel_f[0])),
ox.Free(float(vel_f[1])),
ox.Free(float(vel_f[2])),
]
fuel.final = [("maximize", 0.95)]
# Guesses
position.guess = nodal_guess[:, :3]
velocity.guess = nodal_guess[:, 3:6]
fuel.guess = np.ones((n_nodes, 1))
delta_v.min = -np.ones(3)
delta_v.max = np.ones(3)
delta_v_guess = np.zeros((n_nodes, 3))
delta_v.guess = delta_v_guess
time_guess = (t_f_guess * node_grid).reshape(-1, 1)
time = ox.Time(
initial=0.0,
final=ox.Free(float(t_f_guess)),
min=0.0,
max=3.0 * t_f_guess,
guess=time_guess,
time_dilation_min=0.0001 * t_f_guess,
time_dilation_max=3.0 * t_f_guess,
uniform_time_grid=False,
)
dtdtau_guess = np.gradient(time_guess[:, 0], s_uniform)
dtdtau_guess = np.clip(dtdtau_guess, 0.01 * t_f_guess, 3.0 * t_f_guess)
time.time_dilation_guess = dtdtau_guess.reshape(-1, 1)
# Scaling
position.scaling_max = jnp.array([0.01, 0.01, 0.01])
position.scaling_min = -jnp.array([0.01, 0.01, 0.01])
velocity.scaling_max = jnp.array([0.5, 0.5, 0.5])
velocity.scaling_min = -jnp.array([0.5, 0.5, 0.5])
fuel.scaling_min = jnp.array([0.95])
fuel.scaling_max = jnp.array([1.00])
delta_v.scaling_min = velocity.scaling_min
delta_v.scaling_max = velocity.scaling_max
# Bounds
position.min = position.scaling_min
position.max = position.scaling_max
velocity.min = velocity.scaling_min
velocity.max = velocity.scaling_max
fuel.min = fuel.scaling_min
fuel.max = fuel.scaling_max
states = [position, velocity, fuel]
controls = [delta_v]
discretizer = {
"ode_solver": "Dopri8",
"diffrax_kwargs": {"atol": integration_tol, "rtol": integration_tol},
}
algorithm = {
"k_max": 150,
"lam_prox": 5e-2,
"lam_vc": 3e1,
"lam_vb": 2e-1,
"lam_cost": 0.5,
"ep_tr": 1e-9,
"ep_vc": 1e-6,
"autotuner": ox.AugmentedLagrangian(),
}
constraints = []
# Enforce final distance from Earth in normalized Sun-Earth rotating frame.
final_radius_target = d_earth_moon / r_ref
eps_radius = 1e-4
constraints += [
(ox.linalg.Norm(position - pos_earth_rot) <= final_radius_target)
.at([n_nodes - 1])
.convex(),
]
constraints += [
(ox.linalg.Norm(position - pos_earth_rot) >= (1 + eps_radius) * final_radius_target).at(
[n_nodes - 1]
),
]
# Final orbit tangency: radius and velocity orthogonal at terminal node.
constraints += [
(ox.Sum((position - pos_earth_rot) * velocity) >= 0.0).at([n_nodes - 1]),
]
constraints += [
(ox.Sum((position - pos_earth_rot) * velocity) <= 0.0).at([n_nodes - 1]),
]
# Final speed magnitude: velocity should match the moon one
constraints += [
(ox.linalg.Norm(velocity) - v_moon >= 0.0).at([n_nodes - 1]),
]
constraints += [
(ox.linalg.Norm(velocity) - v_moon <= 0.0).at([n_nodes - 1]).convex(),
]
problem = Problem(
dynamics=dynamics,
dynamics_discrete=dynamics_discrete,
states=states,
controls=controls,
time=time,
constraints=constraints,
N=n_nodes,
discretizer=discretizer,
algorithm=algorithm,
float_dtype=float_dtype,
solver={"cvx_solver": "CLARABEL", "solver_args": {}},
)
# Keep post-process propagation tolerances aligned with discretization.
problem.settings.prp.solver = "Dopri8"
problem.settings.prp.atol = integration_tol
problem.settings.prp.rtol = integration_tol
problem.settings.prp.dt = 1e-4
return {
"problem": problem,
"cr3bp_rhs": cr3bp_rhs,
"nodal_guess": nodal_guess,
"time_guess": time_guess,
}
class _LazyLETProblem:
"""Lazy proxy so test discovery does not instantiate the LET problem at import time."""
_float_dtype = LET_FLOAT_DTYPE
def __init__(self) -> None:
self._bundle: dict | None = None
def _ensure_bundle(self) -> dict:
if self._bundle is None:
self._bundle = _build_let_problem_bundle(float_dtype=self._float_dtype)
return self._bundle
def __getattr__(self, name: str):
return getattr(self._ensure_bundle()["problem"], name)
problem = _LazyLETProblem()
if __name__ == "__main__":
_configure_jax_float_dtype(LET_FLOAT_DTYPE, update_jax_enable_x64=True)
let_bundle = _build_let_problem_bundle(float_dtype=LET_FLOAT_DTYPE)
problem = let_bundle["problem"]
cr3bp_rhs = let_bundle["cr3bp_rhs"]
nodal_guess = let_bundle["nodal_guess"]
time_guess = let_bundle["time_guess"]
hohmann_metrics = _hohmann_transfer_metrics(
mu_central_km3_s2=mu_earth,
r1_km=r_0,
r2_km=d_earth_moon,
)
x0_guess_post = x0_seed.copy()
traj_guess = np.asarray(
solve_ivp_diffrax(
lambda t, x: cr3bp_rhs(x, jnp.zeros((0,), dtype=x.dtype), 0, {}),
tau_final=t_f_guess,
y_0=jnp.asarray(x0_guess_post, dtype=jnp.float64),
args=(),
tau_0=0.0,
num_substeps=3000,
solver_name="Dopri8",
rtol=integration_tol,
atol=integration_tol,
),
dtype=float,
)
guess_results = OptimizationResults(converged=True, t_final=float(t_f_guess))
guess_results.trajectory = {
"time": np.linspace(0.0, t_f_guess, traj_guess.shape[0]).reshape(-1, 1),
"position": traj_guess[:, :3] * r_ref,
"velocity": traj_guess[:, 3:6] * v_ref,
}
guess_results.nodes = {
"time": time_guess,
"position": nodal_guess[:, :3] * r_ref,
"velocity": nodal_guess[:, 3:6] * v_ref,
}
fig_guess = plot_projections_2d(guess_results, velocity_var_name="velocity")
_set_projection_axis_labels_km(fig_guess)
_set_projection_speed_colorbar_kms(fig_guess)
fig_guess.update_layout(title="LET Initial Guess - XY, XZ, YZ Projections (km)")
fig_guess.show()
problem.initialize()
results = problem.solve()
results = problem.post_process()
fig_states = plot_states(results, ["position", "velocity", "fuel"], cols=3)
fig_states.update_layout(title_text="LET Solution - State Evolution")
fig_states.update_xaxes(title_text="Time (normalized)")
fig_states.show()
t_f_opt = float(np.asarray(results.nodes["time"][-1]).squeeze())
dv0_opt = np.asarray(results.nodes["delta_v"][0], dtype=float)
dvf_opt = np.asarray(results.nodes["delta_v"][-1], dtype=float)
x0_opt_pre = np.concatenate([results.nodes["position"][0], results.nodes["velocity"][0]])
x0_opt_post = x0_opt_pre.copy()
x0_opt_post[3:6] += dv0_opt
traj_solution = np.asarray(
solve_ivp_diffrax(
lambda t, x: cr3bp_rhs(x, jnp.zeros((0,), dtype=x.dtype), 0, {}),
tau_final=t_f_opt,
y_0=jnp.asarray(x0_opt_post, dtype=jnp.float64),
args=(),
tau_0=0.0,
num_substeps=3000,
solver_name="Dopri8",
rtol=integration_tol,
atol=integration_tol,
),
dtype=float,
)
solution_results = OptimizationResults(converged=bool(results.converged), t_final=t_f_opt)
solution_results.trajectory = {
"time": np.linspace(0.0, t_f_opt, traj_solution.shape[0]).reshape(-1, 1),
"position": traj_solution[:, :3] * r_ref,
"velocity": traj_solution[:, 3:6] * v_ref,
}
solution_results.nodes = {
"time": results.nodes["time"],
"position": np.asarray(results.nodes["position"], dtype=float) * r_ref,
"velocity": np.asarray(results.nodes["velocity"], dtype=float) * v_ref,
}
fig_solution = plot_projections_2d(solution_results, velocity_var_name="velocity")
_set_projection_axis_labels_km(fig_solution)
_set_projection_speed_colorbar_kms(fig_solution)
fig_solution.update_layout(title="LET Solution - XY, XZ, YZ Projections (km)")
_add_moon_orbit_overlay(
fig_solution,
earth_pos=pos_earth_rot * r_ref,
moon_radius=d_earth_moon,
)
fig_solution.show()
final_pos = np.asarray(traj_solution[-1, :3], dtype=float)
final_radius_vec = final_pos - pos_earth_rot
final_distance_km = float(np.linalg.norm(final_radius_vec)) * r_ref
final_distance_error_km = final_distance_km - d_earth_moon
dv0_guess_norm = float(np.linalg.norm(delta_v0_guess))
dv0_opt_norm = float(np.linalg.norm(dv0_opt))
dvf_opt_norm = float(np.linalg.norm(dvf_opt))
total_dv_opt_norm = dv0_opt_norm + dvf_opt_norm
total_dv_with_guess_norm = dv0_guess_norm + total_dv_opt_norm
print(f"Converged: {bool(results.converged)}")
print(f"Final time (days): {t_f_opt * t_ref / d_2_sec:.6f}")
print(f"||delta_v0_guess|| (km/s): {dv0_guess_norm * v_ref:.9f}")
print(f"Initial delta-v (km/s): {dv0_opt * v_ref}")
print(f"||Initial delta-v|| (km/s): {dv0_opt_norm * v_ref:.9f}")
print(f"Final delta-v (km/s): {dvf_opt * v_ref}")
print(f"||Final delta-v|| (km/s): {dvf_opt_norm * v_ref:.9f}")
print(f"Total ||delta-v|| (solution only, km/s): {total_dv_opt_norm * v_ref:.9f}")
print(f"Total ||delta-v|| (+delta_v0_guess, km/s): {total_dv_with_guess_norm * v_ref:.9f}")
print(f"Final distance from Moon center (km): {final_distance_error_km:.6f}")
print(f"Hohmann dv1 (km/s): {hohmann_metrics['dv1_km_s']:.9f}")
print(f"Hohmann dv2 (km/s): {hohmann_metrics['dv2_km_s']:.9f}")
print(f"Hohmann total delta-v (km/s): {hohmann_metrics['total_dv_km_s']:.9f}")
if ENABLE_VISER_ANIMATION:
traj_time_days = np.linspace(0.0, t_f_opt * t_ref / d_2_sec, traj_solution.shape[0])
traj_solution_vis, traj_time_days_vis = _resample_trajectory_for_viser(
traj_solution, traj_time_days
)
traj_guess_time_days = np.linspace(0.0, t_f_guess * t_ref / d_2_sec, traj_guess.shape[0])
traj_guess_vis, traj_guess_time_days_vis = _resample_trajectory_for_viser(
traj_guess, traj_guess_time_days
)
moon_rate_rad_per_day = np.sqrt(mu_earth / d_earth_moon**3) * d_2_sec
viser_server = _create_let_viser_server(
trajectory=traj_solution_vis,
traj_time_days=traj_time_days_vis,
earth_pos=pos_earth_rot,
sun_pos=np.array([-rho_es, 0.0, 0.0], dtype=float),
moon_radius=d_earth_moon / r_ref,
moon_rate_rad_per_day=moon_rate_rad_per_day,
guess_trajectory=traj_guess_vis,
port=VISER_ROTATING_PORT,
)
inertial_server = None
if ENABLE_VISER_INERTIAL_ANIMATION:
inertial_server = _create_let_viser_server_inertial(
trajectory=traj_solution_vis,
traj_time_days=traj_time_days_vis,
r_ref_km=r_ref,
d_earth_sun_km=d_earth_sun,
d_earth_moon_km=d_earth_moon,
moon_rate_rad_per_day=moon_rate_rad_per_day,
kappa_val=kappa,
guess_trajectory=traj_guess_vis,
guess_time_days=traj_guess_time_days_vis,
port=VISER_INERTIAL_PORT,
)
if viser_server is not None or inertial_server is not None:
rotating_url = _server_local_url(viser_server, VISER_ROTATING_PORT)
inertial_url = _server_local_url(inertial_server, VISER_INERTIAL_PORT)
print("Launching viser animation server(s) (Ctrl+C to exit)...")
if viser_server is not None:
print(f"Rotating frame viewer: {rotating_url}")
if VISER_REQUEST_SHARE_URLS:
try:
rotating_share_url = viser_server.request_share_url(verbose=True)
if rotating_share_url is not None:
print(f"Rotating frame public URL: {rotating_share_url}")
except Exception as exc:
print(f"Rotating frame share URL unavailable: {exc}")
if inertial_server is not None:
print(f"Sun-centered inertial viewer: {inertial_url}")
if VISER_REQUEST_SHARE_URLS:
try:
inertial_share_url = inertial_server.request_share_url(verbose=True)
if inertial_share_url is not None:
print(f"Sun-centered inertial public URL: {inertial_share_url}")
except Exception as exc:
print(f"Sun-centered inertial share URL unavailable: {exc}")
while True:
pytime.sleep(1.0)