diff --git a/rl_game/get_up/amp/README.md b/rl_game/get_up/amp/README.md new file mode 100644 index 0000000..c49b416 --- /dev/null +++ b/rl_game/get_up/amp/README.md @@ -0,0 +1,62 @@ +# AMP Tools for `get_up` + +This folder contains all AMP-related code for `rl_game/get_up`. + +## Files + +- `amp_rewards.py`: AMP discriminator + reward function used by training config. +- `amp_motion.py`: Build AMP expert features from local get-up keyframe YAML files. +- `migrate_legged_lab_expert_template.py`: Template converter for migrating external expert data (for example legged_lab outputs) to `expert_features.pt`. + +## Quick start + +Generate expert features from current local keyframes: + +```bash +python rl_game/get_up/train.py --amp_from_keyframes --headless +``` + +Convert external motion/expert file to AMP template: + +```bash +python rl_game/get_up/amp/migrate_legged_lab_expert_template.py \ + --input /path/to/source_data.pt \ + --output rl_game/get_up/amp/expert_features.pt \ + --input_key expert_features \ + --feature_dim 55 \ + --repeat 4 +``` + +Convert downloaded `legged_lab` motion pickles directly: + +```bash +python rl_game/get_up/amp/migrate_legged_lab_expert_template.py \ + --input third_party/legged_lab/source/legged_lab/legged_lab/data/MotionData/g1_29dof/amp/walk_and_run \ + --input_glob "*.pkl" \ + --target_dof 23 \ + --feature_dim 55 \ + --clip_weight_mode uniform \ + --output rl_game/get_up/amp/expert_features.pt +``` + +For get-up data from gitee legged_lab, use git-like focused clip sampling: + +```bash +python rl_game/get_up/amp/migrate_legged_lab_expert_template.py \ + --input third_party/legged_lab_gitee/source/legged_lab/legged_lab/data/MotionData/g1_29dof/amp/get_up \ + --input_glob "*.pkl" \ + --target_dof 23 \ + --feature_dim 55 \ + --clip_weight_mode git_getup_focus \ + --output rl_game/get_up/amp/expert_features.pt +``` + +Then train with online AMP discriminator: + +```bash +python rl_game/get_up/train.py \ + --amp_train_discriminator \ + --amp_expert_features rl_game/get_up/amp/expert_features.pt \ + --amp_reward_weight 0.6 \ + --headless +``` diff --git a/rl_game/get_up/amp/__init__.py b/rl_game/get_up/amp/__init__.py new file mode 100644 index 0000000..128f859 --- /dev/null +++ b/rl_game/get_up/amp/__init__.py @@ -0,0 +1,7 @@ +from .amp_motion import build_amp_expert_features_from_getup_keyframes +from .amp_rewards import amp_style_prior_reward + +__all__ = [ + "amp_style_prior_reward", + "build_amp_expert_features_from_getup_keyframes", +] diff --git a/rl_game/get_up/amp/amp_motion.py b/rl_game/get_up/amp/amp_motion.py new file mode 100644 index 0000000..c303d5e --- /dev/null +++ b/rl_game/get_up/amp/amp_motion.py @@ -0,0 +1,186 @@ +from __future__ import annotations + +from pathlib import Path +from typing import Any + +import torch +import yaml + + +RIGHT_JOINT_SIGN_FLIP = { + "Shoulder_Roll", + "Elbow_Yaw", + "Hip_Roll", + "Hip_Yaw", + "Ankle_Roll", +} + +JOINT_NAME_ALIAS = { + "AAHead_yaw": "Head_yaw", +} + + +def _safe_load_yaml(path: Path) -> dict[str, Any]: + with path.open("r", encoding="utf-8") as f: + payload = yaml.safe_load(f) or {} + if not isinstance(payload, dict): + raise ValueError(f"Invalid keyframe yaml: {path}") + return payload + + +def _build_interpolated_motor_table( + keyframes: list[dict[str, Any]], + sample_dt: float, +) -> tuple[torch.Tensor, list[str], torch.Tensor]: + if len(keyframes) == 0: + raise ValueError("No keyframes in yaml") + + durations: list[float] = [] + motor_names: set[str] = set() + for frame in keyframes: + durations.append(float(frame.get("delta", 0.1))) + motors = frame.get("motor_positions", {}) + if isinstance(motors, dict): + motor_names.update(motors.keys()) + names = sorted(motor_names) + if len(names) == 0: + raise ValueError("No motor_positions found in keyframes") + + frame_count = len(keyframes) + motor_count = len(names) + times = torch.zeros(frame_count, dtype=torch.float32) + values = torch.full((frame_count, motor_count), float("nan"), dtype=torch.float32) + + elapsed = 0.0 + name_to_idx = {name: idx for idx, name in enumerate(names)} + for i, frame in enumerate(keyframes): + elapsed += max(float(frame.get("delta", 0.1)), 1e-4) + times[i] = elapsed + motors = frame.get("motor_positions", {}) + if not isinstance(motors, dict): + continue + for motor_name, motor_deg in motors.items(): + if motor_name not in name_to_idx: + continue + values[i, name_to_idx[motor_name]] = float(motor_deg) * (torch.pi / 180.0) + + first_valid = torch.nan_to_num(values[0], nan=0.0) + values[0] = first_valid + for i in range(1, frame_count): + cur = values[i] + prev = values[i - 1] + values[i] = torch.where(torch.isnan(cur), prev, cur) + + sample_dt = max(float(sample_dt), 1e-3) + sample_times = torch.arange(0.0, float(times[-1]) + 1e-6, sample_dt, dtype=torch.float32) + sample_times[0] = max(sample_times[0], 1e-6) + sample_times = torch.clamp(sample_times, min=times[0], max=times[-1]) + + upper = torch.searchsorted(times, sample_times, right=True) + upper = torch.clamp(upper, min=1, max=frame_count - 1) + lower = upper - 1 + + t0 = times.index_select(0, lower) + t1 = times.index_select(0, upper) + v0 = values.index_select(0, lower) + v1 = values.index_select(0, upper) + alpha = ((sample_times - t0) / (t1 - t0 + 1e-6)).unsqueeze(-1) + interpolated = v0 + alpha * (v1 - v0) + + return interpolated, names, sample_times + + +def _map_motors_to_joint_targets( + motor_table: torch.Tensor, + motor_names: list[str], + joint_names: list[str], +) -> torch.Tensor: + motor_to_idx = {name: idx for idx, name in enumerate(motor_names)} + joint_targets = torch.zeros((motor_table.shape[0], len(joint_names)), dtype=torch.float32) + + for joint_idx, joint_name in enumerate(joint_names): + alias_name = JOINT_NAME_ALIAS.get(joint_name, joint_name) + base_name = alias_name.replace("Left_", "").replace("Right_", "") + source_name = None + if alias_name in motor_to_idx: + source_name = alias_name + elif base_name in motor_to_idx: + source_name = base_name + if source_name is None: + continue + sign = 1.0 + if alias_name.startswith("Right_") and base_name in RIGHT_JOINT_SIGN_FLIP: + sign = -1.0 + joint_targets[:, joint_idx] = sign * motor_table[:, motor_to_idx[source_name]] + + return joint_targets + + +def _features_from_keyframes( + keyframe_yaml_path: str, + joint_names: list[str], + sample_dt: float, + repeat_count: int, +) -> torch.Tensor: + payload = _safe_load_yaml(Path(keyframe_yaml_path).expanduser().resolve()) + keyframes = payload.get("keyframes", []) + if not isinstance(keyframes, list): + raise ValueError(f"Invalid keyframe list in {keyframe_yaml_path}") + + motor_table, motor_names, _ = _build_interpolated_motor_table(keyframes, sample_dt=sample_dt) + joint_pos = _map_motors_to_joint_targets(motor_table, motor_names, joint_names) + joint_vel = torch.zeros_like(joint_pos) + if joint_pos.shape[0] > 1: + joint_vel[1:] = (joint_pos[1:] - joint_pos[:-1]) / max(sample_dt, 1e-3) + joint_vel[0] = joint_vel[1] + root_lin = torch.zeros((joint_pos.shape[0], 3), dtype=torch.float32) + root_ang = torch.zeros((joint_pos.shape[0], 3), dtype=torch.float32) + projected_gravity = torch.zeros((joint_pos.shape[0], 3), dtype=torch.float32) + projected_gravity[:, 2] = -1.0 + features = torch.cat([joint_pos, joint_vel, root_lin, root_ang, projected_gravity], dim=-1) + repeat_count = max(int(repeat_count), 1) + if repeat_count > 1: + features = features.repeat(repeat_count, 1) + return features + + +def build_amp_expert_features_from_getup_keyframes( + *, + front_yaml_path: str, + back_yaml_path: str, + joint_names: list[str], + output_path: str, + sample_dt: float = 0.04, + repeat_count: int = 16, +) -> tuple[str, tuple[int, int]]: + """Generate AMP expert feature tensor from front/back get-up keyframe yaml files.""" + front_features = _features_from_keyframes( + keyframe_yaml_path=front_yaml_path, + joint_names=joint_names, + sample_dt=sample_dt, + repeat_count=repeat_count, + ) + back_features = _features_from_keyframes( + keyframe_yaml_path=back_yaml_path, + joint_names=joint_names, + sample_dt=sample_dt, + repeat_count=repeat_count, + ) + expert_features = torch.cat([front_features, back_features], dim=0).contiguous() + + out_path = Path(output_path).expanduser().resolve() + out_path.parent.mkdir(parents=True, exist_ok=True) + torch.save( + { + "expert_features": expert_features, + "meta": { + "front_yaml_path": str(Path(front_yaml_path).expanduser().resolve()), + "back_yaml_path": str(Path(back_yaml_path).expanduser().resolve()), + "sample_dt": float(sample_dt), + "repeat_count": int(repeat_count), + "feature_dim": int(expert_features.shape[-1]), + }, + }, + str(out_path), + ) + return str(out_path), (int(expert_features.shape[0]), int(expert_features.shape[1])) diff --git a/rl_game/get_up/amp/amp_rewards.py b/rl_game/get_up/amp/amp_rewards.py new file mode 100644 index 0000000..d20854e --- /dev/null +++ b/rl_game/get_up/amp/amp_rewards.py @@ -0,0 +1,457 @@ +from __future__ import annotations + +import pickle +from pathlib import Path + +import joblib +import numpy as np +import torch +import torch.nn as nn +from isaaclab.envs import ManagerBasedRLEnv + + +def _safe_tensor(x: torch.Tensor, nan: float = 0.0, pos: float = 1e3, neg: float = -1e3) -> torch.Tensor: + return torch.nan_to_num(x, nan=nan, posinf=pos, neginf=neg) + + +class AMPDiscriminator(nn.Module): + """Lightweight discriminator used by online AMP updates.""" + + def __init__(self, input_dim: int, hidden_dims: tuple[int, ...]): + super().__init__() + layers: list[nn.Module] = [] + in_dim = input_dim + for h_dim in hidden_dims: + layers.append(nn.Linear(in_dim, h_dim)) + layers.append(nn.LayerNorm(h_dim)) + layers.append(nn.SiLU()) + in_dim = h_dim + layers.append(nn.Linear(in_dim, 1)) + self.net = nn.Sequential(*layers) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.net(x) + + +def _to_tensor_2d(value) -> torch.Tensor | None: + if isinstance(value, torch.Tensor): + t = value.float() + elif isinstance(value, np.ndarray): + t = torch.as_tensor(value, dtype=torch.float32) + elif isinstance(value, list): + try: + t = torch.as_tensor(value, dtype=torch.float32) + except Exception: + return None + else: + return None + if t.ndim == 1: + t = t.unsqueeze(0) + if t.ndim != 2: + return None + return t + + +def _normalize_feature_dim(x: torch.Tensor, feature_dim: int) -> torch.Tensor: + if x.shape[-1] == feature_dim: + return x + if x.shape[-1] > feature_dim: + return x[:, :feature_dim] + pad = torch.zeros((x.shape[0], feature_dim - x.shape[1]), dtype=x.dtype) + return torch.cat([x, pad], dim=-1) + + +def _extract_expert_bank_from_payload(payload, feature_dim: int) -> dict | None: + clip_tensors: list[torch.Tensor] = [] + clip_names: list[str] = [] + clip_weights_tensor: torch.Tensor | None = None + + if isinstance(payload, dict): + clip_values = payload.get("expert_clips", None) + if isinstance(clip_values, list): + for i, clip_value in enumerate(clip_values): + clip_tensor = _to_tensor_2d(clip_value) + if clip_tensor is None: + continue + clip_tensors.append(_normalize_feature_dim(clip_tensor, feature_dim)) + clip_names.append(f"clip_{i}") + raw_names = payload.get("clip_names", None) + if isinstance(raw_names, list) and len(raw_names) == len(clip_tensors): + clip_names = [str(n) for n in raw_names] + raw_weights = payload.get("clip_weights", None) + clip_weights_tensor = _to_tensor_2d(raw_weights) if raw_weights is not None else None + if clip_weights_tensor is not None: + clip_weights_tensor = clip_weights_tensor.reshape(-1) + if clip_weights_tensor.shape[0] != len(clip_tensors): + clip_weights_tensor = None + + for key in ("expert_features", "features", "obs"): + value = payload.get(key, None) + tensor = _to_tensor_2d(value) + if tensor is not None: + clip_tensors.append(_normalize_feature_dim(tensor, feature_dim)) + clip_names.append(key) + break + else: + tensor = _to_tensor_2d(payload) + if tensor is not None: + clip_tensors.append(_normalize_feature_dim(tensor, feature_dim)) + clip_names.append("expert_features") + + clip_tensors = [c for c in clip_tensors if c.shape[0] >= 2] + if len(clip_tensors) == 0: + return None + + if clip_weights_tensor is None: + clip_weights_tensor = torch.ones(len(clip_tensors), dtype=torch.float32) + else: + clip_weights_tensor = torch.clamp(clip_weights_tensor.float(), min=0.0) + if float(torch.sum(clip_weights_tensor).item()) <= 0.0: + clip_weights_tensor = torch.ones(len(clip_tensors), dtype=torch.float32) + + flat_features = torch.cat(clip_tensors, dim=0) + return { + "flat_features": flat_features, + "clips": clip_tensors, + "clip_names": clip_names, + "clip_weights": clip_weights_tensor, + } + + +def _load_amp_expert_features( + expert_features_path: str, + device: str, + feature_dim: int, + fallback_samples: int, +) -> dict | None: + """Load expert AMP features bank. Returns None when file is unavailable.""" + if not expert_features_path: + return None + p = Path(expert_features_path).expanduser() + if not p.is_file(): + return None + try: + payload = torch.load(str(p), map_location="cpu") + except Exception: + try: + with p.open("rb") as f: + payload = pickle.load(f) + except Exception: + try: + payload = joblib.load(str(p)) + except Exception: + return None + bank = _extract_expert_bank_from_payload(payload, feature_dim=feature_dim) + if bank is None: + return None + flat_features = bank["flat_features"].float() + if flat_features.shape[0] < fallback_samples: + reps = int((fallback_samples + flat_features.shape[0] - 1) // flat_features.shape[0]) + flat_features = flat_features.repeat(reps, 1) + + clip_tensors = [clip.float() for clip in bank["clips"]] + clip_weights = bank["clip_weights"].float() + clip_weights = torch.clamp(clip_weights, min=0.0) + if float(torch.sum(clip_weights).item()) <= 0.0: + clip_weights = torch.ones_like(clip_weights) + clip_weights = clip_weights / torch.sum(clip_weights) + + return { + "flat_features": flat_features.to(device=device), + "clips": [c.to(device=device) for c in clip_tensors], + "clip_names": bank["clip_names"], + "clip_weights": clip_weights.to(device=device), + } + + +def _policy_sequence_features( + env: ManagerBasedRLEnv, + current_features: torch.Tensor, + history_steps: int, +) -> torch.Tensor: + """Build rolling policy sequence features [N, H, D].""" + history_steps = max(int(history_steps), 1) + cache_key = "amp_policy_hist_cache" + cache = env.extras.get(cache_key, None) + if not isinstance(cache, dict): + cache = {} + env.extras[cache_key] = cache + + hist = cache.get("hist", None) + if ( + not isinstance(hist, torch.Tensor) + or hist.shape[0] != env.num_envs + or hist.shape[1] != history_steps + or hist.shape[2] != current_features.shape[1] + ): + hist = current_features.unsqueeze(1).repeat(1, history_steps, 1) + + if history_steps > 1: + hist[:, :-1] = hist[:, 1:].clone() + hist[:, -1] = current_features + cache["hist"] = hist + env.extras[cache_key] = cache + return hist + + +def _sample_expert_sequence_batch( + expert_bank: dict, + batch_size: int, + history_steps: int, + device: str, +) -> torch.Tensor: + """Sample expert sequence batch [B, H, D] with clip-weighted sampling.""" + clips = expert_bank["clips"] + clip_weights = expert_bank["clip_weights"] + clip_count = len(clips) + if clip_count <= 0: + return torch.empty((0, history_steps, 0), device=device) + + clip_ids = torch.multinomial(clip_weights, num_samples=batch_size, replacement=True) + seq_list: list[torch.Tensor] = [] + for clip_id in clip_ids.tolist(): + clip = clips[int(clip_id)] + clip_len = int(clip.shape[0]) + if clip_len >= history_steps: + max_start = clip_len - history_steps + if max_start > 0: + start = int(torch.randint(0, max_start + 1, (1,), device=device).item()) + else: + start = 0 + seq = clip[start : start + history_steps] + else: + pad = clip[-1:].repeat(history_steps - clip_len, 1) + seq = torch.cat([clip, pad], dim=0) + seq_list.append(seq) + return torch.stack(seq_list, dim=0) + + +def _get_amp_state( + env: ManagerBasedRLEnv, + amp_enabled: bool, + amp_model_path: str, + amp_train_enabled: bool, + amp_expert_features_path: str, + feature_dim: int, + disc_hidden_dim: int, + disc_hidden_layers: int, + disc_lr: float, + disc_weight_decay: float, + disc_min_expert_samples: int, + disc_history_steps: int, +): + """Get cached AMP state (frozen jit or trainable discriminator).""" + cache_key = "amp_state_cache" + hidden_layers = max(int(disc_hidden_layers), 1) + hidden_dim = max(int(disc_hidden_dim), 16) + history_steps = max(int(disc_history_steps), 1) + state_sig = ( + bool(amp_enabled), + str(amp_model_path), + bool(amp_train_enabled), + str(amp_expert_features_path), + int(feature_dim), + history_steps, + hidden_dim, + hidden_layers, + float(disc_lr), + float(disc_weight_decay), + ) + cached = env.extras.get(cache_key, None) + if isinstance(cached, dict) and cached.get("sig") == state_sig: + return cached + + state = { + "sig": state_sig, + "mode": "disabled", + "model": None, + "optimizer": None, + "expert_bank": None, + "disc_history_steps": history_steps, + "step": 0, + "last_loss": 0.0, + "last_acc_policy": 0.0, + "last_acc_expert": 0.0, + } + + if amp_train_enabled: + expert_bank = _load_amp_expert_features( + amp_expert_features_path, + device=env.device, + feature_dim=feature_dim, + fallback_samples=max(disc_min_expert_samples, 512), + ) + if expert_bank is not None: + model = AMPDiscriminator(input_dim=feature_dim * history_steps, hidden_dims=tuple([hidden_dim] * hidden_layers)).to(env.device) + optimizer = torch.optim.AdamW(model.parameters(), lr=float(disc_lr), weight_decay=float(disc_weight_decay)) + state["mode"] = "trainable" + state["model"] = model + state["optimizer"] = optimizer + state["expert_bank"] = expert_bank + elif amp_enabled and amp_model_path: + model_path = Path(amp_model_path).expanduser() + if model_path.is_file(): + try: + model = torch.jit.load(str(model_path), map_location=env.device) + model.eval() + state["mode"] = "jit" + state["model"] = model + except Exception: + pass + + env.extras[cache_key] = state + return state + + +def _build_amp_features(env: ManagerBasedRLEnv, feature_clip: float = 8.0) -> torch.Tensor: + """Build AMP-style discriminator features from robot kinematics.""" + robot_data = env.scene["robot"].data + joint_pos_rel = robot_data.joint_pos - robot_data.default_joint_pos + joint_vel = robot_data.joint_vel + root_lin_vel = robot_data.root_lin_vel_w + root_ang_vel = robot_data.root_ang_vel_w + projected_gravity = robot_data.projected_gravity_b + amp_features = torch.cat([joint_pos_rel, joint_vel, root_lin_vel, root_ang_vel, projected_gravity], dim=-1) + amp_features = _safe_tensor(amp_features, nan=0.0, pos=feature_clip, neg=-feature_clip) + return torch.clamp(amp_features, min=-feature_clip, max=feature_clip) + + +def amp_style_prior_reward( + env: ManagerBasedRLEnv, + amp_enabled: bool = False, + amp_model_path: str = "", + amp_train_enabled: bool = False, + amp_expert_features_path: str = "", + disc_hidden_dim: int = 256, + disc_hidden_layers: int = 2, + disc_lr: float = 3e-4, + disc_weight_decay: float = 1e-6, + disc_update_interval: int = 4, + disc_batch_size: int = 1024, + disc_min_expert_samples: int = 2048, + disc_history_steps: int = 4, + feature_clip: float = 8.0, + logit_scale: float = 1.0, + amp_reward_gain: float = 1.0, + internal_reward_scale: float = 1.0, +) -> torch.Tensor: + """AMP style prior reward with optional online discriminator training.""" + zeros = torch.zeros(env.num_envs, device=env.device) + amp_score = zeros + model_loaded = 0.0 + amp_train_active = 0.0 + disc_loss = 0.0 + disc_acc_policy = 0.0 + disc_acc_expert = 0.0 + + amp_features = _build_amp_features(env, feature_clip=feature_clip) + amp_state = _get_amp_state( + env=env, + amp_enabled=amp_enabled, + amp_model_path=amp_model_path, + amp_train_enabled=amp_train_enabled, + amp_expert_features_path=amp_expert_features_path, + feature_dim=int(amp_features.shape[-1]), + disc_hidden_dim=disc_hidden_dim, + disc_hidden_layers=disc_hidden_layers, + disc_lr=disc_lr, + disc_weight_decay=disc_weight_decay, + disc_min_expert_samples=disc_min_expert_samples, + disc_history_steps=disc_history_steps, + ) + discriminator = amp_state.get("model", None) + history_steps = max(int(amp_state.get("disc_history_steps", disc_history_steps)), 1) + policy_seq = _policy_sequence_features(env, amp_features, history_steps=history_steps) + policy_seq_flat = policy_seq.reshape(policy_seq.shape[0], -1) + if discriminator is not None: + model_loaded = 1.0 + + if amp_state.get("mode") == "trainable" and discriminator is not None: + amp_train_active = 1.0 + optimizer = amp_state.get("optimizer", None) + expert_bank = amp_state.get("expert_bank", None) + amp_state["step"] = int(amp_state.get("step", 0)) + 1 + update_interval = max(int(disc_update_interval), 1) + batch_size = max(int(disc_batch_size), 32) + + if optimizer is not None and isinstance(expert_bank, dict) and amp_state["step"] % update_interval == 0: + policy_features = policy_seq_flat.detach() + policy_count = policy_features.shape[0] + if policy_count > batch_size: + policy_ids = torch.randint(0, policy_count, (batch_size,), device=env.device) + policy_batch = policy_features.index_select(0, policy_ids) + else: + policy_batch = policy_features + + expert_seq = _sample_expert_sequence_batch( + expert_bank=expert_bank, + batch_size=policy_batch.shape[0], + history_steps=history_steps, + device=env.device, + ) + expert_batch = expert_seq.reshape(expert_seq.shape[0], -1) + + discriminator.train() + optimizer.zero_grad(set_to_none=True) + logits_expert = discriminator(expert_batch).squeeze(-1) + logits_policy = discriminator(policy_batch).squeeze(-1) + loss_expert = nn.functional.binary_cross_entropy_with_logits(logits_expert, torch.ones_like(logits_expert)) + loss_policy = nn.functional.binary_cross_entropy_with_logits(logits_policy, torch.zeros_like(logits_policy)) + loss = 0.5 * (loss_expert + loss_policy) + loss.backward() + nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0) + optimizer.step() + discriminator.eval() + + with torch.no_grad(): + disc_loss = float(loss.detach().item()) + disc_acc_expert = float((torch.sigmoid(logits_expert) > 0.5).float().mean().item()) + disc_acc_policy = float((torch.sigmoid(logits_policy) < 0.5).float().mean().item()) + amp_state["last_loss"] = disc_loss + amp_state["last_acc_expert"] = disc_acc_expert + amp_state["last_acc_policy"] = disc_acc_policy + else: + disc_loss = float(amp_state.get("last_loss", 0.0)) + disc_acc_expert = float(amp_state.get("last_acc_expert", 0.0)) + disc_acc_policy = float(amp_state.get("last_acc_policy", 0.0)) + + if discriminator is not None: + discriminator.eval() + with torch.no_grad(): + logits = discriminator(policy_seq_flat) + if isinstance(logits, (tuple, list)): + logits = logits[0] + if logits.ndim > 1: + logits = logits.squeeze(-1) + elif amp_enabled and amp_model_path: + # For external scripted models, try temporal then fallback to single frame. + model = amp_state.get("model", None) + if model is not None: + with torch.no_grad(): + try: + logits = model(policy_seq_flat) + except Exception: + logits = model(amp_features) + if isinstance(logits, (tuple, list)): + logits = logits[0] + if logits.ndim > 1: + logits = logits.squeeze(-1) + logits = _safe_tensor(logits, nan=0.0, pos=20.0, neg=-20.0) + amp_score = torch.sigmoid(logit_scale * logits) + amp_score = _safe_tensor(amp_score, nan=0.0, pos=1.0, neg=0.0) + + amp_reward = _safe_tensor(amp_reward_gain * amp_score, nan=0.0, pos=10.0, neg=0.0) + + log_dict = env.extras.get("log", {}) + if isinstance(log_dict, dict): + log_dict["amp_score_mean"] = torch.mean(amp_score).detach().item() + log_dict["amp_reward_mean"] = torch.mean(amp_reward).detach().item() + log_dict["amp_model_loaded_mean"] = model_loaded + log_dict["amp_train_active_mean"] = amp_train_active + log_dict["amp_disc_loss_mean"] = disc_loss + log_dict["amp_disc_acc_policy_mean"] = disc_acc_policy + log_dict["amp_disc_acc_expert_mean"] = disc_acc_expert + log_dict["amp_disc_history_steps"] = float(history_steps) + env.extras["log"] = log_dict + + return internal_reward_scale * amp_reward diff --git a/rl_game/get_up/amp/expert_features.pt b/rl_game/get_up/amp/expert_features.pt index 30b083e..cda9097 100644 Binary files a/rl_game/get_up/amp/expert_features.pt and b/rl_game/get_up/amp/expert_features.pt differ diff --git a/rl_game/get_up/amp/migrate_legged_lab_expert_template.py b/rl_game/get_up/amp/migrate_legged_lab_expert_template.py new file mode 100644 index 0000000..776d039 --- /dev/null +++ b/rl_game/get_up/amp/migrate_legged_lab_expert_template.py @@ -0,0 +1,258 @@ +from __future__ import annotations + +import argparse +import json +import pickle +from glob import glob +from pathlib import Path + +import numpy as np +import joblib + +try: + import torch +except Exception: + torch = None + +GIT_GETUP_FOCUS_CLIP = "fallAndGetUp2_subject2_1200_1370" + + +def _load_payload(path: Path): + if path.suffix.lower() in (".pt", ".pth"): + if torch is None: + raise RuntimeError("Loading .pt/.pth requires torch to be installed.") + return torch.load(str(path), map_location="cpu") + if path.suffix.lower() in (".pkl",): + try: + with path.open("rb") as f: + return pickle.load(f) + except Exception: + return joblib.load(str(path)) + if path.suffix.lower() in (".json",): + with path.open("r", encoding="utf-8") as f: + return json.load(f) + raise ValueError(f"Unsupported file type: {path.suffix}") + + +def _to_numpy(payload, key_hint: str = "") -> np.ndarray: + if torch is not None and isinstance(payload, torch.Tensor): + return payload.detach().cpu().numpy().astype(np.float32) + if isinstance(payload, np.ndarray): + return payload.astype(np.float32) + if isinstance(payload, list): + return np.asarray(payload, dtype=np.float32) + if isinstance(payload, dict): + candidate_keys = [key_hint] if key_hint else [] + candidate_keys += [ + "expert_features", + "features", + "observations", + "obs", + "joint_pos", + "motion", + "data", + ] + for key in candidate_keys: + if key and key in payload: + value = payload[key] + if torch is not None and isinstance(value, torch.Tensor): + return value.detach().cpu().numpy().astype(np.float32) + if isinstance(value, np.ndarray): + return value.astype(np.float32) + if isinstance(value, list): + return np.asarray(value, dtype=np.float32) + for value in payload.values(): + if torch is not None and isinstance(value, torch.Tensor): + return value.detach().cpu().numpy().astype(np.float32) + if isinstance(value, np.ndarray): + return value.astype(np.float32) + if isinstance(value, list): + return np.asarray(value, dtype=np.float32) + raise ValueError("Could not locate tensor-like payload. Provide --input_key when needed.") + + +def _ensure_2d(x: np.ndarray) -> np.ndarray: + if x.ndim == 1: + return x[None, :] + if x.ndim != 2: + raise ValueError(f"Expected 2D array [N, D], got shape={tuple(x.shape)}") + return x.astype(np.float32) + + +def _from_legged_lab_motion_dict(payload: dict, target_dof: int) -> np.ndarray: + """ + Convert legged_lab motion pickle payload to AMP feature tensor [N, D]. + Expected keys: fps, root_pos, dof_pos. + """ + dof_pos = np.asarray(payload.get("dof_pos", None)) + root_pos = np.asarray(payload.get("root_pos", None)) + fps = float(payload.get("fps", 30.0)) + if dof_pos.ndim != 2: + raise ValueError("legged_lab payload missing valid dof_pos [N, M].") + if root_pos.ndim != 2 or root_pos.shape[0] != dof_pos.shape[0]: + root_pos = np.zeros((dof_pos.shape[0], 3), dtype=np.float32) + + dt = 1.0 / max(fps, 1e-3) + dof_pos = dof_pos.astype(np.float32) + root_pos = root_pos.astype(np.float32) + + # Current get_up AMP feature dim expects 23 dof by default. + target_dof = int(target_dof) + if dof_pos.shape[1] >= target_dof: + dof_pos = dof_pos[:, :target_dof] + else: + dof_pos = np.pad(dof_pos, ((0, 0), (0, target_dof - dof_pos.shape[1])), mode="constant") + + dof_vel = np.zeros_like(dof_pos, dtype=np.float32) + if dof_pos.shape[0] > 1: + dof_vel[1:] = (dof_pos[1:] - dof_pos[:-1]) / dt + dof_vel[0] = dof_vel[1] + + root_lin = np.zeros((dof_pos.shape[0], 3), dtype=np.float32) + if root_pos.shape[0] > 1: + root_lin[1:] = (root_pos[1:] - root_pos[:-1]) / dt + root_lin[0] = root_lin[1] + root_ang = np.zeros((dof_pos.shape[0], 3), dtype=np.float32) + gravity = np.zeros((dof_pos.shape[0], 3), dtype=np.float32) + gravity[:, 2] = -1.0 + + x = np.concatenate([dof_pos, dof_vel, root_lin, root_ang, gravity], axis=-1).astype(np.float32) + return x + + +def _normalize_dim(x: np.ndarray, feature_dim: int) -> np.ndarray: + if x.shape[1] == feature_dim: + return x + if x.shape[1] > feature_dim: + return x[:, :feature_dim] + pad = np.zeros((x.shape[0], feature_dim - x.shape[1]), dtype=np.float32) + return np.concatenate([x, pad], axis=-1) + + +def _collect_input_files(input_path: Path, glob_pattern: str) -> list[Path]: + if input_path.is_file(): + return [input_path] + if input_path.is_dir(): + files = sorted([Path(p) for p in glob(str(input_path / glob_pattern))]) + if files: + return files + raise FileNotFoundError(f"No source files found for input={input_path} pattern={glob_pattern}") + + +def _build_clip_weights(clip_names: list[str], weight_mode: str) -> np.ndarray: + if len(clip_names) == 0: + return np.zeros((0,), dtype=np.float32) + if weight_mode == "uniform": + return np.ones((len(clip_names),), dtype=np.float32) + if weight_mode == "git_getup_focus": + weights = np.array([1.0 if GIT_GETUP_FOCUS_CLIP in n else 0.0 for n in clip_names], dtype=np.float32) + if float(np.sum(weights)) <= 0.0: + return np.ones((len(clip_names),), dtype=np.float32) + return weights + return np.ones((len(clip_names),), dtype=np.float32) + + +def main(): + parser = argparse.ArgumentParser( + description=( + "Template converter: migrate legged_lab (or other) motion/expert data " + "to AMP expert_features.pt for rl_game/get_up." + ) + ) + parser.add_argument("--input", required=True, type=str, help="Path to source motion/expert file or directory.") + parser.add_argument("--output", required=True, type=str, help="Output path for expert_features.pt.") + parser.add_argument( + "--input_key", + type=str, + default="", + help="Optional key to locate tensor inside input payload dict.", + ) + parser.add_argument( + "--feature_dim", + type=int, + default=55, + help="Target AMP feature dimension. For current get_up config, default=55.", + ) + parser.add_argument( + "--input_glob", + type=str, + default="*.pkl", + help="Glob used when --input is a directory.", + ) + parser.add_argument( + "--target_dof", + type=int, + default=23, + help="Target dof count when converting legged_lab pkl dof_pos.", + ) + parser.add_argument( + "--clip_weight_mode", + type=str, + default="git_getup_focus", + choices=["git_getup_focus", "uniform"], + help="Clip sampling weight mode for expert sequence training.", + ) + parser.add_argument("--repeat", type=int, default=1, help="Repeat samples for small datasets.") + args = parser.parse_args() + + in_path = Path(args.input).expanduser().resolve() + input_files = _collect_input_files(in_path, args.input_glob) + clip_arrays: list[np.ndarray] = [] + clip_names: list[str] = [] + for f in input_files: + payload = _load_payload(f) + if isinstance(payload, dict) and "dof_pos" in payload and f.suffix.lower() == ".pkl": + x = _from_legged_lab_motion_dict(payload, target_dof=int(args.target_dof)) + else: + x = _to_numpy(payload, key_hint=args.input_key) + x = _ensure_2d(x) + x = _normalize_dim(x.astype(np.float32), int(args.feature_dim)) + clip_arrays.append(x) + clip_names.append(f.stem) + + repeat = max(int(args.repeat), 1) + if repeat > 1: + clip_arrays = clip_arrays * repeat + clip_names = clip_names * repeat + + clip_weights = _build_clip_weights(clip_names, weight_mode=args.clip_weight_mode) + x = np.concatenate(clip_arrays, axis=0).astype(np.float32) + clip_lengths = [int(c.shape[0]) for c in clip_arrays] + clip_offsets: list[int] = [] + cursor = 0 + for length in clip_lengths: + clip_offsets.append(cursor) + cursor += length + repeat = max(int(args.repeat), 1) + + out_path = Path(args.output).expanduser().resolve() + out_path.parent.mkdir(parents=True, exist_ok=True) + payload = { + "expert_features": x, + "expert_clips": clip_arrays, + "clip_names": clip_names, + "clip_weights": clip_weights, + "clip_lengths": clip_lengths, + "clip_offsets": clip_offsets, + "meta": { + "source": str(in_path), + "source_count": len(input_files), + "input_key": args.input_key, + "feature_dim": int(args.feature_dim), + "target_dof": int(args.target_dof), + "input_glob": args.input_glob, + "clip_weight_mode": args.clip_weight_mode, + "git_getup_focus_clip": GIT_GETUP_FOCUS_CLIP, + "repeat": int(repeat), + }, + } + if torch is not None: + torch.save(payload, str(out_path)) + else: + with out_path.open("wb") as f: + pickle.dump(payload, f) + print(f"[OK] Saved AMP expert features -> {out_path} shape={tuple(x.shape)}") + + +if __name__ == "__main__": + main() diff --git a/rl_game/get_up/build_amp_expert_features_from_keyframes.py b/rl_game/get_up/build_amp_expert_features_from_keyframes.py deleted file mode 100644 index 5b5d096..0000000 --- a/rl_game/get_up/build_amp_expert_features_from_keyframes.py +++ /dev/null @@ -1,204 +0,0 @@ -import argparse -import math -from pathlib import Path - -import torch -import yaml - -# AMP feature order must match `_build_amp_features` in `t1_env_cfg.py`: -# [joint_pos_rel(23), joint_vel(23), root_lin_vel(3), root_ang_vel(3), projected_gravity(3)]. -T1_JOINT_NAMES = [ - "AAHead_yaw", - "Head_pitch", - "Left_Shoulder_Pitch", - "Left_Shoulder_Roll", - "Left_Elbow_Pitch", - "Left_Elbow_Yaw", - "Right_Shoulder_Pitch", - "Right_Shoulder_Roll", - "Right_Elbow_Pitch", - "Right_Elbow_Yaw", - "Waist", - "Left_Hip_Pitch", - "Left_Hip_Roll", - "Left_Hip_Yaw", - "Left_Knee_Pitch", - "Left_Ankle_Pitch", - "Left_Ankle_Roll", - "Right_Hip_Pitch", - "Right_Hip_Roll", - "Right_Hip_Yaw", - "Right_Knee_Pitch", - "Right_Ankle_Pitch", - "Right_Ankle_Roll", -] - -JOINT_TO_IDX = {name: i for i, name in enumerate(T1_JOINT_NAMES)} - -# Mirror rules aligned with `behaviors/custom/keyframe/keyframe.py`. -MOTOR_SYMMETRY = { - "Head_yaw": (("Head_yaw",), False), - "Head_pitch": (("Head_pitch",), False), - "Shoulder_Pitch": (("Left_Shoulder_Pitch", "Right_Shoulder_Pitch"), False), - "Shoulder_Roll": (("Left_Shoulder_Roll", "Right_Shoulder_Roll"), True), - "Elbow_Pitch": (("Left_Elbow_Pitch", "Right_Elbow_Pitch"), False), - "Elbow_Yaw": (("Left_Elbow_Yaw", "Right_Elbow_Yaw"), True), - "Waist": (("Waist",), False), - "Hip_Pitch": (("Left_Hip_Pitch", "Right_Hip_Pitch"), False), - "Hip_Roll": (("Left_Hip_Roll", "Right_Hip_Roll"), True), - "Hip_Yaw": (("Left_Hip_Yaw", "Right_Hip_Yaw"), True), - "Knee_Pitch": (("Left_Knee_Pitch", "Right_Knee_Pitch"), False), - "Ankle_Pitch": (("Left_Ankle_Pitch", "Right_Ankle_Pitch"), False), - "Ankle_Roll": (("Left_Ankle_Roll", "Right_Ankle_Roll"), True), -} -READABLE_TO_POLICY = {"Head_yaw": "AAHead_yaw"} - - -def decode_keyframe_motor_positions(raw_motor_positions: dict[str, float]) -> dict[str, float]: - """Decode one keyframe into per-joint radians.""" - out: dict[str, float] = {} - deg_to_rad = math.pi / 180.0 - for readable_name, position_deg in raw_motor_positions.items(): - if readable_name in MOTOR_SYMMETRY: - motor_names, is_inverse_direction = MOTOR_SYMMETRY[readable_name] - invert_state = bool(is_inverse_direction) - for motor_name in motor_names: - signed_deg = position_deg if invert_state else -position_deg - invert_state = False - out_name = READABLE_TO_POLICY.get(motor_name, motor_name) - out[out_name] = float(signed_deg) * deg_to_rad - else: - out_name = READABLE_TO_POLICY.get(readable_name, readable_name) - out[out_name] = float(position_deg) * deg_to_rad - return out - - -def load_sequence(yaml_path: Path) -> list[tuple[float, torch.Tensor]]: - """Load yaml keyframes -> list[(delta_seconds, joint_pos_vec)].""" - with yaml_path.open("r", encoding="utf-8") as f: - desc = yaml.safe_load(f) or {} - out: list[tuple[float, torch.Tensor]] = [] - for keyframe in desc.get("keyframes", []): - delta_s = max(float(keyframe.get("delta", 0.1)), 1e-3) - raw = keyframe.get("motor_positions", {}) or {} - decoded = decode_keyframe_motor_positions(raw) - joint_pos = torch.zeros(len(T1_JOINT_NAMES), dtype=torch.float32) - for j_name, j_val in decoded.items(): - idx = JOINT_TO_IDX.get(j_name, None) - if idx is not None: - joint_pos[idx] = float(j_val) - out.append((delta_s, joint_pos)) - return out - - -def sequence_to_amp_features( - sequence: list[tuple[float, torch.Tensor]], - sample_fps: float, - projected_gravity: tuple[float, float, float], -) -> torch.Tensor: - """Convert decoded sequence into AMP features tensor (N, 55).""" - if len(sequence) == 0: - raise ValueError("Empty keyframe sequence.") - dt = 1.0 / max(sample_fps, 1e-6) - grav = torch.tensor(projected_gravity, dtype=torch.float32) - - frames_joint_pos: list[torch.Tensor] = [] - for delta_s, joint_pos in sequence: - repeat = max(int(round(delta_s / dt)), 1) - for _ in range(repeat): - frames_joint_pos.append(joint_pos.clone()) - if len(frames_joint_pos) < 2: - frames_joint_pos.append(frames_joint_pos[0].clone()) - - pos = torch.stack(frames_joint_pos, dim=0) - vel = torch.zeros_like(pos) - vel[1:] = (pos[1:] - pos[:-1]) / dt - vel[0] = vel[1] - - root_lin = torch.zeros((pos.shape[0], 3), dtype=torch.float32) - root_ang = torch.zeros((pos.shape[0], 3), dtype=torch.float32) - grav_batch = grav.unsqueeze(0).repeat(pos.shape[0], 1) - return torch.cat([pos, vel, root_lin, root_ang, grav_batch], dim=-1) - - -def main(): - parser = argparse.ArgumentParser(description="Build AMP expert features from get_up keyframe YAML files.") - parser.add_argument( - "--front_yaml", - type=str, - default="behaviors/custom/keyframe/get_up/get_up_front.yaml", - help="Path to front get-up YAML.", - ) - parser.add_argument( - "--back_yaml", - type=str, - default="behaviors/custom/keyframe/get_up/get_up_back.yaml", - help="Path to back get-up YAML.", - ) - parser.add_argument( - "--sample_fps", - type=float, - default=50.0, - help="Sampling fps when expanding keyframe durations.", - ) - parser.add_argument( - "--repeat_cycles", - type=int, - default=200, - help="How many times to repeat front+back sequences to enlarge dataset.", - ) - parser.add_argument( - "--projected_gravity", - type=float, - nargs=3, - default=(0.0, 0.0, -1.0), - help="Projected gravity feature used for synthesized expert data.", - ) - parser.add_argument( - "--output", - type=str, - default="rl_game/get_up/amp/expert_features.pt", - help="Output expert feature file path.", - ) - args = parser.parse_args() - - front_path = Path(args.front_yaml).expanduser().resolve() - back_path = Path(args.back_yaml).expanduser().resolve() - if not front_path.is_file(): - raise FileNotFoundError(f"Front YAML not found: {front_path}") - if not back_path.is_file(): - raise FileNotFoundError(f"Back YAML not found: {back_path}") - - front_seq = load_sequence(front_path) - back_seq = load_sequence(back_path) - front_feat = sequence_to_amp_features(front_seq, args.sample_fps, tuple(args.projected_gravity)) - back_feat = sequence_to_amp_features(back_seq, args.sample_fps, tuple(args.projected_gravity)) - base_feat = torch.cat([front_feat, back_feat], dim=0) - - repeat_cycles = max(int(args.repeat_cycles), 1) - expert_features = base_feat.repeat(repeat_cycles, 1).contiguous() - - out_path = Path(args.output).expanduser() - if not out_path.is_absolute(): - out_path = Path.cwd() / out_path - out_path.parent.mkdir(parents=True, exist_ok=True) - torch.save( - { - "expert_features": expert_features, - "feature_dim": int(expert_features.shape[1]), - "num_samples": int(expert_features.shape[0]), - "source": "get_up_keyframe_yaml", - "front_yaml": str(front_path), - "back_yaml": str(back_path), - "sample_fps": float(args.sample_fps), - "repeat_cycles": repeat_cycles, - "projected_gravity": [float(v) for v in args.projected_gravity], - }, - str(out_path), - ) - print(f"[INFO]: saved expert features -> {out_path}") - print(f"[INFO]: shape={tuple(expert_features.shape)}") - - -if __name__ == "__main__": - main() diff --git a/rl_game/get_up/config/ppo_cfg.yaml b/rl_game/get_up/config/ppo_cfg.yaml index d3c173d..7bb7b7b 100644 --- a/rl_game/get_up/config/ppo_cfg.yaml +++ b/rl_game/get_up/config/ppo_cfg.yaml @@ -27,7 +27,7 @@ params: name: default config: - name: T1_Walking + name: T1_GetUp env_name: rlgym # Isaac Lab 包装器 multi_gpu: False ppo: True diff --git a/rl_game/get_up/config/t1_env_cfg.py b/rl_game/get_up/config/t1_env_cfg.py index 25dc8fb..973276d 100644 --- a/rl_game/get_up/config/t1_env_cfg.py +++ b/rl_game/get_up/config/t1_env_cfg.py @@ -1,6 +1,6 @@ import torch -import torch.nn as nn from pathlib import Path +import yaml import isaaclab.envs.mdp as mdp from isaaclab.envs import ManagerBasedRLEnvCfg, ManagerBasedRLEnv from isaaclab.managers import ObservationGroupCfg as ObsGroup @@ -11,8 +11,13 @@ from isaaclab.managers import EventTermCfg as EventTerm from isaaclab.envs.mdp import JointPositionActionCfg from isaaclab.managers import SceneEntityCfg from isaaclab.utils import configclass +from rl_game.get_up.amp.amp_rewards import amp_style_prior_reward from rl_game.get_up.env.t1_env import T1SceneCfg +_PROJECT_ROOT = Path(__file__).resolve().parents[3] +_DEFAULT_FRONT_KEYFRAME_YAML = str(_PROJECT_ROOT / "behaviors" / "custom" / "keyframe" / "get_up" / "get_up_front.yaml") +_DEFAULT_BACK_KEYFRAME_YAML = str(_PROJECT_ROOT / "behaviors" / "custom" / "keyframe" / "get_up" / "get_up_back.yaml") + def _contact_force_z(env: ManagerBasedRLEnv, sensor_cfg: SceneEntityCfg) -> torch.Tensor: """Sum positive vertical contact force on selected bodies.""" sensor = env.scene.sensors.get(sensor_cfg.name) @@ -30,272 +35,216 @@ def _safe_tensor(x: torch.Tensor, nan: float = 0.0, pos: float = 1e3, neg: float return torch.nan_to_num(x, nan=nan, posinf=pos, neginf=neg) -class AMPDiscriminator(nn.Module): - """Lightweight discriminator used by online AMP updates.""" - - def __init__(self, input_dim: int, hidden_dims: tuple[int, ...]): - super().__init__() - layers: list[nn.Module] = [] - in_dim = input_dim - for h_dim in hidden_dims: - layers.append(nn.Linear(in_dim, h_dim)) - layers.append(nn.LayerNorm(h_dim)) - layers.append(nn.SiLU()) - in_dim = h_dim - layers.append(nn.Linear(in_dim, 1)) - self.net = nn.Sequential(*layers) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.net(x) +def _resolve_path(path_like: str) -> Path: + path = Path(path_like).expanduser() + if path.is_absolute(): + return path + return (_PROJECT_ROOT / path).resolve() -def _extract_tensor_from_amp_payload(payload) -> torch.Tensor | None: - if isinstance(payload, torch.Tensor): - return payload - if isinstance(payload, dict): - for key in ("expert_features", "features", "obs"): - value = payload.get(key, None) - if isinstance(value, torch.Tensor): - return value - return None +def _interpolate_keyframes( + keyframes: list[dict], + sample_dt: float, +) -> tuple[torch.Tensor, list[str]]: + """Interpolate sparse keyframes into dense [T, M] motor angle table in radians.""" + if len(keyframes) == 0: + raise ValueError("No keyframes in motion yaml") + + motor_names: set[str] = set() + for frame in keyframes: + motors = frame.get("motor_positions", {}) + if isinstance(motors, dict): + motor_names.update(motors.keys()) + names = sorted(motor_names) + if len(names) == 0: + raise ValueError("No motor_positions found in keyframes") + + frame_count = len(keyframes) + times = torch.zeros(frame_count, dtype=torch.float32) + values = torch.full((frame_count, len(names)), float("nan"), dtype=torch.float32) + name_to_idx = {name: idx for idx, name in enumerate(names)} + + elapsed = 0.0 + for i, frame in enumerate(keyframes): + elapsed += max(float(frame.get("delta", 0.1)), 1e-4) + times[i] = elapsed + motors = frame.get("motor_positions", {}) + if not isinstance(motors, dict): + continue + for motor_name, motor_deg in motors.items(): + if motor_name not in name_to_idx: + continue + values[i, name_to_idx[motor_name]] = float(motor_deg) * (torch.pi / 180.0) + + values[0] = torch.nan_to_num(values[0], nan=0.0) + for i in range(1, frame_count): + values[i] = torch.where(torch.isnan(values[i]), values[i - 1], values[i]) + + sample_dt = max(float(sample_dt), 1e-3) + sample_times = torch.arange(0.0, float(times[-1]) + 1e-6, sample_dt, dtype=torch.float32) + sample_times[0] = max(sample_times[0], 1e-6) + sample_times = torch.clamp(sample_times, min=times[0], max=times[-1]) + + upper = torch.searchsorted(times, sample_times, right=True) + upper = torch.clamp(upper, min=1, max=frame_count - 1) + lower = upper - 1 + + t0 = times.index_select(0, lower) + t1 = times.index_select(0, upper) + v0 = values.index_select(0, lower) + v1 = values.index_select(0, upper) + alpha = ((sample_times - t0) / (t1 - t0 + 1e-6)).unsqueeze(-1) + interp = v0 + alpha * (v1 - v0) + return interp, names -def _load_amp_expert_features( - expert_features_path: str, - device: str, - feature_dim: int, - fallback_samples: int, -) -> torch.Tensor | None: - """Load expert AMP features. Returns None when file is unavailable.""" - if not expert_features_path: - return None - p = Path(expert_features_path).expanduser() - if not p.is_file(): - return None - try: - payload = torch.load(str(p), map_location="cpu") - except Exception: - return None - expert = _extract_tensor_from_amp_payload(payload) - if expert is None: - return None - expert = expert.float() - if expert.ndim == 1: - expert = expert.unsqueeze(0) - if expert.ndim != 2: - return None - if expert.shape[1] != feature_dim: - return None - if expert.shape[0] < 2: - return None - if expert.shape[0] < fallback_samples: - reps = int((fallback_samples + expert.shape[0] - 1) // expert.shape[0]) - expert = expert.repeat(reps, 1) - return expert.to(device=device) +def _motion_table_from_yaml( + yaml_path: str, + joint_names: list[str], + sample_dt: float, +) -> tuple[torch.Tensor, float]: + """ + Build [T, J] target joint motion from get-up keyframes. + Unknown joints default to 0.0 (neutral). + """ + path = _resolve_path(yaml_path) + if not path.is_file(): + raise FileNotFoundError(f"Motion yaml not found: {path}") + with path.open("r", encoding="utf-8") as f: + payload = yaml.safe_load(f) or {} + keyframes = payload.get("keyframes", []) + if not isinstance(keyframes, list): + raise ValueError(f"Invalid keyframes in yaml: {path}") + + motor_table, motor_names = _interpolate_keyframes(keyframes, sample_dt=sample_dt) + motor_to_idx = {name: idx for idx, name in enumerate(motor_names)} + joint_table = torch.zeros((motor_table.shape[0], len(joint_names)), dtype=torch.float32) + + sign_flip_bases = {"Shoulder_Roll", "Elbow_Yaw", "Hip_Roll", "Hip_Yaw", "Ankle_Roll"} + for j, joint_name in enumerate(joint_names): + alias = "Head_yaw" if joint_name == "AAHead_yaw" else joint_name + base = alias.replace("Left_", "").replace("Right_", "") + src = None + if alias in motor_to_idx: + src = alias + elif base in motor_to_idx: + src = base + if src is None: + continue + sign = -1.0 if alias.startswith("Right_") and base in sign_flip_bases else 1.0 + joint_table[:, j] = sign * motor_table[:, motor_to_idx[src]] + + duration_s = max(float(motor_table.shape[0] - 1) * sample_dt, sample_dt) + return joint_table, duration_s -def _get_amp_state( +def _get_keyframe_motion_cache( env: ManagerBasedRLEnv, - amp_enabled: bool, - amp_model_path: str, - amp_train_enabled: bool, - amp_expert_features_path: str, - feature_dim: int, - disc_hidden_dim: int, - disc_hidden_layers: int, - disc_lr: float, - disc_weight_decay: float, - disc_min_expert_samples: int, + front_motion_path: str, + back_motion_path: str, + sample_dt: float, ): - """Get cached AMP state (frozen jit or trainable discriminator).""" - cache_key = "amp_state_cache" - hidden_layers = max(int(disc_hidden_layers), 1) - hidden_dim = max(int(disc_hidden_dim), 16) - state_sig = ( - bool(amp_enabled), - str(amp_model_path), - bool(amp_train_enabled), - str(amp_expert_features_path), - int(feature_dim), - hidden_dim, - hidden_layers, - float(disc_lr), - float(disc_weight_decay), - ) + """Cache interpolated front/back get-up motion priors on env device.""" + cache_key = "getup_keyframe_motion_cache" + sig = (str(front_motion_path), str(back_motion_path), float(sample_dt)) cached = env.extras.get(cache_key, None) - if isinstance(cached, dict) and cached.get("sig") == state_sig: + if isinstance(cached, dict) and cached.get("sig") == sig: return cached - state = { - "sig": state_sig, - "mode": "disabled", - "model": None, - "optimizer": None, - "expert_features": None, - "step": 0, - "last_loss": 0.0, - "last_acc_policy": 0.0, - "last_acc_expert": 0.0, + front_table, front_duration = _motion_table_from_yaml(front_motion_path, T1_JOINT_NAMES, sample_dt) + back_table, back_duration = _motion_table_from_yaml(back_motion_path, T1_JOINT_NAMES, sample_dt) + + cache = { + "sig": sig, + "sample_dt": float(sample_dt), + "front_motion": front_table.to(device=env.device), + "back_motion": back_table.to(device=env.device), + "front_duration": float(front_duration), + "back_duration": float(back_duration), } - - if amp_train_enabled: - expert_features = _load_amp_expert_features( - amp_expert_features_path, - device=env.device, - feature_dim=feature_dim, - fallback_samples=max(disc_min_expert_samples, 512), - ) - if expert_features is not None: - model = AMPDiscriminator(input_dim=feature_dim, hidden_dims=tuple([hidden_dim] * hidden_layers)).to(env.device) - optimizer = torch.optim.AdamW(model.parameters(), lr=float(disc_lr), weight_decay=float(disc_weight_decay)) - state["mode"] = "trainable" - state["model"] = model - state["optimizer"] = optimizer - state["expert_features"] = expert_features - elif amp_enabled and amp_model_path: - model_path = Path(amp_model_path).expanduser() - if model_path.is_file(): - try: - model = torch.jit.load(str(model_path), map_location=env.device) - model.eval() - state["mode"] = "jit" - state["model"] = model - except Exception: - pass - - env.extras[cache_key] = state - return state + env.extras[cache_key] = cache + return cache -def _build_amp_features(env: ManagerBasedRLEnv, feature_clip: float = 8.0) -> torch.Tensor: - """Build AMP-style discriminator features from robot kinematics.""" - robot_data = env.scene["robot"].data - joint_pos_rel = robot_data.joint_pos - robot_data.default_joint_pos - joint_vel = robot_data.joint_vel - root_lin_vel = robot_data.root_lin_vel_w - root_ang_vel = robot_data.root_ang_vel_w - projected_gravity = robot_data.projected_gravity_b - amp_features = torch.cat([joint_pos_rel, joint_vel, root_lin_vel, root_ang_vel, projected_gravity], dim=-1) - amp_features = _safe_tensor(amp_features, nan=0.0, pos=feature_clip, neg=-feature_clip) - return torch.clamp(amp_features, min=-feature_clip, max=feature_clip) - - -def amp_style_prior_reward( +def keyframe_motion_prior_reward( env: ManagerBasedRLEnv, - amp_enabled: bool = False, - amp_model_path: str = "", - amp_train_enabled: bool = False, - amp_expert_features_path: str = "", - disc_hidden_dim: int = 256, - disc_hidden_layers: int = 2, - disc_lr: float = 3e-4, - disc_weight_decay: float = 1e-6, - disc_update_interval: int = 4, - disc_batch_size: int = 1024, - disc_min_expert_samples: int = 2048, - feature_clip: float = 8.0, - logit_scale: float = 1.0, - amp_reward_gain: float = 1.0, + front_motion_path: str = _DEFAULT_FRONT_KEYFRAME_YAML, + back_motion_path: str = _DEFAULT_BACK_KEYFRAME_YAML, + sample_dt: float = 0.04, + pose_sigma: float = 0.42, + vel_sigma: float = 1.6, + joint_subset: str = "all", internal_reward_scale: float = 1.0, ) -> torch.Tensor: - """AMP style prior reward with optional online discriminator training.""" - zeros = torch.zeros(env.num_envs, device=env.device) - amp_score = zeros - model_loaded = 0.0 - amp_train_active = 0.0 - disc_loss = 0.0 - disc_acc_policy = 0.0 - disc_acc_expert = 0.0 - - amp_features = _build_amp_features(env, feature_clip=feature_clip) - amp_state = _get_amp_state( + """ + DeepMimic-style dense reward from keyframe get-up motions. + - mode=1 uses front sequence + - mode=0 uses back sequence + """ + motion_cache = _get_keyframe_motion_cache( env=env, - amp_enabled=amp_enabled, - amp_model_path=amp_model_path, - amp_train_enabled=amp_train_enabled, - amp_expert_features_path=amp_expert_features_path, - feature_dim=int(amp_features.shape[-1]), - disc_hidden_dim=disc_hidden_dim, - disc_hidden_layers=disc_hidden_layers, - disc_lr=disc_lr, - disc_weight_decay=disc_weight_decay, - disc_min_expert_samples=disc_min_expert_samples, + front_motion_path=front_motion_path, + back_motion_path=back_motion_path, + sample_dt=sample_dt, ) - discriminator = amp_state.get("model", None) - if discriminator is not None: - model_loaded = 1.0 - if amp_state.get("mode") == "trainable" and discriminator is not None: - amp_train_active = 1.0 - optimizer = amp_state.get("optimizer", None) - expert_features = amp_state.get("expert_features", None) - amp_state["step"] = int(amp_state.get("step", 0)) + 1 - update_interval = max(int(disc_update_interval), 1) - batch_size = max(int(disc_batch_size), 32) + # env.extras["getup_mode"]: 1=front, 0=back + getup_mode = env.extras.get("getup_mode", None) + if not isinstance(getup_mode, torch.Tensor) or getup_mode.shape[0] != env.num_envs: + getup_mode = torch.zeros(env.num_envs, device=env.device, dtype=torch.long) + env.extras["getup_mode"] = getup_mode + getup_mode = getup_mode.to(dtype=torch.long) + use_front = getup_mode == 1 - if optimizer is not None and isinstance(expert_features, torch.Tensor) and amp_state["step"] % update_interval == 0: - policy_features = amp_features.detach() - policy_count = policy_features.shape[0] - if policy_count > batch_size: - policy_ids = torch.randint(0, policy_count, (batch_size,), device=env.device) - policy_batch = policy_features.index_select(0, policy_ids) - else: - policy_batch = policy_features + step_dt = env.step_dt + phase_time = torch.clamp(env.episode_length_buf * step_dt, min=0.0) - expert_count = expert_features.shape[0] - expert_ids = torch.randint(0, expert_count, (policy_batch.shape[0],), device=env.device) - expert_batch = expert_features.index_select(0, expert_ids) + front_motion = motion_cache["front_motion"] + back_motion = motion_cache["back_motion"] + front_idx = torch.clamp((phase_time / sample_dt).to(torch.long), min=0, max=front_motion.shape[0] - 1) + back_idx = torch.clamp((phase_time / sample_dt).to(torch.long), min=0, max=back_motion.shape[0] - 1) - discriminator.train() - optimizer.zero_grad(set_to_none=True) - logits_expert = discriminator(expert_batch).squeeze(-1) - logits_policy = discriminator(policy_batch).squeeze(-1) - loss_expert = nn.functional.binary_cross_entropy_with_logits(logits_expert, torch.ones_like(logits_expert)) - loss_policy = nn.functional.binary_cross_entropy_with_logits(logits_policy, torch.zeros_like(logits_policy)) - loss = 0.5 * (loss_expert + loss_policy) - loss.backward() - nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0) - optimizer.step() - discriminator.eval() + target_front = front_motion.index_select(0, front_idx) + target_back = back_motion.index_select(0, back_idx) + target_pos = torch.where(use_front.unsqueeze(-1), target_front, target_back) - with torch.no_grad(): - disc_loss = float(loss.detach().item()) - disc_acc_expert = float((torch.sigmoid(logits_expert) > 0.5).float().mean().item()) - disc_acc_policy = float((torch.sigmoid(logits_policy) < 0.5).float().mean().item()) - amp_state["last_loss"] = disc_loss - amp_state["last_acc_expert"] = disc_acc_expert - amp_state["last_acc_policy"] = disc_acc_policy - else: - disc_loss = float(amp_state.get("last_loss", 0.0)) - disc_acc_expert = float(amp_state.get("last_acc_expert", 0.0)) - disc_acc_policy = float(amp_state.get("last_acc_policy", 0.0)) + robot_data = env.scene["robot"].data + current_pos = _safe_tensor(robot_data.joint_pos) + current_vel = _safe_tensor(robot_data.joint_vel) - if discriminator is not None: - discriminator.eval() - with torch.no_grad(): - logits = discriminator(amp_features) - if isinstance(logits, (tuple, list)): - logits = logits[0] - if logits.ndim > 1: - logits = logits.squeeze(-1) - logits = _safe_tensor(logits, nan=0.0, pos=20.0, neg=-20.0) - amp_score = torch.sigmoid(logit_scale * logits) - amp_score = _safe_tensor(amp_score, nan=0.0, pos=1.0, neg=0.0) + if joint_subset == "legs": + legs_idx, _ = env.scene["robot"].find_joints(".*(Hip|Knee|Ankle).*") + if len(legs_idx) > 0: + ids = torch.tensor(legs_idx, device=env.device, dtype=torch.long) + current_pos = current_pos.index_select(1, ids) + current_vel = current_vel.index_select(1, ids) + target_pos = target_pos.index_select(1, ids) + elif joint_subset == "core": + core_idx, _ = env.scene["robot"].find_joints(".*(Waist|Hip|Knee|Ankle).*") + if len(core_idx) > 0: + ids = torch.tensor(core_idx, device=env.device, dtype=torch.long) + current_pos = current_pos.index_select(1, ids) + current_vel = current_vel.index_select(1, ids) + target_pos = target_pos.index_select(1, ids) - amp_reward = _safe_tensor(amp_reward_gain * amp_score, nan=0.0, pos=10.0, neg=0.0) + target_vel = torch.zeros_like(target_pos) + pos_mse = torch.mean(torch.square(current_pos - target_pos), dim=-1) + vel_mse = torch.mean(torch.square(current_vel - target_vel), dim=-1) + + pose_sigma = max(float(pose_sigma), 1e-3) + vel_sigma = max(float(vel_sigma), 1e-3) + pose_reward = torch.exp(-pos_mse / pose_sigma) + vel_reward = torch.exp(-vel_mse / vel_sigma) + prior_reward = 0.75 * pose_reward + 0.25 * vel_reward + prior_reward = _safe_tensor(prior_reward, nan=0.0, pos=1.0, neg=0.0) log_dict = env.extras.get("log", {}) if isinstance(log_dict, dict): - log_dict["amp_score_mean"] = torch.mean(amp_score).detach().item() - log_dict["amp_reward_mean"] = torch.mean(amp_reward).detach().item() - log_dict["amp_model_loaded_mean"] = model_loaded - log_dict["amp_train_active_mean"] = amp_train_active - log_dict["amp_disc_loss_mean"] = disc_loss - log_dict["amp_disc_acc_policy_mean"] = disc_acc_policy - log_dict["amp_disc_acc_expert_mean"] = disc_acc_expert + log_dict["keyframe_prior_mean"] = torch.mean(prior_reward).detach().item() + log_dict["keyframe_front_ratio"] = torch.mean(use_front.float()).detach().item() env.extras["log"] = log_dict - return internal_reward_scale * amp_reward + return internal_reward_scale * prior_reward def root_height_obs(env: ManagerBasedRLEnv) -> torch.Tensor: @@ -365,6 +314,10 @@ def reset_root_state_bimodal_lie_pose( -torch.ones(num_resets, device=env.device), ) euler_angles[:, 1] = pitch_mag * pitch_sign + # Cache get-up mode for motion priors: pitch<0 => front, pitch>=0 => back. + if "getup_mode" not in env.extras or not isinstance(env.extras.get("getup_mode"), torch.Tensor): + env.extras["getup_mode"] = torch.zeros(env.num_envs, device=env.device, dtype=torch.long) + env.extras["getup_mode"][env_ids] = (pitch_sign < 0.0).to(torch.long) yaw_min, yaw_max = yaw_abs_range yaw_mag = yaw_min + torch.rand(num_resets, device=env.device) * (yaw_max - yaw_min) @@ -778,6 +731,19 @@ class T1GetUpRewardCfg: "timer_name": "reward_stable_timer", }, ) + keyframe_motion_prior = RewTerm( + func=keyframe_motion_prior_reward, + weight=0.0, + params={ + "front_motion_path": _DEFAULT_FRONT_KEYFRAME_YAML, + "back_motion_path": _DEFAULT_BACK_KEYFRAME_YAML, + "sample_dt": 0.04, + "pose_sigma": 0.42, + "vel_sigma": 1.6, + "joint_subset": "all", + "internal_reward_scale": 1.0, + }, + ) # AMP reward is disabled by default until a discriminator model path is provided. amp_style_prior = RewTerm( func=amp_style_prior_reward, @@ -794,6 +760,7 @@ class T1GetUpRewardCfg: "disc_update_interval": 4, "disc_batch_size": 1024, "disc_min_expert_samples": 2048, + "disc_history_steps": 4, "feature_clip": 8.0, "logit_scale": 1.0, "amp_reward_gain": 1.0, diff --git a/rl_game/get_up/train.py b/rl_game/get_up/train.py index 36a4187..70a71d1 100644 --- a/rl_game/get_up/train.py +++ b/rl_game/get_up/train.py @@ -42,7 +42,29 @@ parser.add_argument("--amp_disc_lr", type=float, default=3e-4, help="Learning ra parser.add_argument("--amp_disc_weight_decay", type=float, default=1e-6, help="Weight decay for AMP discriminator.") parser.add_argument("--amp_disc_update_interval", type=int, default=4, help="Train discriminator every N reward calls.") parser.add_argument("--amp_disc_batch_size", type=int, default=1024, help="Discriminator train batch size.") +parser.add_argument("--amp_disc_history_steps", type=int, default=4, help="Temporal history steps for AMP discriminator.") parser.add_argument("--amp_logit_scale", type=float, default=1.0, help="Scale before sigmoid(logits) for AMP score.") +parser.add_argument( + "--amp_from_keyframes", + action="store_true", + help="Generate AMP expert features from get-up keyframe yaml files and enable online discriminator training.", +) +parser.add_argument( + "--amp_keyframe_front", + type=str, + default=os.path.join(PROJECT_ROOT, "behaviors", "custom", "keyframe", "get_up", "get_up_front.yaml"), + help="Front get-up keyframe yaml path for AMP expert generation.", +) +parser.add_argument( + "--amp_keyframe_back", + type=str, + default=os.path.join(PROJECT_ROOT, "behaviors", "custom", "keyframe", "get_up", "get_up_back.yaml"), + help="Back get-up keyframe yaml path for AMP expert generation.", +) +parser.add_argument("--amp_keyframe_dt", type=float, default=0.04, help="Resampling dt for keyframe AMP expert features.") +parser.add_argument("--amp_keyframe_repeat", type=int, default=16, help="Repeat count for each keyframe sequence.") +parser.add_argument("--keyframe_prior_weight", type=float, default=1.0, help="Weight for keyframe motion prior reward.") +parser.add_argument("--disable_keyframe_prior", action="store_true", help="Disable keyframe motion prior reward.") AppLauncher.add_app_launcher_args(parser) args_cli = parser.parse_args() @@ -56,7 +78,8 @@ from rl_games.common.algo_observer import DefaultAlgoObserver from rl_games.torch_runner import Runner from rl_games.common import env_configurations, vecenv -from rl_game.get_up.config.t1_env_cfg import T1EnvCfg +from rl_game.get_up.amp.amp_motion import build_amp_expert_features_from_getup_keyframes +from rl_game.get_up.config.t1_env_cfg import T1EnvCfg, T1_JOINT_NAMES class T1MetricObserver(DefaultAlgoObserver): @@ -77,6 +100,8 @@ class T1MetricObserver(DefaultAlgoObserver): "amp_disc_loss_mean", "amp_disc_acc_policy_mean", "amp_disc_acc_expert_mean", + "keyframe_prior_mean", + "keyframe_front_ratio", ) self._metric_sums: dict[str, float] = {} self._metric_counts: dict[str, int] = {} @@ -202,10 +227,33 @@ def main(): task_id = "Isaac-T1-GetUp-v0" env_cfg = T1EnvCfg() + if args_cli.disable_keyframe_prior: + env_cfg.rewards.keyframe_motion_prior.weight = 0.0 + print("[INFO]: keyframe motion prior disabled") + else: + env_cfg.rewards.keyframe_motion_prior.weight = float(args_cli.keyframe_prior_weight) + print(f"[INFO]: keyframe motion prior weight={env_cfg.rewards.keyframe_motion_prior.weight:.3f}") + + if args_cli.amp_from_keyframes: + auto_feature_path = os.path.join(os.path.dirname(__file__), "logs", "amp", "expert_features_from_keyframes.pt") + generated_path, feature_shape = build_amp_expert_features_from_getup_keyframes( + front_yaml_path=args_cli.amp_keyframe_front, + back_yaml_path=args_cli.amp_keyframe_back, + joint_names=T1_JOINT_NAMES, + output_path=auto_feature_path, + sample_dt=float(args_cli.amp_keyframe_dt), + repeat_count=int(args_cli.amp_keyframe_repeat), + ) + args_cli.amp_expert_features = generated_path + args_cli.amp_train_discriminator = True + print(f"[INFO]: AMP expert features generated at {generated_path}, shape={feature_shape}") + amp_cfg = env_cfg.rewards.amp_style_prior amp_cfg.params["logit_scale"] = float(args_cli.amp_logit_scale) if args_cli.amp_train_discriminator: expert_path = os.path.abspath(os.path.expanduser(args_cli.amp_expert_features)) if args_cli.amp_expert_features else "" + if not expert_path: + raise ValueError("--amp_train_discriminator requires --amp_expert_features or --amp_from_keyframes.") amp_cfg.weight = float(args_cli.amp_reward_weight) amp_cfg.params["amp_train_enabled"] = True amp_cfg.params["amp_enabled"] = False @@ -216,8 +264,10 @@ def main(): amp_cfg.params["disc_weight_decay"] = float(args_cli.amp_disc_weight_decay) amp_cfg.params["disc_update_interval"] = int(args_cli.amp_disc_update_interval) amp_cfg.params["disc_batch_size"] = int(args_cli.amp_disc_batch_size) - print(f"[INFO]: AMP online discriminator enabled, expert_features={expert_path or ''}") + amp_cfg.params["disc_history_steps"] = int(args_cli.amp_disc_history_steps) + print(f"[INFO]: AMP online discriminator enabled, expert_features={expert_path}") print(f"[INFO]: AMP reward weight={amp_cfg.weight:.3f}") + print(f"[INFO]: AMP discriminator history_steps={amp_cfg.params['disc_history_steps']}") elif args_cli.amp_model: amp_model_path = os.path.abspath(os.path.expanduser(args_cli.amp_model)) amp_cfg.weight = float(args_cli.amp_reward_weight)