Files
Gym_GPU/rl_game/get_up/amp/amp_motion.py

187 lines
6.5 KiB
Python

from __future__ import annotations
from pathlib import Path
from typing import Any
import torch
import yaml
RIGHT_JOINT_SIGN_FLIP = {
"Shoulder_Roll",
"Elbow_Yaw",
"Hip_Roll",
"Hip_Yaw",
"Ankle_Roll",
}
JOINT_NAME_ALIAS = {
"AAHead_yaw": "Head_yaw",
}
def _safe_load_yaml(path: Path) -> dict[str, Any]:
with path.open("r", encoding="utf-8") as f:
payload = yaml.safe_load(f) or {}
if not isinstance(payload, dict):
raise ValueError(f"Invalid keyframe yaml: {path}")
return payload
def _build_interpolated_motor_table(
keyframes: list[dict[str, Any]],
sample_dt: float,
) -> tuple[torch.Tensor, list[str], torch.Tensor]:
if len(keyframes) == 0:
raise ValueError("No keyframes in yaml")
durations: list[float] = []
motor_names: set[str] = set()
for frame in keyframes:
durations.append(float(frame.get("delta", 0.1)))
motors = frame.get("motor_positions", {})
if isinstance(motors, dict):
motor_names.update(motors.keys())
names = sorted(motor_names)
if len(names) == 0:
raise ValueError("No motor_positions found in keyframes")
frame_count = len(keyframes)
motor_count = len(names)
times = torch.zeros(frame_count, dtype=torch.float32)
values = torch.full((frame_count, motor_count), float("nan"), dtype=torch.float32)
elapsed = 0.0
name_to_idx = {name: idx for idx, name in enumerate(names)}
for i, frame in enumerate(keyframes):
elapsed += max(float(frame.get("delta", 0.1)), 1e-4)
times[i] = elapsed
motors = frame.get("motor_positions", {})
if not isinstance(motors, dict):
continue
for motor_name, motor_deg in motors.items():
if motor_name not in name_to_idx:
continue
values[i, name_to_idx[motor_name]] = float(motor_deg) * (torch.pi / 180.0)
first_valid = torch.nan_to_num(values[0], nan=0.0)
values[0] = first_valid
for i in range(1, frame_count):
cur = values[i]
prev = values[i - 1]
values[i] = torch.where(torch.isnan(cur), prev, cur)
sample_dt = max(float(sample_dt), 1e-3)
sample_times = torch.arange(0.0, float(times[-1]) + 1e-6, sample_dt, dtype=torch.float32)
sample_times[0] = max(sample_times[0], 1e-6)
sample_times = torch.clamp(sample_times, min=times[0], max=times[-1])
upper = torch.searchsorted(times, sample_times, right=True)
upper = torch.clamp(upper, min=1, max=frame_count - 1)
lower = upper - 1
t0 = times.index_select(0, lower)
t1 = times.index_select(0, upper)
v0 = values.index_select(0, lower)
v1 = values.index_select(0, upper)
alpha = ((sample_times - t0) / (t1 - t0 + 1e-6)).unsqueeze(-1)
interpolated = v0 + alpha * (v1 - v0)
return interpolated, names, sample_times
def _map_motors_to_joint_targets(
motor_table: torch.Tensor,
motor_names: list[str],
joint_names: list[str],
) -> torch.Tensor:
motor_to_idx = {name: idx for idx, name in enumerate(motor_names)}
joint_targets = torch.zeros((motor_table.shape[0], len(joint_names)), dtype=torch.float32)
for joint_idx, joint_name in enumerate(joint_names):
alias_name = JOINT_NAME_ALIAS.get(joint_name, joint_name)
base_name = alias_name.replace("Left_", "").replace("Right_", "")
source_name = None
if alias_name in motor_to_idx:
source_name = alias_name
elif base_name in motor_to_idx:
source_name = base_name
if source_name is None:
continue
sign = 1.0
if alias_name.startswith("Right_") and base_name in RIGHT_JOINT_SIGN_FLIP:
sign = -1.0
joint_targets[:, joint_idx] = sign * motor_table[:, motor_to_idx[source_name]]
return joint_targets
def _features_from_keyframes(
keyframe_yaml_path: str,
joint_names: list[str],
sample_dt: float,
repeat_count: int,
) -> torch.Tensor:
payload = _safe_load_yaml(Path(keyframe_yaml_path).expanduser().resolve())
keyframes = payload.get("keyframes", [])
if not isinstance(keyframes, list):
raise ValueError(f"Invalid keyframe list in {keyframe_yaml_path}")
motor_table, motor_names, _ = _build_interpolated_motor_table(keyframes, sample_dt=sample_dt)
joint_pos = _map_motors_to_joint_targets(motor_table, motor_names, joint_names)
joint_vel = torch.zeros_like(joint_pos)
if joint_pos.shape[0] > 1:
joint_vel[1:] = (joint_pos[1:] - joint_pos[:-1]) / max(sample_dt, 1e-3)
joint_vel[0] = joint_vel[1]
root_lin = torch.zeros((joint_pos.shape[0], 3), dtype=torch.float32)
root_ang = torch.zeros((joint_pos.shape[0], 3), dtype=torch.float32)
projected_gravity = torch.zeros((joint_pos.shape[0], 3), dtype=torch.float32)
projected_gravity[:, 2] = -1.0
features = torch.cat([joint_pos, joint_vel, root_lin, root_ang, projected_gravity], dim=-1)
repeat_count = max(int(repeat_count), 1)
if repeat_count > 1:
features = features.repeat(repeat_count, 1)
return features
def build_amp_expert_features_from_getup_keyframes(
*,
front_yaml_path: str,
back_yaml_path: str,
joint_names: list[str],
output_path: str,
sample_dt: float = 0.04,
repeat_count: int = 16,
) -> tuple[str, tuple[int, int]]:
"""Generate AMP expert feature tensor from front/back get-up keyframe yaml files."""
front_features = _features_from_keyframes(
keyframe_yaml_path=front_yaml_path,
joint_names=joint_names,
sample_dt=sample_dt,
repeat_count=repeat_count,
)
back_features = _features_from_keyframes(
keyframe_yaml_path=back_yaml_path,
joint_names=joint_names,
sample_dt=sample_dt,
repeat_count=repeat_count,
)
expert_features = torch.cat([front_features, back_features], dim=0).contiguous()
out_path = Path(output_path).expanduser().resolve()
out_path.parent.mkdir(parents=True, exist_ok=True)
torch.save(
{
"expert_features": expert_features,
"meta": {
"front_yaml_path": str(Path(front_yaml_path).expanduser().resolve()),
"back_yaml_path": str(Path(back_yaml_path).expanduser().resolve()),
"sample_dt": float(sample_dt),
"repeat_count": int(repeat_count),
"feature_dim": int(expert_features.shape[-1]),
},
},
str(out_path),
)
return str(out_path), (int(expert_features.shape[0]), int(expert_features.shape[1]))