Skip to content

inverse_kinematics

IK-based trajectory initialization.

Generates joint-space trajectory guesses by interpolating task-space poses (position + orientation) between keyframes and solving inverse kinematics at each node via damped least-squares. Combines linspace (position), slerp (orientation), and IK into a single initialization call.

Two solve modes are available:

  • Parallel (default): Solves all nodes independently via jax.vmap. Each node starts from the same angles_init seed. Avoids propagating bad local minima but may produce less coherent joint-space paths.

  • Sequential: Solves nodes in order via jax.lax.scan, seeding each node with the previous node's solution. Produces smoother trajectories when the seed is good, but a bad early solution can propagate.

Requires jaxlie: pip install openscvx[lie]

ik_interpolation(keyframes: Sequence[Pose], nodes: Sequence[int], screw_axes: np.ndarray, T_home: np.ndarray, *, angles_init: np.ndarray = None, angles_min: np.ndarray = None, angles_max: np.ndarray = None, sequential: bool = False, damping: float = 0.001, max_iter: int = 200, tol: float = 1e-06) -> np.ndarray

Generate joint-angle trajectory guess via task-space interpolation and IK.

Interpolates end-effector poses between keyframes (linspace for position, slerp for orientation) and solves inverse kinematics at each trajectory node.

Parameters:

Name Type Description Default
keyframes Sequence[Pose]

Sequence of (position, quaternion_wxyz) tuples. Each position is array-like with shape (3,) and each quaternion is array-like with shape (4,) in [w, x, y, z] order.

required
nodes Sequence[int]

Sequence of node indices where keyframes occur. Must be sorted in ascending order and have the same length as keyframes. The last node determines the output size (N = nodes[-1] + 1).

required
screw_axes ndarray

(n_joints, 6) array of screw axes for Product of Exponentials.

required
T_home ndarray

(4, 4) home configuration transform.

required
angles_init ndarray

(n_joints,) initial joint angle guess. In parallel mode this seeds every node; in sequential mode it seeds only the first node. Defaults to zeros.

None
angles_min ndarray

(n_joints,) optional lower joint limits.

None
angles_max ndarray

(n_joints,) optional upper joint limits.

None
sequential bool

If True, solve nodes sequentially (each node seeded by the previous solution). If False (default), solve all nodes in parallel from the same angles_init seed.

False
damping float

Damping factor for least-squares IK.

0.001
max_iter int

Maximum IK iterations per node.

200
tol float

IK convergence tolerance.

1e-06

Returns:

Type Description
ndarray

np.ndarray of shape (N, n_joints) containing joint angles at each node.

Example

Initialize a 7-DOF arm trajectory reaching for a target::

import openscvx as ox

angle.guess = ox.init.ik_interpolation(
    keyframes=[
        ([0.7, 0, 0.34], [1, 0, 0, 0]),  # home pose
        ([0.3, 0.3, 0.5], [1, 0, 0, 0]),  # target pose
    ],
    nodes=[0, 49],
    screw_axes=screw_axes,
    T_home=T_home,
)
Source code in openscvx/init/inverse_kinematics.py
def ik_interpolation(
    keyframes: Sequence[Pose],
    nodes: Sequence[int],
    screw_axes: np.ndarray,
    T_home: np.ndarray,
    *,
    angles_init: np.ndarray = None,
    angles_min: np.ndarray = None,
    angles_max: np.ndarray = None,
    sequential: bool = False,
    damping: float = 1e-3,
    max_iter: int = 200,
    tol: float = 1e-6,
) -> np.ndarray:
    """Generate joint-angle trajectory guess via task-space interpolation and IK.

    Interpolates end-effector poses between keyframes (linspace for position,
    slerp for orientation) and solves inverse kinematics at each trajectory
    node.

    Args:
        keyframes: Sequence of (position, quaternion_wxyz) tuples. Each position
            is array-like with shape (3,) and each quaternion is array-like with
            shape (4,) in [w, x, y, z] order.
        nodes: Sequence of node indices where keyframes occur. Must be sorted in
            ascending order and have the same length as keyframes. The last node
            determines the output size (N = nodes[-1] + 1).
        screw_axes: (n_joints, 6) array of screw axes for Product of Exponentials.
        T_home: (4, 4) home configuration transform.
        angles_init: (n_joints,) initial joint angle guess. In parallel mode
            this seeds every node; in sequential mode it seeds only the first
            node. Defaults to zeros.
        angles_min: (n_joints,) optional lower joint limits.
        angles_max: (n_joints,) optional upper joint limits.
        sequential: If True, solve nodes sequentially (each node seeded by the
            previous solution). If False (default), solve all nodes in parallel
            from the same ``angles_init`` seed.
        damping: Damping factor for least-squares IK.
        max_iter: Maximum IK iterations per node.
        tol: IK convergence tolerance.

    Returns:
        np.ndarray of shape (N, n_joints) containing joint angles at each node.

    Example:
        Initialize a 7-DOF arm trajectory reaching for a target::

            import openscvx as ox

            angle.guess = ox.init.ik_interpolation(
                keyframes=[
                    ([0.7, 0, 0.34], [1, 0, 0, 0]),  # home pose
                    ([0.3, 0.3, 0.5], [1, 0, 0, 0]),  # target pose
                ],
                nodes=[0, 49],
                screw_axes=screw_axes,
                T_home=T_home,
            )
    """
    positions = [np.asarray(kf[0], dtype=np.float64) for kf in keyframes]
    quaternions = [np.asarray(kf[1], dtype=np.float64) for kf in keyframes]

    for i, (p, quat) in enumerate(zip(positions, quaternions)):
        if p.shape != (3,):
            raise ValueError(f"Keyframe {i} position has shape {p.shape}, expected (3,)")
        if quat.shape != (4,):
            raise ValueError(f"Keyframe {i} quaternion has shape {quat.shape}, expected (4,)")

    # Interpolate task-space trajectory
    p_traj = jnp.array(linspace(keyframes=positions, nodes=nodes))  # (N, 3)
    quat_traj = jnp.array(slerp(keyframes=quaternions, nodes=nodes))  # (N, 4)
    R_traj = jax.vmap(_quat_wxyz_to_rotmat)(quat_traj)  # (N, 3, 3)

    n_joints = screw_axes.shape[0]
    screw_axes_j = jnp.array(screw_axes)
    T_home_j = jnp.array(T_home)
    angles_lo = jnp.array(angles_min) if angles_min is not None else jnp.full(n_joints, -jnp.inf)
    angles_hi = jnp.array(angles_max) if angles_max is not None else jnp.full(n_joints, jnp.inf)

    angles0 = jnp.array(angles_init) if angles_init is not None else jnp.zeros(n_joints)

    if sequential:
        # Solve nodes in order, seeding each from the previous solution
        def scan_step(prev_angles, node_data):
            p, R = node_data
            sol = _ik_loop_pose(
                screw_axes_j,
                T_home_j,
                p,
                R,
                prev_angles,
                angles_lo,
                angles_hi,
                damping,
                tol,
                max_iter,
            )
            return sol, sol

        _, result = jax.lax.scan(scan_step, angles0, (p_traj, R_traj))
    else:
        # Solve all nodes in parallel from the same seed
        N = nodes[-1] + 1
        angles0_all = jnp.broadcast_to(angles0, (N, n_joints))
        result = jax.vmap(
            lambda p, R, a0: _ik_loop_pose(
                screw_axes_j, T_home_j, p, R, a0, angles_lo, angles_hi, damping, tol, max_iter
            )
        )(p_traj, R_traj, angles0_all)

    return np.array(result)