Add AMP get-up pipeline with sequence discriminator and git-sourced expert data

This commit is contained in:
Chen
2026-04-20 15:51:44 +08:00
parent 9e6e7e00f8
commit 995f6522b2
10 changed files with 1226 additions and 443 deletions

View File

@@ -0,0 +1,62 @@
# AMP Tools for `get_up`
This folder contains all AMP-related code for `rl_game/get_up`.
## Files
- `amp_rewards.py`: AMP discriminator + reward function used by training config.
- `amp_motion.py`: Build AMP expert features from local get-up keyframe YAML files.
- `migrate_legged_lab_expert_template.py`: Template converter for migrating external expert data (for example legged_lab outputs) to `expert_features.pt`.
## Quick start
Generate expert features from current local keyframes:
```bash
python rl_game/get_up/train.py --amp_from_keyframes --headless
```
Convert external motion/expert file to AMP template:
```bash
python rl_game/get_up/amp/migrate_legged_lab_expert_template.py \
--input /path/to/source_data.pt \
--output rl_game/get_up/amp/expert_features.pt \
--input_key expert_features \
--feature_dim 55 \
--repeat 4
```
Convert downloaded `legged_lab` motion pickles directly:
```bash
python rl_game/get_up/amp/migrate_legged_lab_expert_template.py \
--input third_party/legged_lab/source/legged_lab/legged_lab/data/MotionData/g1_29dof/amp/walk_and_run \
--input_glob "*.pkl" \
--target_dof 23 \
--feature_dim 55 \
--clip_weight_mode uniform \
--output rl_game/get_up/amp/expert_features.pt
```
For get-up data from gitee legged_lab, use git-like focused clip sampling:
```bash
python rl_game/get_up/amp/migrate_legged_lab_expert_template.py \
--input third_party/legged_lab_gitee/source/legged_lab/legged_lab/data/MotionData/g1_29dof/amp/get_up \
--input_glob "*.pkl" \
--target_dof 23 \
--feature_dim 55 \
--clip_weight_mode git_getup_focus \
--output rl_game/get_up/amp/expert_features.pt
```
Then train with online AMP discriminator:
```bash
python rl_game/get_up/train.py \
--amp_train_discriminator \
--amp_expert_features rl_game/get_up/amp/expert_features.pt \
--amp_reward_weight 0.6 \
--headless
```

View File

@@ -0,0 +1,7 @@
from .amp_motion import build_amp_expert_features_from_getup_keyframes
from .amp_rewards import amp_style_prior_reward
__all__ = [
"amp_style_prior_reward",
"build_amp_expert_features_from_getup_keyframes",
]

View File

@@ -0,0 +1,186 @@
from __future__ import annotations
from pathlib import Path
from typing import Any
import torch
import yaml
RIGHT_JOINT_SIGN_FLIP = {
"Shoulder_Roll",
"Elbow_Yaw",
"Hip_Roll",
"Hip_Yaw",
"Ankle_Roll",
}
JOINT_NAME_ALIAS = {
"AAHead_yaw": "Head_yaw",
}
def _safe_load_yaml(path: Path) -> dict[str, Any]:
with path.open("r", encoding="utf-8") as f:
payload = yaml.safe_load(f) or {}
if not isinstance(payload, dict):
raise ValueError(f"Invalid keyframe yaml: {path}")
return payload
def _build_interpolated_motor_table(
keyframes: list[dict[str, Any]],
sample_dt: float,
) -> tuple[torch.Tensor, list[str], torch.Tensor]:
if len(keyframes) == 0:
raise ValueError("No keyframes in yaml")
durations: list[float] = []
motor_names: set[str] = set()
for frame in keyframes:
durations.append(float(frame.get("delta", 0.1)))
motors = frame.get("motor_positions", {})
if isinstance(motors, dict):
motor_names.update(motors.keys())
names = sorted(motor_names)
if len(names) == 0:
raise ValueError("No motor_positions found in keyframes")
frame_count = len(keyframes)
motor_count = len(names)
times = torch.zeros(frame_count, dtype=torch.float32)
values = torch.full((frame_count, motor_count), float("nan"), dtype=torch.float32)
elapsed = 0.0
name_to_idx = {name: idx for idx, name in enumerate(names)}
for i, frame in enumerate(keyframes):
elapsed += max(float(frame.get("delta", 0.1)), 1e-4)
times[i] = elapsed
motors = frame.get("motor_positions", {})
if not isinstance(motors, dict):
continue
for motor_name, motor_deg in motors.items():
if motor_name not in name_to_idx:
continue
values[i, name_to_idx[motor_name]] = float(motor_deg) * (torch.pi / 180.0)
first_valid = torch.nan_to_num(values[0], nan=0.0)
values[0] = first_valid
for i in range(1, frame_count):
cur = values[i]
prev = values[i - 1]
values[i] = torch.where(torch.isnan(cur), prev, cur)
sample_dt = max(float(sample_dt), 1e-3)
sample_times = torch.arange(0.0, float(times[-1]) + 1e-6, sample_dt, dtype=torch.float32)
sample_times[0] = max(sample_times[0], 1e-6)
sample_times = torch.clamp(sample_times, min=times[0], max=times[-1])
upper = torch.searchsorted(times, sample_times, right=True)
upper = torch.clamp(upper, min=1, max=frame_count - 1)
lower = upper - 1
t0 = times.index_select(0, lower)
t1 = times.index_select(0, upper)
v0 = values.index_select(0, lower)
v1 = values.index_select(0, upper)
alpha = ((sample_times - t0) / (t1 - t0 + 1e-6)).unsqueeze(-1)
interpolated = v0 + alpha * (v1 - v0)
return interpolated, names, sample_times
def _map_motors_to_joint_targets(
motor_table: torch.Tensor,
motor_names: list[str],
joint_names: list[str],
) -> torch.Tensor:
motor_to_idx = {name: idx for idx, name in enumerate(motor_names)}
joint_targets = torch.zeros((motor_table.shape[0], len(joint_names)), dtype=torch.float32)
for joint_idx, joint_name in enumerate(joint_names):
alias_name = JOINT_NAME_ALIAS.get(joint_name, joint_name)
base_name = alias_name.replace("Left_", "").replace("Right_", "")
source_name = None
if alias_name in motor_to_idx:
source_name = alias_name
elif base_name in motor_to_idx:
source_name = base_name
if source_name is None:
continue
sign = 1.0
if alias_name.startswith("Right_") and base_name in RIGHT_JOINT_SIGN_FLIP:
sign = -1.0
joint_targets[:, joint_idx] = sign * motor_table[:, motor_to_idx[source_name]]
return joint_targets
def _features_from_keyframes(
keyframe_yaml_path: str,
joint_names: list[str],
sample_dt: float,
repeat_count: int,
) -> torch.Tensor:
payload = _safe_load_yaml(Path(keyframe_yaml_path).expanduser().resolve())
keyframes = payload.get("keyframes", [])
if not isinstance(keyframes, list):
raise ValueError(f"Invalid keyframe list in {keyframe_yaml_path}")
motor_table, motor_names, _ = _build_interpolated_motor_table(keyframes, sample_dt=sample_dt)
joint_pos = _map_motors_to_joint_targets(motor_table, motor_names, joint_names)
joint_vel = torch.zeros_like(joint_pos)
if joint_pos.shape[0] > 1:
joint_vel[1:] = (joint_pos[1:] - joint_pos[:-1]) / max(sample_dt, 1e-3)
joint_vel[0] = joint_vel[1]
root_lin = torch.zeros((joint_pos.shape[0], 3), dtype=torch.float32)
root_ang = torch.zeros((joint_pos.shape[0], 3), dtype=torch.float32)
projected_gravity = torch.zeros((joint_pos.shape[0], 3), dtype=torch.float32)
projected_gravity[:, 2] = -1.0
features = torch.cat([joint_pos, joint_vel, root_lin, root_ang, projected_gravity], dim=-1)
repeat_count = max(int(repeat_count), 1)
if repeat_count > 1:
features = features.repeat(repeat_count, 1)
return features
def build_amp_expert_features_from_getup_keyframes(
*,
front_yaml_path: str,
back_yaml_path: str,
joint_names: list[str],
output_path: str,
sample_dt: float = 0.04,
repeat_count: int = 16,
) -> tuple[str, tuple[int, int]]:
"""Generate AMP expert feature tensor from front/back get-up keyframe yaml files."""
front_features = _features_from_keyframes(
keyframe_yaml_path=front_yaml_path,
joint_names=joint_names,
sample_dt=sample_dt,
repeat_count=repeat_count,
)
back_features = _features_from_keyframes(
keyframe_yaml_path=back_yaml_path,
joint_names=joint_names,
sample_dt=sample_dt,
repeat_count=repeat_count,
)
expert_features = torch.cat([front_features, back_features], dim=0).contiguous()
out_path = Path(output_path).expanduser().resolve()
out_path.parent.mkdir(parents=True, exist_ok=True)
torch.save(
{
"expert_features": expert_features,
"meta": {
"front_yaml_path": str(Path(front_yaml_path).expanduser().resolve()),
"back_yaml_path": str(Path(back_yaml_path).expanduser().resolve()),
"sample_dt": float(sample_dt),
"repeat_count": int(repeat_count),
"feature_dim": int(expert_features.shape[-1]),
},
},
str(out_path),
)
return str(out_path), (int(expert_features.shape[0]), int(expert_features.shape[1]))

View File

@@ -0,0 +1,457 @@
from __future__ import annotations
import pickle
from pathlib import Path
import joblib
import numpy as np
import torch
import torch.nn as nn
from isaaclab.envs import ManagerBasedRLEnv
def _safe_tensor(x: torch.Tensor, nan: float = 0.0, pos: float = 1e3, neg: float = -1e3) -> torch.Tensor:
return torch.nan_to_num(x, nan=nan, posinf=pos, neginf=neg)
class AMPDiscriminator(nn.Module):
"""Lightweight discriminator used by online AMP updates."""
def __init__(self, input_dim: int, hidden_dims: tuple[int, ...]):
super().__init__()
layers: list[nn.Module] = []
in_dim = input_dim
for h_dim in hidden_dims:
layers.append(nn.Linear(in_dim, h_dim))
layers.append(nn.LayerNorm(h_dim))
layers.append(nn.SiLU())
in_dim = h_dim
layers.append(nn.Linear(in_dim, 1))
self.net = nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)
def _to_tensor_2d(value) -> torch.Tensor | None:
if isinstance(value, torch.Tensor):
t = value.float()
elif isinstance(value, np.ndarray):
t = torch.as_tensor(value, dtype=torch.float32)
elif isinstance(value, list):
try:
t = torch.as_tensor(value, dtype=torch.float32)
except Exception:
return None
else:
return None
if t.ndim == 1:
t = t.unsqueeze(0)
if t.ndim != 2:
return None
return t
def _normalize_feature_dim(x: torch.Tensor, feature_dim: int) -> torch.Tensor:
if x.shape[-1] == feature_dim:
return x
if x.shape[-1] > feature_dim:
return x[:, :feature_dim]
pad = torch.zeros((x.shape[0], feature_dim - x.shape[1]), dtype=x.dtype)
return torch.cat([x, pad], dim=-1)
def _extract_expert_bank_from_payload(payload, feature_dim: int) -> dict | None:
clip_tensors: list[torch.Tensor] = []
clip_names: list[str] = []
clip_weights_tensor: torch.Tensor | None = None
if isinstance(payload, dict):
clip_values = payload.get("expert_clips", None)
if isinstance(clip_values, list):
for i, clip_value in enumerate(clip_values):
clip_tensor = _to_tensor_2d(clip_value)
if clip_tensor is None:
continue
clip_tensors.append(_normalize_feature_dim(clip_tensor, feature_dim))
clip_names.append(f"clip_{i}")
raw_names = payload.get("clip_names", None)
if isinstance(raw_names, list) and len(raw_names) == len(clip_tensors):
clip_names = [str(n) for n in raw_names]
raw_weights = payload.get("clip_weights", None)
clip_weights_tensor = _to_tensor_2d(raw_weights) if raw_weights is not None else None
if clip_weights_tensor is not None:
clip_weights_tensor = clip_weights_tensor.reshape(-1)
if clip_weights_tensor.shape[0] != len(clip_tensors):
clip_weights_tensor = None
for key in ("expert_features", "features", "obs"):
value = payload.get(key, None)
tensor = _to_tensor_2d(value)
if tensor is not None:
clip_tensors.append(_normalize_feature_dim(tensor, feature_dim))
clip_names.append(key)
break
else:
tensor = _to_tensor_2d(payload)
if tensor is not None:
clip_tensors.append(_normalize_feature_dim(tensor, feature_dim))
clip_names.append("expert_features")
clip_tensors = [c for c in clip_tensors if c.shape[0] >= 2]
if len(clip_tensors) == 0:
return None
if clip_weights_tensor is None:
clip_weights_tensor = torch.ones(len(clip_tensors), dtype=torch.float32)
else:
clip_weights_tensor = torch.clamp(clip_weights_tensor.float(), min=0.0)
if float(torch.sum(clip_weights_tensor).item()) <= 0.0:
clip_weights_tensor = torch.ones(len(clip_tensors), dtype=torch.float32)
flat_features = torch.cat(clip_tensors, dim=0)
return {
"flat_features": flat_features,
"clips": clip_tensors,
"clip_names": clip_names,
"clip_weights": clip_weights_tensor,
}
def _load_amp_expert_features(
expert_features_path: str,
device: str,
feature_dim: int,
fallback_samples: int,
) -> dict | None:
"""Load expert AMP features bank. Returns None when file is unavailable."""
if not expert_features_path:
return None
p = Path(expert_features_path).expanduser()
if not p.is_file():
return None
try:
payload = torch.load(str(p), map_location="cpu")
except Exception:
try:
with p.open("rb") as f:
payload = pickle.load(f)
except Exception:
try:
payload = joblib.load(str(p))
except Exception:
return None
bank = _extract_expert_bank_from_payload(payload, feature_dim=feature_dim)
if bank is None:
return None
flat_features = bank["flat_features"].float()
if flat_features.shape[0] < fallback_samples:
reps = int((fallback_samples + flat_features.shape[0] - 1) // flat_features.shape[0])
flat_features = flat_features.repeat(reps, 1)
clip_tensors = [clip.float() for clip in bank["clips"]]
clip_weights = bank["clip_weights"].float()
clip_weights = torch.clamp(clip_weights, min=0.0)
if float(torch.sum(clip_weights).item()) <= 0.0:
clip_weights = torch.ones_like(clip_weights)
clip_weights = clip_weights / torch.sum(clip_weights)
return {
"flat_features": flat_features.to(device=device),
"clips": [c.to(device=device) for c in clip_tensors],
"clip_names": bank["clip_names"],
"clip_weights": clip_weights.to(device=device),
}
def _policy_sequence_features(
env: ManagerBasedRLEnv,
current_features: torch.Tensor,
history_steps: int,
) -> torch.Tensor:
"""Build rolling policy sequence features [N, H, D]."""
history_steps = max(int(history_steps), 1)
cache_key = "amp_policy_hist_cache"
cache = env.extras.get(cache_key, None)
if not isinstance(cache, dict):
cache = {}
env.extras[cache_key] = cache
hist = cache.get("hist", None)
if (
not isinstance(hist, torch.Tensor)
or hist.shape[0] != env.num_envs
or hist.shape[1] != history_steps
or hist.shape[2] != current_features.shape[1]
):
hist = current_features.unsqueeze(1).repeat(1, history_steps, 1)
if history_steps > 1:
hist[:, :-1] = hist[:, 1:].clone()
hist[:, -1] = current_features
cache["hist"] = hist
env.extras[cache_key] = cache
return hist
def _sample_expert_sequence_batch(
expert_bank: dict,
batch_size: int,
history_steps: int,
device: str,
) -> torch.Tensor:
"""Sample expert sequence batch [B, H, D] with clip-weighted sampling."""
clips = expert_bank["clips"]
clip_weights = expert_bank["clip_weights"]
clip_count = len(clips)
if clip_count <= 0:
return torch.empty((0, history_steps, 0), device=device)
clip_ids = torch.multinomial(clip_weights, num_samples=batch_size, replacement=True)
seq_list: list[torch.Tensor] = []
for clip_id in clip_ids.tolist():
clip = clips[int(clip_id)]
clip_len = int(clip.shape[0])
if clip_len >= history_steps:
max_start = clip_len - history_steps
if max_start > 0:
start = int(torch.randint(0, max_start + 1, (1,), device=device).item())
else:
start = 0
seq = clip[start : start + history_steps]
else:
pad = clip[-1:].repeat(history_steps - clip_len, 1)
seq = torch.cat([clip, pad], dim=0)
seq_list.append(seq)
return torch.stack(seq_list, dim=0)
def _get_amp_state(
env: ManagerBasedRLEnv,
amp_enabled: bool,
amp_model_path: str,
amp_train_enabled: bool,
amp_expert_features_path: str,
feature_dim: int,
disc_hidden_dim: int,
disc_hidden_layers: int,
disc_lr: float,
disc_weight_decay: float,
disc_min_expert_samples: int,
disc_history_steps: int,
):
"""Get cached AMP state (frozen jit or trainable discriminator)."""
cache_key = "amp_state_cache"
hidden_layers = max(int(disc_hidden_layers), 1)
hidden_dim = max(int(disc_hidden_dim), 16)
history_steps = max(int(disc_history_steps), 1)
state_sig = (
bool(amp_enabled),
str(amp_model_path),
bool(amp_train_enabled),
str(amp_expert_features_path),
int(feature_dim),
history_steps,
hidden_dim,
hidden_layers,
float(disc_lr),
float(disc_weight_decay),
)
cached = env.extras.get(cache_key, None)
if isinstance(cached, dict) and cached.get("sig") == state_sig:
return cached
state = {
"sig": state_sig,
"mode": "disabled",
"model": None,
"optimizer": None,
"expert_bank": None,
"disc_history_steps": history_steps,
"step": 0,
"last_loss": 0.0,
"last_acc_policy": 0.0,
"last_acc_expert": 0.0,
}
if amp_train_enabled:
expert_bank = _load_amp_expert_features(
amp_expert_features_path,
device=env.device,
feature_dim=feature_dim,
fallback_samples=max(disc_min_expert_samples, 512),
)
if expert_bank is not None:
model = AMPDiscriminator(input_dim=feature_dim * history_steps, hidden_dims=tuple([hidden_dim] * hidden_layers)).to(env.device)
optimizer = torch.optim.AdamW(model.parameters(), lr=float(disc_lr), weight_decay=float(disc_weight_decay))
state["mode"] = "trainable"
state["model"] = model
state["optimizer"] = optimizer
state["expert_bank"] = expert_bank
elif amp_enabled and amp_model_path:
model_path = Path(amp_model_path).expanduser()
if model_path.is_file():
try:
model = torch.jit.load(str(model_path), map_location=env.device)
model.eval()
state["mode"] = "jit"
state["model"] = model
except Exception:
pass
env.extras[cache_key] = state
return state
def _build_amp_features(env: ManagerBasedRLEnv, feature_clip: float = 8.0) -> torch.Tensor:
"""Build AMP-style discriminator features from robot kinematics."""
robot_data = env.scene["robot"].data
joint_pos_rel = robot_data.joint_pos - robot_data.default_joint_pos
joint_vel = robot_data.joint_vel
root_lin_vel = robot_data.root_lin_vel_w
root_ang_vel = robot_data.root_ang_vel_w
projected_gravity = robot_data.projected_gravity_b
amp_features = torch.cat([joint_pos_rel, joint_vel, root_lin_vel, root_ang_vel, projected_gravity], dim=-1)
amp_features = _safe_tensor(amp_features, nan=0.0, pos=feature_clip, neg=-feature_clip)
return torch.clamp(amp_features, min=-feature_clip, max=feature_clip)
def amp_style_prior_reward(
env: ManagerBasedRLEnv,
amp_enabled: bool = False,
amp_model_path: str = "",
amp_train_enabled: bool = False,
amp_expert_features_path: str = "",
disc_hidden_dim: int = 256,
disc_hidden_layers: int = 2,
disc_lr: float = 3e-4,
disc_weight_decay: float = 1e-6,
disc_update_interval: int = 4,
disc_batch_size: int = 1024,
disc_min_expert_samples: int = 2048,
disc_history_steps: int = 4,
feature_clip: float = 8.0,
logit_scale: float = 1.0,
amp_reward_gain: float = 1.0,
internal_reward_scale: float = 1.0,
) -> torch.Tensor:
"""AMP style prior reward with optional online discriminator training."""
zeros = torch.zeros(env.num_envs, device=env.device)
amp_score = zeros
model_loaded = 0.0
amp_train_active = 0.0
disc_loss = 0.0
disc_acc_policy = 0.0
disc_acc_expert = 0.0
amp_features = _build_amp_features(env, feature_clip=feature_clip)
amp_state = _get_amp_state(
env=env,
amp_enabled=amp_enabled,
amp_model_path=amp_model_path,
amp_train_enabled=amp_train_enabled,
amp_expert_features_path=amp_expert_features_path,
feature_dim=int(amp_features.shape[-1]),
disc_hidden_dim=disc_hidden_dim,
disc_hidden_layers=disc_hidden_layers,
disc_lr=disc_lr,
disc_weight_decay=disc_weight_decay,
disc_min_expert_samples=disc_min_expert_samples,
disc_history_steps=disc_history_steps,
)
discriminator = amp_state.get("model", None)
history_steps = max(int(amp_state.get("disc_history_steps", disc_history_steps)), 1)
policy_seq = _policy_sequence_features(env, amp_features, history_steps=history_steps)
policy_seq_flat = policy_seq.reshape(policy_seq.shape[0], -1)
if discriminator is not None:
model_loaded = 1.0
if amp_state.get("mode") == "trainable" and discriminator is not None:
amp_train_active = 1.0
optimizer = amp_state.get("optimizer", None)
expert_bank = amp_state.get("expert_bank", None)
amp_state["step"] = int(amp_state.get("step", 0)) + 1
update_interval = max(int(disc_update_interval), 1)
batch_size = max(int(disc_batch_size), 32)
if optimizer is not None and isinstance(expert_bank, dict) and amp_state["step"] % update_interval == 0:
policy_features = policy_seq_flat.detach()
policy_count = policy_features.shape[0]
if policy_count > batch_size:
policy_ids = torch.randint(0, policy_count, (batch_size,), device=env.device)
policy_batch = policy_features.index_select(0, policy_ids)
else:
policy_batch = policy_features
expert_seq = _sample_expert_sequence_batch(
expert_bank=expert_bank,
batch_size=policy_batch.shape[0],
history_steps=history_steps,
device=env.device,
)
expert_batch = expert_seq.reshape(expert_seq.shape[0], -1)
discriminator.train()
optimizer.zero_grad(set_to_none=True)
logits_expert = discriminator(expert_batch).squeeze(-1)
logits_policy = discriminator(policy_batch).squeeze(-1)
loss_expert = nn.functional.binary_cross_entropy_with_logits(logits_expert, torch.ones_like(logits_expert))
loss_policy = nn.functional.binary_cross_entropy_with_logits(logits_policy, torch.zeros_like(logits_policy))
loss = 0.5 * (loss_expert + loss_policy)
loss.backward()
nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0)
optimizer.step()
discriminator.eval()
with torch.no_grad():
disc_loss = float(loss.detach().item())
disc_acc_expert = float((torch.sigmoid(logits_expert) > 0.5).float().mean().item())
disc_acc_policy = float((torch.sigmoid(logits_policy) < 0.5).float().mean().item())
amp_state["last_loss"] = disc_loss
amp_state["last_acc_expert"] = disc_acc_expert
amp_state["last_acc_policy"] = disc_acc_policy
else:
disc_loss = float(amp_state.get("last_loss", 0.0))
disc_acc_expert = float(amp_state.get("last_acc_expert", 0.0))
disc_acc_policy = float(amp_state.get("last_acc_policy", 0.0))
if discriminator is not None:
discriminator.eval()
with torch.no_grad():
logits = discriminator(policy_seq_flat)
if isinstance(logits, (tuple, list)):
logits = logits[0]
if logits.ndim > 1:
logits = logits.squeeze(-1)
elif amp_enabled and amp_model_path:
# For external scripted models, try temporal then fallback to single frame.
model = amp_state.get("model", None)
if model is not None:
with torch.no_grad():
try:
logits = model(policy_seq_flat)
except Exception:
logits = model(amp_features)
if isinstance(logits, (tuple, list)):
logits = logits[0]
if logits.ndim > 1:
logits = logits.squeeze(-1)
logits = _safe_tensor(logits, nan=0.0, pos=20.0, neg=-20.0)
amp_score = torch.sigmoid(logit_scale * logits)
amp_score = _safe_tensor(amp_score, nan=0.0, pos=1.0, neg=0.0)
amp_reward = _safe_tensor(amp_reward_gain * amp_score, nan=0.0, pos=10.0, neg=0.0)
log_dict = env.extras.get("log", {})
if isinstance(log_dict, dict):
log_dict["amp_score_mean"] = torch.mean(amp_score).detach().item()
log_dict["amp_reward_mean"] = torch.mean(amp_reward).detach().item()
log_dict["amp_model_loaded_mean"] = model_loaded
log_dict["amp_train_active_mean"] = amp_train_active
log_dict["amp_disc_loss_mean"] = disc_loss
log_dict["amp_disc_acc_policy_mean"] = disc_acc_policy
log_dict["amp_disc_acc_expert_mean"] = disc_acc_expert
log_dict["amp_disc_history_steps"] = float(history_steps)
env.extras["log"] = log_dict
return internal_reward_scale * amp_reward

View File

@@ -0,0 +1,258 @@
from __future__ import annotations
import argparse
import json
import pickle
from glob import glob
from pathlib import Path
import numpy as np
import joblib
try:
import torch
except Exception:
torch = None
GIT_GETUP_FOCUS_CLIP = "fallAndGetUp2_subject2_1200_1370"
def _load_payload(path: Path):
if path.suffix.lower() in (".pt", ".pth"):
if torch is None:
raise RuntimeError("Loading .pt/.pth requires torch to be installed.")
return torch.load(str(path), map_location="cpu")
if path.suffix.lower() in (".pkl",):
try:
with path.open("rb") as f:
return pickle.load(f)
except Exception:
return joblib.load(str(path))
if path.suffix.lower() in (".json",):
with path.open("r", encoding="utf-8") as f:
return json.load(f)
raise ValueError(f"Unsupported file type: {path.suffix}")
def _to_numpy(payload, key_hint: str = "") -> np.ndarray:
if torch is not None and isinstance(payload, torch.Tensor):
return payload.detach().cpu().numpy().astype(np.float32)
if isinstance(payload, np.ndarray):
return payload.astype(np.float32)
if isinstance(payload, list):
return np.asarray(payload, dtype=np.float32)
if isinstance(payload, dict):
candidate_keys = [key_hint] if key_hint else []
candidate_keys += [
"expert_features",
"features",
"observations",
"obs",
"joint_pos",
"motion",
"data",
]
for key in candidate_keys:
if key and key in payload:
value = payload[key]
if torch is not None and isinstance(value, torch.Tensor):
return value.detach().cpu().numpy().astype(np.float32)
if isinstance(value, np.ndarray):
return value.astype(np.float32)
if isinstance(value, list):
return np.asarray(value, dtype=np.float32)
for value in payload.values():
if torch is not None and isinstance(value, torch.Tensor):
return value.detach().cpu().numpy().astype(np.float32)
if isinstance(value, np.ndarray):
return value.astype(np.float32)
if isinstance(value, list):
return np.asarray(value, dtype=np.float32)
raise ValueError("Could not locate tensor-like payload. Provide --input_key when needed.")
def _ensure_2d(x: np.ndarray) -> np.ndarray:
if x.ndim == 1:
return x[None, :]
if x.ndim != 2:
raise ValueError(f"Expected 2D array [N, D], got shape={tuple(x.shape)}")
return x.astype(np.float32)
def _from_legged_lab_motion_dict(payload: dict, target_dof: int) -> np.ndarray:
"""
Convert legged_lab motion pickle payload to AMP feature tensor [N, D].
Expected keys: fps, root_pos, dof_pos.
"""
dof_pos = np.asarray(payload.get("dof_pos", None))
root_pos = np.asarray(payload.get("root_pos", None))
fps = float(payload.get("fps", 30.0))
if dof_pos.ndim != 2:
raise ValueError("legged_lab payload missing valid dof_pos [N, M].")
if root_pos.ndim != 2 or root_pos.shape[0] != dof_pos.shape[0]:
root_pos = np.zeros((dof_pos.shape[0], 3), dtype=np.float32)
dt = 1.0 / max(fps, 1e-3)
dof_pos = dof_pos.astype(np.float32)
root_pos = root_pos.astype(np.float32)
# Current get_up AMP feature dim expects 23 dof by default.
target_dof = int(target_dof)
if dof_pos.shape[1] >= target_dof:
dof_pos = dof_pos[:, :target_dof]
else:
dof_pos = np.pad(dof_pos, ((0, 0), (0, target_dof - dof_pos.shape[1])), mode="constant")
dof_vel = np.zeros_like(dof_pos, dtype=np.float32)
if dof_pos.shape[0] > 1:
dof_vel[1:] = (dof_pos[1:] - dof_pos[:-1]) / dt
dof_vel[0] = dof_vel[1]
root_lin = np.zeros((dof_pos.shape[0], 3), dtype=np.float32)
if root_pos.shape[0] > 1:
root_lin[1:] = (root_pos[1:] - root_pos[:-1]) / dt
root_lin[0] = root_lin[1]
root_ang = np.zeros((dof_pos.shape[0], 3), dtype=np.float32)
gravity = np.zeros((dof_pos.shape[0], 3), dtype=np.float32)
gravity[:, 2] = -1.0
x = np.concatenate([dof_pos, dof_vel, root_lin, root_ang, gravity], axis=-1).astype(np.float32)
return x
def _normalize_dim(x: np.ndarray, feature_dim: int) -> np.ndarray:
if x.shape[1] == feature_dim:
return x
if x.shape[1] > feature_dim:
return x[:, :feature_dim]
pad = np.zeros((x.shape[0], feature_dim - x.shape[1]), dtype=np.float32)
return np.concatenate([x, pad], axis=-1)
def _collect_input_files(input_path: Path, glob_pattern: str) -> list[Path]:
if input_path.is_file():
return [input_path]
if input_path.is_dir():
files = sorted([Path(p) for p in glob(str(input_path / glob_pattern))])
if files:
return files
raise FileNotFoundError(f"No source files found for input={input_path} pattern={glob_pattern}")
def _build_clip_weights(clip_names: list[str], weight_mode: str) -> np.ndarray:
if len(clip_names) == 0:
return np.zeros((0,), dtype=np.float32)
if weight_mode == "uniform":
return np.ones((len(clip_names),), dtype=np.float32)
if weight_mode == "git_getup_focus":
weights = np.array([1.0 if GIT_GETUP_FOCUS_CLIP in n else 0.0 for n in clip_names], dtype=np.float32)
if float(np.sum(weights)) <= 0.0:
return np.ones((len(clip_names),), dtype=np.float32)
return weights
return np.ones((len(clip_names),), dtype=np.float32)
def main():
parser = argparse.ArgumentParser(
description=(
"Template converter: migrate legged_lab (or other) motion/expert data "
"to AMP expert_features.pt for rl_game/get_up."
)
)
parser.add_argument("--input", required=True, type=str, help="Path to source motion/expert file or directory.")
parser.add_argument("--output", required=True, type=str, help="Output path for expert_features.pt.")
parser.add_argument(
"--input_key",
type=str,
default="",
help="Optional key to locate tensor inside input payload dict.",
)
parser.add_argument(
"--feature_dim",
type=int,
default=55,
help="Target AMP feature dimension. For current get_up config, default=55.",
)
parser.add_argument(
"--input_glob",
type=str,
default="*.pkl",
help="Glob used when --input is a directory.",
)
parser.add_argument(
"--target_dof",
type=int,
default=23,
help="Target dof count when converting legged_lab pkl dof_pos.",
)
parser.add_argument(
"--clip_weight_mode",
type=str,
default="git_getup_focus",
choices=["git_getup_focus", "uniform"],
help="Clip sampling weight mode for expert sequence training.",
)
parser.add_argument("--repeat", type=int, default=1, help="Repeat samples for small datasets.")
args = parser.parse_args()
in_path = Path(args.input).expanduser().resolve()
input_files = _collect_input_files(in_path, args.input_glob)
clip_arrays: list[np.ndarray] = []
clip_names: list[str] = []
for f in input_files:
payload = _load_payload(f)
if isinstance(payload, dict) and "dof_pos" in payload and f.suffix.lower() == ".pkl":
x = _from_legged_lab_motion_dict(payload, target_dof=int(args.target_dof))
else:
x = _to_numpy(payload, key_hint=args.input_key)
x = _ensure_2d(x)
x = _normalize_dim(x.astype(np.float32), int(args.feature_dim))
clip_arrays.append(x)
clip_names.append(f.stem)
repeat = max(int(args.repeat), 1)
if repeat > 1:
clip_arrays = clip_arrays * repeat
clip_names = clip_names * repeat
clip_weights = _build_clip_weights(clip_names, weight_mode=args.clip_weight_mode)
x = np.concatenate(clip_arrays, axis=0).astype(np.float32)
clip_lengths = [int(c.shape[0]) for c in clip_arrays]
clip_offsets: list[int] = []
cursor = 0
for length in clip_lengths:
clip_offsets.append(cursor)
cursor += length
repeat = max(int(args.repeat), 1)
out_path = Path(args.output).expanduser().resolve()
out_path.parent.mkdir(parents=True, exist_ok=True)
payload = {
"expert_features": x,
"expert_clips": clip_arrays,
"clip_names": clip_names,
"clip_weights": clip_weights,
"clip_lengths": clip_lengths,
"clip_offsets": clip_offsets,
"meta": {
"source": str(in_path),
"source_count": len(input_files),
"input_key": args.input_key,
"feature_dim": int(args.feature_dim),
"target_dof": int(args.target_dof),
"input_glob": args.input_glob,
"clip_weight_mode": args.clip_weight_mode,
"git_getup_focus_clip": GIT_GETUP_FOCUS_CLIP,
"repeat": int(repeat),
},
}
if torch is not None:
torch.save(payload, str(out_path))
else:
with out_path.open("wb") as f:
pickle.dump(payload, f)
print(f"[OK] Saved AMP expert features -> {out_path} shape={tuple(x.shape)}")
if __name__ == "__main__":
main()

View File

@@ -1,204 +0,0 @@
import argparse
import math
from pathlib import Path
import torch
import yaml
# AMP feature order must match `_build_amp_features` in `t1_env_cfg.py`:
# [joint_pos_rel(23), joint_vel(23), root_lin_vel(3), root_ang_vel(3), projected_gravity(3)].
T1_JOINT_NAMES = [
"AAHead_yaw",
"Head_pitch",
"Left_Shoulder_Pitch",
"Left_Shoulder_Roll",
"Left_Elbow_Pitch",
"Left_Elbow_Yaw",
"Right_Shoulder_Pitch",
"Right_Shoulder_Roll",
"Right_Elbow_Pitch",
"Right_Elbow_Yaw",
"Waist",
"Left_Hip_Pitch",
"Left_Hip_Roll",
"Left_Hip_Yaw",
"Left_Knee_Pitch",
"Left_Ankle_Pitch",
"Left_Ankle_Roll",
"Right_Hip_Pitch",
"Right_Hip_Roll",
"Right_Hip_Yaw",
"Right_Knee_Pitch",
"Right_Ankle_Pitch",
"Right_Ankle_Roll",
]
JOINT_TO_IDX = {name: i for i, name in enumerate(T1_JOINT_NAMES)}
# Mirror rules aligned with `behaviors/custom/keyframe/keyframe.py`.
MOTOR_SYMMETRY = {
"Head_yaw": (("Head_yaw",), False),
"Head_pitch": (("Head_pitch",), False),
"Shoulder_Pitch": (("Left_Shoulder_Pitch", "Right_Shoulder_Pitch"), False),
"Shoulder_Roll": (("Left_Shoulder_Roll", "Right_Shoulder_Roll"), True),
"Elbow_Pitch": (("Left_Elbow_Pitch", "Right_Elbow_Pitch"), False),
"Elbow_Yaw": (("Left_Elbow_Yaw", "Right_Elbow_Yaw"), True),
"Waist": (("Waist",), False),
"Hip_Pitch": (("Left_Hip_Pitch", "Right_Hip_Pitch"), False),
"Hip_Roll": (("Left_Hip_Roll", "Right_Hip_Roll"), True),
"Hip_Yaw": (("Left_Hip_Yaw", "Right_Hip_Yaw"), True),
"Knee_Pitch": (("Left_Knee_Pitch", "Right_Knee_Pitch"), False),
"Ankle_Pitch": (("Left_Ankle_Pitch", "Right_Ankle_Pitch"), False),
"Ankle_Roll": (("Left_Ankle_Roll", "Right_Ankle_Roll"), True),
}
READABLE_TO_POLICY = {"Head_yaw": "AAHead_yaw"}
def decode_keyframe_motor_positions(raw_motor_positions: dict[str, float]) -> dict[str, float]:
"""Decode one keyframe into per-joint radians."""
out: dict[str, float] = {}
deg_to_rad = math.pi / 180.0
for readable_name, position_deg in raw_motor_positions.items():
if readable_name in MOTOR_SYMMETRY:
motor_names, is_inverse_direction = MOTOR_SYMMETRY[readable_name]
invert_state = bool(is_inverse_direction)
for motor_name in motor_names:
signed_deg = position_deg if invert_state else -position_deg
invert_state = False
out_name = READABLE_TO_POLICY.get(motor_name, motor_name)
out[out_name] = float(signed_deg) * deg_to_rad
else:
out_name = READABLE_TO_POLICY.get(readable_name, readable_name)
out[out_name] = float(position_deg) * deg_to_rad
return out
def load_sequence(yaml_path: Path) -> list[tuple[float, torch.Tensor]]:
"""Load yaml keyframes -> list[(delta_seconds, joint_pos_vec)]."""
with yaml_path.open("r", encoding="utf-8") as f:
desc = yaml.safe_load(f) or {}
out: list[tuple[float, torch.Tensor]] = []
for keyframe in desc.get("keyframes", []):
delta_s = max(float(keyframe.get("delta", 0.1)), 1e-3)
raw = keyframe.get("motor_positions", {}) or {}
decoded = decode_keyframe_motor_positions(raw)
joint_pos = torch.zeros(len(T1_JOINT_NAMES), dtype=torch.float32)
for j_name, j_val in decoded.items():
idx = JOINT_TO_IDX.get(j_name, None)
if idx is not None:
joint_pos[idx] = float(j_val)
out.append((delta_s, joint_pos))
return out
def sequence_to_amp_features(
sequence: list[tuple[float, torch.Tensor]],
sample_fps: float,
projected_gravity: tuple[float, float, float],
) -> torch.Tensor:
"""Convert decoded sequence into AMP features tensor (N, 55)."""
if len(sequence) == 0:
raise ValueError("Empty keyframe sequence.")
dt = 1.0 / max(sample_fps, 1e-6)
grav = torch.tensor(projected_gravity, dtype=torch.float32)
frames_joint_pos: list[torch.Tensor] = []
for delta_s, joint_pos in sequence:
repeat = max(int(round(delta_s / dt)), 1)
for _ in range(repeat):
frames_joint_pos.append(joint_pos.clone())
if len(frames_joint_pos) < 2:
frames_joint_pos.append(frames_joint_pos[0].clone())
pos = torch.stack(frames_joint_pos, dim=0)
vel = torch.zeros_like(pos)
vel[1:] = (pos[1:] - pos[:-1]) / dt
vel[0] = vel[1]
root_lin = torch.zeros((pos.shape[0], 3), dtype=torch.float32)
root_ang = torch.zeros((pos.shape[0], 3), dtype=torch.float32)
grav_batch = grav.unsqueeze(0).repeat(pos.shape[0], 1)
return torch.cat([pos, vel, root_lin, root_ang, grav_batch], dim=-1)
def main():
parser = argparse.ArgumentParser(description="Build AMP expert features from get_up keyframe YAML files.")
parser.add_argument(
"--front_yaml",
type=str,
default="behaviors/custom/keyframe/get_up/get_up_front.yaml",
help="Path to front get-up YAML.",
)
parser.add_argument(
"--back_yaml",
type=str,
default="behaviors/custom/keyframe/get_up/get_up_back.yaml",
help="Path to back get-up YAML.",
)
parser.add_argument(
"--sample_fps",
type=float,
default=50.0,
help="Sampling fps when expanding keyframe durations.",
)
parser.add_argument(
"--repeat_cycles",
type=int,
default=200,
help="How many times to repeat front+back sequences to enlarge dataset.",
)
parser.add_argument(
"--projected_gravity",
type=float,
nargs=3,
default=(0.0, 0.0, -1.0),
help="Projected gravity feature used for synthesized expert data.",
)
parser.add_argument(
"--output",
type=str,
default="rl_game/get_up/amp/expert_features.pt",
help="Output expert feature file path.",
)
args = parser.parse_args()
front_path = Path(args.front_yaml).expanduser().resolve()
back_path = Path(args.back_yaml).expanduser().resolve()
if not front_path.is_file():
raise FileNotFoundError(f"Front YAML not found: {front_path}")
if not back_path.is_file():
raise FileNotFoundError(f"Back YAML not found: {back_path}")
front_seq = load_sequence(front_path)
back_seq = load_sequence(back_path)
front_feat = sequence_to_amp_features(front_seq, args.sample_fps, tuple(args.projected_gravity))
back_feat = sequence_to_amp_features(back_seq, args.sample_fps, tuple(args.projected_gravity))
base_feat = torch.cat([front_feat, back_feat], dim=0)
repeat_cycles = max(int(args.repeat_cycles), 1)
expert_features = base_feat.repeat(repeat_cycles, 1).contiguous()
out_path = Path(args.output).expanduser()
if not out_path.is_absolute():
out_path = Path.cwd() / out_path
out_path.parent.mkdir(parents=True, exist_ok=True)
torch.save(
{
"expert_features": expert_features,
"feature_dim": int(expert_features.shape[1]),
"num_samples": int(expert_features.shape[0]),
"source": "get_up_keyframe_yaml",
"front_yaml": str(front_path),
"back_yaml": str(back_path),
"sample_fps": float(args.sample_fps),
"repeat_cycles": repeat_cycles,
"projected_gravity": [float(v) for v in args.projected_gravity],
},
str(out_path),
)
print(f"[INFO]: saved expert features -> {out_path}")
print(f"[INFO]: shape={tuple(expert_features.shape)}")
if __name__ == "__main__":
main()

View File

@@ -27,7 +27,7 @@ params:
name: default name: default
config: config:
name: T1_Walking name: T1_GetUp
env_name: rlgym # Isaac Lab 包装器 env_name: rlgym # Isaac Lab 包装器
multi_gpu: False multi_gpu: False
ppo: True ppo: True

View File

@@ -1,6 +1,6 @@
import torch import torch
import torch.nn as nn
from pathlib import Path from pathlib import Path
import yaml
import isaaclab.envs.mdp as mdp import isaaclab.envs.mdp as mdp
from isaaclab.envs import ManagerBasedRLEnvCfg, ManagerBasedRLEnv from isaaclab.envs import ManagerBasedRLEnvCfg, ManagerBasedRLEnv
from isaaclab.managers import ObservationGroupCfg as ObsGroup from isaaclab.managers import ObservationGroupCfg as ObsGroup
@@ -11,8 +11,13 @@ from isaaclab.managers import EventTermCfg as EventTerm
from isaaclab.envs.mdp import JointPositionActionCfg from isaaclab.envs.mdp import JointPositionActionCfg
from isaaclab.managers import SceneEntityCfg from isaaclab.managers import SceneEntityCfg
from isaaclab.utils import configclass from isaaclab.utils import configclass
from rl_game.get_up.amp.amp_rewards import amp_style_prior_reward
from rl_game.get_up.env.t1_env import T1SceneCfg from rl_game.get_up.env.t1_env import T1SceneCfg
_PROJECT_ROOT = Path(__file__).resolve().parents[3]
_DEFAULT_FRONT_KEYFRAME_YAML = str(_PROJECT_ROOT / "behaviors" / "custom" / "keyframe" / "get_up" / "get_up_front.yaml")
_DEFAULT_BACK_KEYFRAME_YAML = str(_PROJECT_ROOT / "behaviors" / "custom" / "keyframe" / "get_up" / "get_up_back.yaml")
def _contact_force_z(env: ManagerBasedRLEnv, sensor_cfg: SceneEntityCfg) -> torch.Tensor: def _contact_force_z(env: ManagerBasedRLEnv, sensor_cfg: SceneEntityCfg) -> torch.Tensor:
"""Sum positive vertical contact force on selected bodies.""" """Sum positive vertical contact force on selected bodies."""
sensor = env.scene.sensors.get(sensor_cfg.name) sensor = env.scene.sensors.get(sensor_cfg.name)
@@ -30,272 +35,216 @@ def _safe_tensor(x: torch.Tensor, nan: float = 0.0, pos: float = 1e3, neg: float
return torch.nan_to_num(x, nan=nan, posinf=pos, neginf=neg) return torch.nan_to_num(x, nan=nan, posinf=pos, neginf=neg)
class AMPDiscriminator(nn.Module): def _resolve_path(path_like: str) -> Path:
"""Lightweight discriminator used by online AMP updates.""" path = Path(path_like).expanduser()
if path.is_absolute():
def __init__(self, input_dim: int, hidden_dims: tuple[int, ...]): return path
super().__init__() return (_PROJECT_ROOT / path).resolve()
layers: list[nn.Module] = []
in_dim = input_dim
for h_dim in hidden_dims:
layers.append(nn.Linear(in_dim, h_dim))
layers.append(nn.LayerNorm(h_dim))
layers.append(nn.SiLU())
in_dim = h_dim
layers.append(nn.Linear(in_dim, 1))
self.net = nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)
def _extract_tensor_from_amp_payload(payload) -> torch.Tensor | None: def _interpolate_keyframes(
if isinstance(payload, torch.Tensor): keyframes: list[dict],
return payload sample_dt: float,
if isinstance(payload, dict): ) -> tuple[torch.Tensor, list[str]]:
for key in ("expert_features", "features", "obs"): """Interpolate sparse keyframes into dense [T, M] motor angle table in radians."""
value = payload.get(key, None) if len(keyframes) == 0:
if isinstance(value, torch.Tensor): raise ValueError("No keyframes in motion yaml")
return value
return None motor_names: set[str] = set()
for frame in keyframes:
motors = frame.get("motor_positions", {})
if isinstance(motors, dict):
motor_names.update(motors.keys())
names = sorted(motor_names)
if len(names) == 0:
raise ValueError("No motor_positions found in keyframes")
frame_count = len(keyframes)
times = torch.zeros(frame_count, dtype=torch.float32)
values = torch.full((frame_count, len(names)), float("nan"), dtype=torch.float32)
name_to_idx = {name: idx for idx, name in enumerate(names)}
elapsed = 0.0
for i, frame in enumerate(keyframes):
elapsed += max(float(frame.get("delta", 0.1)), 1e-4)
times[i] = elapsed
motors = frame.get("motor_positions", {})
if not isinstance(motors, dict):
continue
for motor_name, motor_deg in motors.items():
if motor_name not in name_to_idx:
continue
values[i, name_to_idx[motor_name]] = float(motor_deg) * (torch.pi / 180.0)
values[0] = torch.nan_to_num(values[0], nan=0.0)
for i in range(1, frame_count):
values[i] = torch.where(torch.isnan(values[i]), values[i - 1], values[i])
sample_dt = max(float(sample_dt), 1e-3)
sample_times = torch.arange(0.0, float(times[-1]) + 1e-6, sample_dt, dtype=torch.float32)
sample_times[0] = max(sample_times[0], 1e-6)
sample_times = torch.clamp(sample_times, min=times[0], max=times[-1])
upper = torch.searchsorted(times, sample_times, right=True)
upper = torch.clamp(upper, min=1, max=frame_count - 1)
lower = upper - 1
t0 = times.index_select(0, lower)
t1 = times.index_select(0, upper)
v0 = values.index_select(0, lower)
v1 = values.index_select(0, upper)
alpha = ((sample_times - t0) / (t1 - t0 + 1e-6)).unsqueeze(-1)
interp = v0 + alpha * (v1 - v0)
return interp, names
def _load_amp_expert_features( def _motion_table_from_yaml(
expert_features_path: str, yaml_path: str,
device: str, joint_names: list[str],
feature_dim: int, sample_dt: float,
fallback_samples: int, ) -> tuple[torch.Tensor, float]:
) -> torch.Tensor | None: """
"""Load expert AMP features. Returns None when file is unavailable.""" Build [T, J] target joint motion from get-up keyframes.
if not expert_features_path: Unknown joints default to 0.0 (neutral).
return None """
p = Path(expert_features_path).expanduser() path = _resolve_path(yaml_path)
if not p.is_file(): if not path.is_file():
return None raise FileNotFoundError(f"Motion yaml not found: {path}")
try: with path.open("r", encoding="utf-8") as f:
payload = torch.load(str(p), map_location="cpu") payload = yaml.safe_load(f) or {}
except Exception: keyframes = payload.get("keyframes", [])
return None if not isinstance(keyframes, list):
expert = _extract_tensor_from_amp_payload(payload) raise ValueError(f"Invalid keyframes in yaml: {path}")
if expert is None:
return None motor_table, motor_names = _interpolate_keyframes(keyframes, sample_dt=sample_dt)
expert = expert.float() motor_to_idx = {name: idx for idx, name in enumerate(motor_names)}
if expert.ndim == 1: joint_table = torch.zeros((motor_table.shape[0], len(joint_names)), dtype=torch.float32)
expert = expert.unsqueeze(0)
if expert.ndim != 2: sign_flip_bases = {"Shoulder_Roll", "Elbow_Yaw", "Hip_Roll", "Hip_Yaw", "Ankle_Roll"}
return None for j, joint_name in enumerate(joint_names):
if expert.shape[1] != feature_dim: alias = "Head_yaw" if joint_name == "AAHead_yaw" else joint_name
return None base = alias.replace("Left_", "").replace("Right_", "")
if expert.shape[0] < 2: src = None
return None if alias in motor_to_idx:
if expert.shape[0] < fallback_samples: src = alias
reps = int((fallback_samples + expert.shape[0] - 1) // expert.shape[0]) elif base in motor_to_idx:
expert = expert.repeat(reps, 1) src = base
return expert.to(device=device) if src is None:
continue
sign = -1.0 if alias.startswith("Right_") and base in sign_flip_bases else 1.0
joint_table[:, j] = sign * motor_table[:, motor_to_idx[src]]
duration_s = max(float(motor_table.shape[0] - 1) * sample_dt, sample_dt)
return joint_table, duration_s
def _get_amp_state( def _get_keyframe_motion_cache(
env: ManagerBasedRLEnv, env: ManagerBasedRLEnv,
amp_enabled: bool, front_motion_path: str,
amp_model_path: str, back_motion_path: str,
amp_train_enabled: bool, sample_dt: float,
amp_expert_features_path: str,
feature_dim: int,
disc_hidden_dim: int,
disc_hidden_layers: int,
disc_lr: float,
disc_weight_decay: float,
disc_min_expert_samples: int,
): ):
"""Get cached AMP state (frozen jit or trainable discriminator).""" """Cache interpolated front/back get-up motion priors on env device."""
cache_key = "amp_state_cache" cache_key = "getup_keyframe_motion_cache"
hidden_layers = max(int(disc_hidden_layers), 1) sig = (str(front_motion_path), str(back_motion_path), float(sample_dt))
hidden_dim = max(int(disc_hidden_dim), 16)
state_sig = (
bool(amp_enabled),
str(amp_model_path),
bool(amp_train_enabled),
str(amp_expert_features_path),
int(feature_dim),
hidden_dim,
hidden_layers,
float(disc_lr),
float(disc_weight_decay),
)
cached = env.extras.get(cache_key, None) cached = env.extras.get(cache_key, None)
if isinstance(cached, dict) and cached.get("sig") == state_sig: if isinstance(cached, dict) and cached.get("sig") == sig:
return cached return cached
state = { front_table, front_duration = _motion_table_from_yaml(front_motion_path, T1_JOINT_NAMES, sample_dt)
"sig": state_sig, back_table, back_duration = _motion_table_from_yaml(back_motion_path, T1_JOINT_NAMES, sample_dt)
"mode": "disabled",
"model": None, cache = {
"optimizer": None, "sig": sig,
"expert_features": None, "sample_dt": float(sample_dt),
"step": 0, "front_motion": front_table.to(device=env.device),
"last_loss": 0.0, "back_motion": back_table.to(device=env.device),
"last_acc_policy": 0.0, "front_duration": float(front_duration),
"last_acc_expert": 0.0, "back_duration": float(back_duration),
} }
env.extras[cache_key] = cache
if amp_train_enabled: return cache
expert_features = _load_amp_expert_features(
amp_expert_features_path,
device=env.device,
feature_dim=feature_dim,
fallback_samples=max(disc_min_expert_samples, 512),
)
if expert_features is not None:
model = AMPDiscriminator(input_dim=feature_dim, hidden_dims=tuple([hidden_dim] * hidden_layers)).to(env.device)
optimizer = torch.optim.AdamW(model.parameters(), lr=float(disc_lr), weight_decay=float(disc_weight_decay))
state["mode"] = "trainable"
state["model"] = model
state["optimizer"] = optimizer
state["expert_features"] = expert_features
elif amp_enabled and amp_model_path:
model_path = Path(amp_model_path).expanduser()
if model_path.is_file():
try:
model = torch.jit.load(str(model_path), map_location=env.device)
model.eval()
state["mode"] = "jit"
state["model"] = model
except Exception:
pass
env.extras[cache_key] = state
return state
def _build_amp_features(env: ManagerBasedRLEnv, feature_clip: float = 8.0) -> torch.Tensor: def keyframe_motion_prior_reward(
"""Build AMP-style discriminator features from robot kinematics."""
robot_data = env.scene["robot"].data
joint_pos_rel = robot_data.joint_pos - robot_data.default_joint_pos
joint_vel = robot_data.joint_vel
root_lin_vel = robot_data.root_lin_vel_w
root_ang_vel = robot_data.root_ang_vel_w
projected_gravity = robot_data.projected_gravity_b
amp_features = torch.cat([joint_pos_rel, joint_vel, root_lin_vel, root_ang_vel, projected_gravity], dim=-1)
amp_features = _safe_tensor(amp_features, nan=0.0, pos=feature_clip, neg=-feature_clip)
return torch.clamp(amp_features, min=-feature_clip, max=feature_clip)
def amp_style_prior_reward(
env: ManagerBasedRLEnv, env: ManagerBasedRLEnv,
amp_enabled: bool = False, front_motion_path: str = _DEFAULT_FRONT_KEYFRAME_YAML,
amp_model_path: str = "", back_motion_path: str = _DEFAULT_BACK_KEYFRAME_YAML,
amp_train_enabled: bool = False, sample_dt: float = 0.04,
amp_expert_features_path: str = "", pose_sigma: float = 0.42,
disc_hidden_dim: int = 256, vel_sigma: float = 1.6,
disc_hidden_layers: int = 2, joint_subset: str = "all",
disc_lr: float = 3e-4,
disc_weight_decay: float = 1e-6,
disc_update_interval: int = 4,
disc_batch_size: int = 1024,
disc_min_expert_samples: int = 2048,
feature_clip: float = 8.0,
logit_scale: float = 1.0,
amp_reward_gain: float = 1.0,
internal_reward_scale: float = 1.0, internal_reward_scale: float = 1.0,
) -> torch.Tensor: ) -> torch.Tensor:
"""AMP style prior reward with optional online discriminator training.""" """
zeros = torch.zeros(env.num_envs, device=env.device) DeepMimic-style dense reward from keyframe get-up motions.
amp_score = zeros - mode=1 uses front sequence
model_loaded = 0.0 - mode=0 uses back sequence
amp_train_active = 0.0 """
disc_loss = 0.0 motion_cache = _get_keyframe_motion_cache(
disc_acc_policy = 0.0
disc_acc_expert = 0.0
amp_features = _build_amp_features(env, feature_clip=feature_clip)
amp_state = _get_amp_state(
env=env, env=env,
amp_enabled=amp_enabled, front_motion_path=front_motion_path,
amp_model_path=amp_model_path, back_motion_path=back_motion_path,
amp_train_enabled=amp_train_enabled, sample_dt=sample_dt,
amp_expert_features_path=amp_expert_features_path,
feature_dim=int(amp_features.shape[-1]),
disc_hidden_dim=disc_hidden_dim,
disc_hidden_layers=disc_hidden_layers,
disc_lr=disc_lr,
disc_weight_decay=disc_weight_decay,
disc_min_expert_samples=disc_min_expert_samples,
) )
discriminator = amp_state.get("model", None)
if discriminator is not None:
model_loaded = 1.0
if amp_state.get("mode") == "trainable" and discriminator is not None: # env.extras["getup_mode"]: 1=front, 0=back
amp_train_active = 1.0 getup_mode = env.extras.get("getup_mode", None)
optimizer = amp_state.get("optimizer", None) if not isinstance(getup_mode, torch.Tensor) or getup_mode.shape[0] != env.num_envs:
expert_features = amp_state.get("expert_features", None) getup_mode = torch.zeros(env.num_envs, device=env.device, dtype=torch.long)
amp_state["step"] = int(amp_state.get("step", 0)) + 1 env.extras["getup_mode"] = getup_mode
update_interval = max(int(disc_update_interval), 1) getup_mode = getup_mode.to(dtype=torch.long)
batch_size = max(int(disc_batch_size), 32) use_front = getup_mode == 1
if optimizer is not None and isinstance(expert_features, torch.Tensor) and amp_state["step"] % update_interval == 0: step_dt = env.step_dt
policy_features = amp_features.detach() phase_time = torch.clamp(env.episode_length_buf * step_dt, min=0.0)
policy_count = policy_features.shape[0]
if policy_count > batch_size:
policy_ids = torch.randint(0, policy_count, (batch_size,), device=env.device)
policy_batch = policy_features.index_select(0, policy_ids)
else:
policy_batch = policy_features
expert_count = expert_features.shape[0] front_motion = motion_cache["front_motion"]
expert_ids = torch.randint(0, expert_count, (policy_batch.shape[0],), device=env.device) back_motion = motion_cache["back_motion"]
expert_batch = expert_features.index_select(0, expert_ids) front_idx = torch.clamp((phase_time / sample_dt).to(torch.long), min=0, max=front_motion.shape[0] - 1)
back_idx = torch.clamp((phase_time / sample_dt).to(torch.long), min=0, max=back_motion.shape[0] - 1)
discriminator.train() target_front = front_motion.index_select(0, front_idx)
optimizer.zero_grad(set_to_none=True) target_back = back_motion.index_select(0, back_idx)
logits_expert = discriminator(expert_batch).squeeze(-1) target_pos = torch.where(use_front.unsqueeze(-1), target_front, target_back)
logits_policy = discriminator(policy_batch).squeeze(-1)
loss_expert = nn.functional.binary_cross_entropy_with_logits(logits_expert, torch.ones_like(logits_expert))
loss_policy = nn.functional.binary_cross_entropy_with_logits(logits_policy, torch.zeros_like(logits_policy))
loss = 0.5 * (loss_expert + loss_policy)
loss.backward()
nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0)
optimizer.step()
discriminator.eval()
with torch.no_grad(): robot_data = env.scene["robot"].data
disc_loss = float(loss.detach().item()) current_pos = _safe_tensor(robot_data.joint_pos)
disc_acc_expert = float((torch.sigmoid(logits_expert) > 0.5).float().mean().item()) current_vel = _safe_tensor(robot_data.joint_vel)
disc_acc_policy = float((torch.sigmoid(logits_policy) < 0.5).float().mean().item())
amp_state["last_loss"] = disc_loss
amp_state["last_acc_expert"] = disc_acc_expert
amp_state["last_acc_policy"] = disc_acc_policy
else:
disc_loss = float(amp_state.get("last_loss", 0.0))
disc_acc_expert = float(amp_state.get("last_acc_expert", 0.0))
disc_acc_policy = float(amp_state.get("last_acc_policy", 0.0))
if discriminator is not None: if joint_subset == "legs":
discriminator.eval() legs_idx, _ = env.scene["robot"].find_joints(".*(Hip|Knee|Ankle).*")
with torch.no_grad(): if len(legs_idx) > 0:
logits = discriminator(amp_features) ids = torch.tensor(legs_idx, device=env.device, dtype=torch.long)
if isinstance(logits, (tuple, list)): current_pos = current_pos.index_select(1, ids)
logits = logits[0] current_vel = current_vel.index_select(1, ids)
if logits.ndim > 1: target_pos = target_pos.index_select(1, ids)
logits = logits.squeeze(-1) elif joint_subset == "core":
logits = _safe_tensor(logits, nan=0.0, pos=20.0, neg=-20.0) core_idx, _ = env.scene["robot"].find_joints(".*(Waist|Hip|Knee|Ankle).*")
amp_score = torch.sigmoid(logit_scale * logits) if len(core_idx) > 0:
amp_score = _safe_tensor(amp_score, nan=0.0, pos=1.0, neg=0.0) ids = torch.tensor(core_idx, device=env.device, dtype=torch.long)
current_pos = current_pos.index_select(1, ids)
current_vel = current_vel.index_select(1, ids)
target_pos = target_pos.index_select(1, ids)
amp_reward = _safe_tensor(amp_reward_gain * amp_score, nan=0.0, pos=10.0, neg=0.0) target_vel = torch.zeros_like(target_pos)
pos_mse = torch.mean(torch.square(current_pos - target_pos), dim=-1)
vel_mse = torch.mean(torch.square(current_vel - target_vel), dim=-1)
pose_sigma = max(float(pose_sigma), 1e-3)
vel_sigma = max(float(vel_sigma), 1e-3)
pose_reward = torch.exp(-pos_mse / pose_sigma)
vel_reward = torch.exp(-vel_mse / vel_sigma)
prior_reward = 0.75 * pose_reward + 0.25 * vel_reward
prior_reward = _safe_tensor(prior_reward, nan=0.0, pos=1.0, neg=0.0)
log_dict = env.extras.get("log", {}) log_dict = env.extras.get("log", {})
if isinstance(log_dict, dict): if isinstance(log_dict, dict):
log_dict["amp_score_mean"] = torch.mean(amp_score).detach().item() log_dict["keyframe_prior_mean"] = torch.mean(prior_reward).detach().item()
log_dict["amp_reward_mean"] = torch.mean(amp_reward).detach().item() log_dict["keyframe_front_ratio"] = torch.mean(use_front.float()).detach().item()
log_dict["amp_model_loaded_mean"] = model_loaded
log_dict["amp_train_active_mean"] = amp_train_active
log_dict["amp_disc_loss_mean"] = disc_loss
log_dict["amp_disc_acc_policy_mean"] = disc_acc_policy
log_dict["amp_disc_acc_expert_mean"] = disc_acc_expert
env.extras["log"] = log_dict env.extras["log"] = log_dict
return internal_reward_scale * amp_reward return internal_reward_scale * prior_reward
def root_height_obs(env: ManagerBasedRLEnv) -> torch.Tensor: def root_height_obs(env: ManagerBasedRLEnv) -> torch.Tensor:
@@ -365,6 +314,10 @@ def reset_root_state_bimodal_lie_pose(
-torch.ones(num_resets, device=env.device), -torch.ones(num_resets, device=env.device),
) )
euler_angles[:, 1] = pitch_mag * pitch_sign euler_angles[:, 1] = pitch_mag * pitch_sign
# Cache get-up mode for motion priors: pitch<0 => front, pitch>=0 => back.
if "getup_mode" not in env.extras or not isinstance(env.extras.get("getup_mode"), torch.Tensor):
env.extras["getup_mode"] = torch.zeros(env.num_envs, device=env.device, dtype=torch.long)
env.extras["getup_mode"][env_ids] = (pitch_sign < 0.0).to(torch.long)
yaw_min, yaw_max = yaw_abs_range yaw_min, yaw_max = yaw_abs_range
yaw_mag = yaw_min + torch.rand(num_resets, device=env.device) * (yaw_max - yaw_min) yaw_mag = yaw_min + torch.rand(num_resets, device=env.device) * (yaw_max - yaw_min)
@@ -778,6 +731,19 @@ class T1GetUpRewardCfg:
"timer_name": "reward_stable_timer", "timer_name": "reward_stable_timer",
}, },
) )
keyframe_motion_prior = RewTerm(
func=keyframe_motion_prior_reward,
weight=0.0,
params={
"front_motion_path": _DEFAULT_FRONT_KEYFRAME_YAML,
"back_motion_path": _DEFAULT_BACK_KEYFRAME_YAML,
"sample_dt": 0.04,
"pose_sigma": 0.42,
"vel_sigma": 1.6,
"joint_subset": "all",
"internal_reward_scale": 1.0,
},
)
# AMP reward is disabled by default until a discriminator model path is provided. # AMP reward is disabled by default until a discriminator model path is provided.
amp_style_prior = RewTerm( amp_style_prior = RewTerm(
func=amp_style_prior_reward, func=amp_style_prior_reward,
@@ -794,6 +760,7 @@ class T1GetUpRewardCfg:
"disc_update_interval": 4, "disc_update_interval": 4,
"disc_batch_size": 1024, "disc_batch_size": 1024,
"disc_min_expert_samples": 2048, "disc_min_expert_samples": 2048,
"disc_history_steps": 4,
"feature_clip": 8.0, "feature_clip": 8.0,
"logit_scale": 1.0, "logit_scale": 1.0,
"amp_reward_gain": 1.0, "amp_reward_gain": 1.0,

View File

@@ -42,7 +42,29 @@ parser.add_argument("--amp_disc_lr", type=float, default=3e-4, help="Learning ra
parser.add_argument("--amp_disc_weight_decay", type=float, default=1e-6, help="Weight decay for AMP discriminator.") parser.add_argument("--amp_disc_weight_decay", type=float, default=1e-6, help="Weight decay for AMP discriminator.")
parser.add_argument("--amp_disc_update_interval", type=int, default=4, help="Train discriminator every N reward calls.") parser.add_argument("--amp_disc_update_interval", type=int, default=4, help="Train discriminator every N reward calls.")
parser.add_argument("--amp_disc_batch_size", type=int, default=1024, help="Discriminator train batch size.") parser.add_argument("--amp_disc_batch_size", type=int, default=1024, help="Discriminator train batch size.")
parser.add_argument("--amp_disc_history_steps", type=int, default=4, help="Temporal history steps for AMP discriminator.")
parser.add_argument("--amp_logit_scale", type=float, default=1.0, help="Scale before sigmoid(logits) for AMP score.") parser.add_argument("--amp_logit_scale", type=float, default=1.0, help="Scale before sigmoid(logits) for AMP score.")
parser.add_argument(
"--amp_from_keyframes",
action="store_true",
help="Generate AMP expert features from get-up keyframe yaml files and enable online discriminator training.",
)
parser.add_argument(
"--amp_keyframe_front",
type=str,
default=os.path.join(PROJECT_ROOT, "behaviors", "custom", "keyframe", "get_up", "get_up_front.yaml"),
help="Front get-up keyframe yaml path for AMP expert generation.",
)
parser.add_argument(
"--amp_keyframe_back",
type=str,
default=os.path.join(PROJECT_ROOT, "behaviors", "custom", "keyframe", "get_up", "get_up_back.yaml"),
help="Back get-up keyframe yaml path for AMP expert generation.",
)
parser.add_argument("--amp_keyframe_dt", type=float, default=0.04, help="Resampling dt for keyframe AMP expert features.")
parser.add_argument("--amp_keyframe_repeat", type=int, default=16, help="Repeat count for each keyframe sequence.")
parser.add_argument("--keyframe_prior_weight", type=float, default=1.0, help="Weight for keyframe motion prior reward.")
parser.add_argument("--disable_keyframe_prior", action="store_true", help="Disable keyframe motion prior reward.")
AppLauncher.add_app_launcher_args(parser) AppLauncher.add_app_launcher_args(parser)
args_cli = parser.parse_args() args_cli = parser.parse_args()
@@ -56,7 +78,8 @@ from rl_games.common.algo_observer import DefaultAlgoObserver
from rl_games.torch_runner import Runner from rl_games.torch_runner import Runner
from rl_games.common import env_configurations, vecenv from rl_games.common import env_configurations, vecenv
from rl_game.get_up.config.t1_env_cfg import T1EnvCfg from rl_game.get_up.amp.amp_motion import build_amp_expert_features_from_getup_keyframes
from rl_game.get_up.config.t1_env_cfg import T1EnvCfg, T1_JOINT_NAMES
class T1MetricObserver(DefaultAlgoObserver): class T1MetricObserver(DefaultAlgoObserver):
@@ -77,6 +100,8 @@ class T1MetricObserver(DefaultAlgoObserver):
"amp_disc_loss_mean", "amp_disc_loss_mean",
"amp_disc_acc_policy_mean", "amp_disc_acc_policy_mean",
"amp_disc_acc_expert_mean", "amp_disc_acc_expert_mean",
"keyframe_prior_mean",
"keyframe_front_ratio",
) )
self._metric_sums: dict[str, float] = {} self._metric_sums: dict[str, float] = {}
self._metric_counts: dict[str, int] = {} self._metric_counts: dict[str, int] = {}
@@ -202,10 +227,33 @@ def main():
task_id = "Isaac-T1-GetUp-v0" task_id = "Isaac-T1-GetUp-v0"
env_cfg = T1EnvCfg() env_cfg = T1EnvCfg()
if args_cli.disable_keyframe_prior:
env_cfg.rewards.keyframe_motion_prior.weight = 0.0
print("[INFO]: keyframe motion prior disabled")
else:
env_cfg.rewards.keyframe_motion_prior.weight = float(args_cli.keyframe_prior_weight)
print(f"[INFO]: keyframe motion prior weight={env_cfg.rewards.keyframe_motion_prior.weight:.3f}")
if args_cli.amp_from_keyframes:
auto_feature_path = os.path.join(os.path.dirname(__file__), "logs", "amp", "expert_features_from_keyframes.pt")
generated_path, feature_shape = build_amp_expert_features_from_getup_keyframes(
front_yaml_path=args_cli.amp_keyframe_front,
back_yaml_path=args_cli.amp_keyframe_back,
joint_names=T1_JOINT_NAMES,
output_path=auto_feature_path,
sample_dt=float(args_cli.amp_keyframe_dt),
repeat_count=int(args_cli.amp_keyframe_repeat),
)
args_cli.amp_expert_features = generated_path
args_cli.amp_train_discriminator = True
print(f"[INFO]: AMP expert features generated at {generated_path}, shape={feature_shape}")
amp_cfg = env_cfg.rewards.amp_style_prior amp_cfg = env_cfg.rewards.amp_style_prior
amp_cfg.params["logit_scale"] = float(args_cli.amp_logit_scale) amp_cfg.params["logit_scale"] = float(args_cli.amp_logit_scale)
if args_cli.amp_train_discriminator: if args_cli.amp_train_discriminator:
expert_path = os.path.abspath(os.path.expanduser(args_cli.amp_expert_features)) if args_cli.amp_expert_features else "" expert_path = os.path.abspath(os.path.expanduser(args_cli.amp_expert_features)) if args_cli.amp_expert_features else ""
if not expert_path:
raise ValueError("--amp_train_discriminator requires --amp_expert_features or --amp_from_keyframes.")
amp_cfg.weight = float(args_cli.amp_reward_weight) amp_cfg.weight = float(args_cli.amp_reward_weight)
amp_cfg.params["amp_train_enabled"] = True amp_cfg.params["amp_train_enabled"] = True
amp_cfg.params["amp_enabled"] = False amp_cfg.params["amp_enabled"] = False
@@ -216,8 +264,10 @@ def main():
amp_cfg.params["disc_weight_decay"] = float(args_cli.amp_disc_weight_decay) amp_cfg.params["disc_weight_decay"] = float(args_cli.amp_disc_weight_decay)
amp_cfg.params["disc_update_interval"] = int(args_cli.amp_disc_update_interval) amp_cfg.params["disc_update_interval"] = int(args_cli.amp_disc_update_interval)
amp_cfg.params["disc_batch_size"] = int(args_cli.amp_disc_batch_size) amp_cfg.params["disc_batch_size"] = int(args_cli.amp_disc_batch_size)
print(f"[INFO]: AMP online discriminator enabled, expert_features={expert_path or '<missing>'}") amp_cfg.params["disc_history_steps"] = int(args_cli.amp_disc_history_steps)
print(f"[INFO]: AMP online discriminator enabled, expert_features={expert_path}")
print(f"[INFO]: AMP reward weight={amp_cfg.weight:.3f}") print(f"[INFO]: AMP reward weight={amp_cfg.weight:.3f}")
print(f"[INFO]: AMP discriminator history_steps={amp_cfg.params['disc_history_steps']}")
elif args_cli.amp_model: elif args_cli.amp_model:
amp_model_path = os.path.abspath(os.path.expanduser(args_cli.amp_model)) amp_model_path = os.path.abspath(os.path.expanduser(args_cli.amp_model))
amp_cfg.weight = float(args_cli.amp_reward_weight) amp_cfg.weight = float(args_cli.amp_reward_weight)