from __future__ import annotations from pathlib import Path from typing import Any import torch import yaml RIGHT_JOINT_SIGN_FLIP = { "Shoulder_Roll", "Elbow_Yaw", "Hip_Roll", "Hip_Yaw", "Ankle_Roll", } JOINT_NAME_ALIAS = { "AAHead_yaw": "Head_yaw", } def _safe_load_yaml(path: Path) -> dict[str, Any]: with path.open("r", encoding="utf-8") as f: payload = yaml.safe_load(f) or {} if not isinstance(payload, dict): raise ValueError(f"Invalid keyframe yaml: {path}") return payload def _build_interpolated_motor_table( keyframes: list[dict[str, Any]], sample_dt: float, ) -> tuple[torch.Tensor, list[str], torch.Tensor]: if len(keyframes) == 0: raise ValueError("No keyframes in yaml") durations: list[float] = [] motor_names: set[str] = set() for frame in keyframes: durations.append(float(frame.get("delta", 0.1))) motors = frame.get("motor_positions", {}) if isinstance(motors, dict): motor_names.update(motors.keys()) names = sorted(motor_names) if len(names) == 0: raise ValueError("No motor_positions found in keyframes") frame_count = len(keyframes) motor_count = len(names) times = torch.zeros(frame_count, dtype=torch.float32) values = torch.full((frame_count, motor_count), float("nan"), dtype=torch.float32) elapsed = 0.0 name_to_idx = {name: idx for idx, name in enumerate(names)} for i, frame in enumerate(keyframes): elapsed += max(float(frame.get("delta", 0.1)), 1e-4) times[i] = elapsed motors = frame.get("motor_positions", {}) if not isinstance(motors, dict): continue for motor_name, motor_deg in motors.items(): if motor_name not in name_to_idx: continue values[i, name_to_idx[motor_name]] = float(motor_deg) * (torch.pi / 180.0) first_valid = torch.nan_to_num(values[0], nan=0.0) values[0] = first_valid for i in range(1, frame_count): cur = values[i] prev = values[i - 1] values[i] = torch.where(torch.isnan(cur), prev, cur) sample_dt = max(float(sample_dt), 1e-3) sample_times = torch.arange(0.0, float(times[-1]) + 1e-6, sample_dt, dtype=torch.float32) sample_times[0] = max(sample_times[0], 1e-6) sample_times = torch.clamp(sample_times, min=times[0], max=times[-1]) upper = torch.searchsorted(times, sample_times, right=True) upper = torch.clamp(upper, min=1, max=frame_count - 1) lower = upper - 1 t0 = times.index_select(0, lower) t1 = times.index_select(0, upper) v0 = values.index_select(0, lower) v1 = values.index_select(0, upper) alpha = ((sample_times - t0) / (t1 - t0 + 1e-6)).unsqueeze(-1) interpolated = v0 + alpha * (v1 - v0) return interpolated, names, sample_times def _map_motors_to_joint_targets( motor_table: torch.Tensor, motor_names: list[str], joint_names: list[str], ) -> torch.Tensor: motor_to_idx = {name: idx for idx, name in enumerate(motor_names)} joint_targets = torch.zeros((motor_table.shape[0], len(joint_names)), dtype=torch.float32) for joint_idx, joint_name in enumerate(joint_names): alias_name = JOINT_NAME_ALIAS.get(joint_name, joint_name) base_name = alias_name.replace("Left_", "").replace("Right_", "") source_name = None if alias_name in motor_to_idx: source_name = alias_name elif base_name in motor_to_idx: source_name = base_name if source_name is None: continue sign = 1.0 if alias_name.startswith("Right_") and base_name in RIGHT_JOINT_SIGN_FLIP: sign = -1.0 joint_targets[:, joint_idx] = sign * motor_table[:, motor_to_idx[source_name]] return joint_targets def _features_from_keyframes( keyframe_yaml_path: str, joint_names: list[str], sample_dt: float, repeat_count: int, ) -> torch.Tensor: payload = _safe_load_yaml(Path(keyframe_yaml_path).expanduser().resolve()) keyframes = payload.get("keyframes", []) if not isinstance(keyframes, list): raise ValueError(f"Invalid keyframe list in {keyframe_yaml_path}") motor_table, motor_names, _ = _build_interpolated_motor_table(keyframes, sample_dt=sample_dt) joint_pos = _map_motors_to_joint_targets(motor_table, motor_names, joint_names) joint_vel = torch.zeros_like(joint_pos) if joint_pos.shape[0] > 1: joint_vel[1:] = (joint_pos[1:] - joint_pos[:-1]) / max(sample_dt, 1e-3) joint_vel[0] = joint_vel[1] root_lin = torch.zeros((joint_pos.shape[0], 3), dtype=torch.float32) root_ang = torch.zeros((joint_pos.shape[0], 3), dtype=torch.float32) projected_gravity = torch.zeros((joint_pos.shape[0], 3), dtype=torch.float32) projected_gravity[:, 2] = -1.0 features = torch.cat([joint_pos, joint_vel, root_lin, root_ang, projected_gravity], dim=-1) repeat_count = max(int(repeat_count), 1) if repeat_count > 1: features = features.repeat(repeat_count, 1) return features def build_amp_expert_features_from_getup_keyframes( *, front_yaml_path: str, back_yaml_path: str, joint_names: list[str], output_path: str, sample_dt: float = 0.04, repeat_count: int = 16, ) -> tuple[str, tuple[int, int]]: """Generate AMP expert feature tensor from front/back get-up keyframe yaml files.""" front_features = _features_from_keyframes( keyframe_yaml_path=front_yaml_path, joint_names=joint_names, sample_dt=sample_dt, repeat_count=repeat_count, ) back_features = _features_from_keyframes( keyframe_yaml_path=back_yaml_path, joint_names=joint_names, sample_dt=sample_dt, repeat_count=repeat_count, ) expert_features = torch.cat([front_features, back_features], dim=0).contiguous() out_path = Path(output_path).expanduser().resolve() out_path.parent.mkdir(parents=True, exist_ok=True) torch.save( { "expert_features": expert_features, "meta": { "front_yaml_path": str(Path(front_yaml_path).expanduser().resolve()), "back_yaml_path": str(Path(back_yaml_path).expanduser().resolve()), "sample_dt": float(sample_dt), "repeat_count": int(repeat_count), "feature_dim": int(expert_features.shape[-1]), }, }, str(out_path), ) return str(out_path), (int(expert_features.shape[0]), int(expert_features.shape[1]))