Add AMP get-up pipeline with sequence discriminator and git-sourced expert data
This commit is contained in:
186
rl_game/get_up/amp/amp_motion.py
Normal file
186
rl_game/get_up/amp/amp_motion.py
Normal file
@@ -0,0 +1,186 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
|
||||
|
||||
RIGHT_JOINT_SIGN_FLIP = {
|
||||
"Shoulder_Roll",
|
||||
"Elbow_Yaw",
|
||||
"Hip_Roll",
|
||||
"Hip_Yaw",
|
||||
"Ankle_Roll",
|
||||
}
|
||||
|
||||
JOINT_NAME_ALIAS = {
|
||||
"AAHead_yaw": "Head_yaw",
|
||||
}
|
||||
|
||||
|
||||
def _safe_load_yaml(path: Path) -> dict[str, Any]:
|
||||
with path.open("r", encoding="utf-8") as f:
|
||||
payload = yaml.safe_load(f) or {}
|
||||
if not isinstance(payload, dict):
|
||||
raise ValueError(f"Invalid keyframe yaml: {path}")
|
||||
return payload
|
||||
|
||||
|
||||
def _build_interpolated_motor_table(
|
||||
keyframes: list[dict[str, Any]],
|
||||
sample_dt: float,
|
||||
) -> tuple[torch.Tensor, list[str], torch.Tensor]:
|
||||
if len(keyframes) == 0:
|
||||
raise ValueError("No keyframes in yaml")
|
||||
|
||||
durations: list[float] = []
|
||||
motor_names: set[str] = set()
|
||||
for frame in keyframes:
|
||||
durations.append(float(frame.get("delta", 0.1)))
|
||||
motors = frame.get("motor_positions", {})
|
||||
if isinstance(motors, dict):
|
||||
motor_names.update(motors.keys())
|
||||
names = sorted(motor_names)
|
||||
if len(names) == 0:
|
||||
raise ValueError("No motor_positions found in keyframes")
|
||||
|
||||
frame_count = len(keyframes)
|
||||
motor_count = len(names)
|
||||
times = torch.zeros(frame_count, dtype=torch.float32)
|
||||
values = torch.full((frame_count, motor_count), float("nan"), dtype=torch.float32)
|
||||
|
||||
elapsed = 0.0
|
||||
name_to_idx = {name: idx for idx, name in enumerate(names)}
|
||||
for i, frame in enumerate(keyframes):
|
||||
elapsed += max(float(frame.get("delta", 0.1)), 1e-4)
|
||||
times[i] = elapsed
|
||||
motors = frame.get("motor_positions", {})
|
||||
if not isinstance(motors, dict):
|
||||
continue
|
||||
for motor_name, motor_deg in motors.items():
|
||||
if motor_name not in name_to_idx:
|
||||
continue
|
||||
values[i, name_to_idx[motor_name]] = float(motor_deg) * (torch.pi / 180.0)
|
||||
|
||||
first_valid = torch.nan_to_num(values[0], nan=0.0)
|
||||
values[0] = first_valid
|
||||
for i in range(1, frame_count):
|
||||
cur = values[i]
|
||||
prev = values[i - 1]
|
||||
values[i] = torch.where(torch.isnan(cur), prev, cur)
|
||||
|
||||
sample_dt = max(float(sample_dt), 1e-3)
|
||||
sample_times = torch.arange(0.0, float(times[-1]) + 1e-6, sample_dt, dtype=torch.float32)
|
||||
sample_times[0] = max(sample_times[0], 1e-6)
|
||||
sample_times = torch.clamp(sample_times, min=times[0], max=times[-1])
|
||||
|
||||
upper = torch.searchsorted(times, sample_times, right=True)
|
||||
upper = torch.clamp(upper, min=1, max=frame_count - 1)
|
||||
lower = upper - 1
|
||||
|
||||
t0 = times.index_select(0, lower)
|
||||
t1 = times.index_select(0, upper)
|
||||
v0 = values.index_select(0, lower)
|
||||
v1 = values.index_select(0, upper)
|
||||
alpha = ((sample_times - t0) / (t1 - t0 + 1e-6)).unsqueeze(-1)
|
||||
interpolated = v0 + alpha * (v1 - v0)
|
||||
|
||||
return interpolated, names, sample_times
|
||||
|
||||
|
||||
def _map_motors_to_joint_targets(
|
||||
motor_table: torch.Tensor,
|
||||
motor_names: list[str],
|
||||
joint_names: list[str],
|
||||
) -> torch.Tensor:
|
||||
motor_to_idx = {name: idx for idx, name in enumerate(motor_names)}
|
||||
joint_targets = torch.zeros((motor_table.shape[0], len(joint_names)), dtype=torch.float32)
|
||||
|
||||
for joint_idx, joint_name in enumerate(joint_names):
|
||||
alias_name = JOINT_NAME_ALIAS.get(joint_name, joint_name)
|
||||
base_name = alias_name.replace("Left_", "").replace("Right_", "")
|
||||
source_name = None
|
||||
if alias_name in motor_to_idx:
|
||||
source_name = alias_name
|
||||
elif base_name in motor_to_idx:
|
||||
source_name = base_name
|
||||
if source_name is None:
|
||||
continue
|
||||
sign = 1.0
|
||||
if alias_name.startswith("Right_") and base_name in RIGHT_JOINT_SIGN_FLIP:
|
||||
sign = -1.0
|
||||
joint_targets[:, joint_idx] = sign * motor_table[:, motor_to_idx[source_name]]
|
||||
|
||||
return joint_targets
|
||||
|
||||
|
||||
def _features_from_keyframes(
|
||||
keyframe_yaml_path: str,
|
||||
joint_names: list[str],
|
||||
sample_dt: float,
|
||||
repeat_count: int,
|
||||
) -> torch.Tensor:
|
||||
payload = _safe_load_yaml(Path(keyframe_yaml_path).expanduser().resolve())
|
||||
keyframes = payload.get("keyframes", [])
|
||||
if not isinstance(keyframes, list):
|
||||
raise ValueError(f"Invalid keyframe list in {keyframe_yaml_path}")
|
||||
|
||||
motor_table, motor_names, _ = _build_interpolated_motor_table(keyframes, sample_dt=sample_dt)
|
||||
joint_pos = _map_motors_to_joint_targets(motor_table, motor_names, joint_names)
|
||||
joint_vel = torch.zeros_like(joint_pos)
|
||||
if joint_pos.shape[0] > 1:
|
||||
joint_vel[1:] = (joint_pos[1:] - joint_pos[:-1]) / max(sample_dt, 1e-3)
|
||||
joint_vel[0] = joint_vel[1]
|
||||
root_lin = torch.zeros((joint_pos.shape[0], 3), dtype=torch.float32)
|
||||
root_ang = torch.zeros((joint_pos.shape[0], 3), dtype=torch.float32)
|
||||
projected_gravity = torch.zeros((joint_pos.shape[0], 3), dtype=torch.float32)
|
||||
projected_gravity[:, 2] = -1.0
|
||||
features = torch.cat([joint_pos, joint_vel, root_lin, root_ang, projected_gravity], dim=-1)
|
||||
repeat_count = max(int(repeat_count), 1)
|
||||
if repeat_count > 1:
|
||||
features = features.repeat(repeat_count, 1)
|
||||
return features
|
||||
|
||||
|
||||
def build_amp_expert_features_from_getup_keyframes(
|
||||
*,
|
||||
front_yaml_path: str,
|
||||
back_yaml_path: str,
|
||||
joint_names: list[str],
|
||||
output_path: str,
|
||||
sample_dt: float = 0.04,
|
||||
repeat_count: int = 16,
|
||||
) -> tuple[str, tuple[int, int]]:
|
||||
"""Generate AMP expert feature tensor from front/back get-up keyframe yaml files."""
|
||||
front_features = _features_from_keyframes(
|
||||
keyframe_yaml_path=front_yaml_path,
|
||||
joint_names=joint_names,
|
||||
sample_dt=sample_dt,
|
||||
repeat_count=repeat_count,
|
||||
)
|
||||
back_features = _features_from_keyframes(
|
||||
keyframe_yaml_path=back_yaml_path,
|
||||
joint_names=joint_names,
|
||||
sample_dt=sample_dt,
|
||||
repeat_count=repeat_count,
|
||||
)
|
||||
expert_features = torch.cat([front_features, back_features], dim=0).contiguous()
|
||||
|
||||
out_path = Path(output_path).expanduser().resolve()
|
||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
torch.save(
|
||||
{
|
||||
"expert_features": expert_features,
|
||||
"meta": {
|
||||
"front_yaml_path": str(Path(front_yaml_path).expanduser().resolve()),
|
||||
"back_yaml_path": str(Path(back_yaml_path).expanduser().resolve()),
|
||||
"sample_dt": float(sample_dt),
|
||||
"repeat_count": int(repeat_count),
|
||||
"feature_dim": int(expert_features.shape[-1]),
|
||||
},
|
||||
},
|
||||
str(out_path),
|
||||
)
|
||||
return str(out_path), (int(expert_features.shape[0]), int(expert_features.shape[1]))
|
||||
Reference in New Issue
Block a user