Add AMP get-up pipeline with sequence discriminator and git-sourced expert data

This commit is contained in:
Chen
2026-04-20 15:51:44 +08:00
parent 9e6e7e00f8
commit 995f6522b2
10 changed files with 1226 additions and 443 deletions

View File

@@ -0,0 +1,62 @@
# AMP Tools for `get_up`
This folder contains all AMP-related code for `rl_game/get_up`.
## Files
- `amp_rewards.py`: AMP discriminator + reward function used by training config.
- `amp_motion.py`: Build AMP expert features from local get-up keyframe YAML files.
- `migrate_legged_lab_expert_template.py`: Template converter for migrating external expert data (for example legged_lab outputs) to `expert_features.pt`.
## Quick start
Generate expert features from current local keyframes:
```bash
python rl_game/get_up/train.py --amp_from_keyframes --headless
```
Convert external motion/expert file to AMP template:
```bash
python rl_game/get_up/amp/migrate_legged_lab_expert_template.py \
--input /path/to/source_data.pt \
--output rl_game/get_up/amp/expert_features.pt \
--input_key expert_features \
--feature_dim 55 \
--repeat 4
```
Convert downloaded `legged_lab` motion pickles directly:
```bash
python rl_game/get_up/amp/migrate_legged_lab_expert_template.py \
--input third_party/legged_lab/source/legged_lab/legged_lab/data/MotionData/g1_29dof/amp/walk_and_run \
--input_glob "*.pkl" \
--target_dof 23 \
--feature_dim 55 \
--clip_weight_mode uniform \
--output rl_game/get_up/amp/expert_features.pt
```
For get-up data from gitee legged_lab, use git-like focused clip sampling:
```bash
python rl_game/get_up/amp/migrate_legged_lab_expert_template.py \
--input third_party/legged_lab_gitee/source/legged_lab/legged_lab/data/MotionData/g1_29dof/amp/get_up \
--input_glob "*.pkl" \
--target_dof 23 \
--feature_dim 55 \
--clip_weight_mode git_getup_focus \
--output rl_game/get_up/amp/expert_features.pt
```
Then train with online AMP discriminator:
```bash
python rl_game/get_up/train.py \
--amp_train_discriminator \
--amp_expert_features rl_game/get_up/amp/expert_features.pt \
--amp_reward_weight 0.6 \
--headless
```

View File

@@ -0,0 +1,7 @@
from .amp_motion import build_amp_expert_features_from_getup_keyframes
from .amp_rewards import amp_style_prior_reward
__all__ = [
"amp_style_prior_reward",
"build_amp_expert_features_from_getup_keyframes",
]

View File

@@ -0,0 +1,186 @@
from __future__ import annotations
from pathlib import Path
from typing import Any
import torch
import yaml
RIGHT_JOINT_SIGN_FLIP = {
"Shoulder_Roll",
"Elbow_Yaw",
"Hip_Roll",
"Hip_Yaw",
"Ankle_Roll",
}
JOINT_NAME_ALIAS = {
"AAHead_yaw": "Head_yaw",
}
def _safe_load_yaml(path: Path) -> dict[str, Any]:
with path.open("r", encoding="utf-8") as f:
payload = yaml.safe_load(f) or {}
if not isinstance(payload, dict):
raise ValueError(f"Invalid keyframe yaml: {path}")
return payload
def _build_interpolated_motor_table(
keyframes: list[dict[str, Any]],
sample_dt: float,
) -> tuple[torch.Tensor, list[str], torch.Tensor]:
if len(keyframes) == 0:
raise ValueError("No keyframes in yaml")
durations: list[float] = []
motor_names: set[str] = set()
for frame in keyframes:
durations.append(float(frame.get("delta", 0.1)))
motors = frame.get("motor_positions", {})
if isinstance(motors, dict):
motor_names.update(motors.keys())
names = sorted(motor_names)
if len(names) == 0:
raise ValueError("No motor_positions found in keyframes")
frame_count = len(keyframes)
motor_count = len(names)
times = torch.zeros(frame_count, dtype=torch.float32)
values = torch.full((frame_count, motor_count), float("nan"), dtype=torch.float32)
elapsed = 0.0
name_to_idx = {name: idx for idx, name in enumerate(names)}
for i, frame in enumerate(keyframes):
elapsed += max(float(frame.get("delta", 0.1)), 1e-4)
times[i] = elapsed
motors = frame.get("motor_positions", {})
if not isinstance(motors, dict):
continue
for motor_name, motor_deg in motors.items():
if motor_name not in name_to_idx:
continue
values[i, name_to_idx[motor_name]] = float(motor_deg) * (torch.pi / 180.0)
first_valid = torch.nan_to_num(values[0], nan=0.0)
values[0] = first_valid
for i in range(1, frame_count):
cur = values[i]
prev = values[i - 1]
values[i] = torch.where(torch.isnan(cur), prev, cur)
sample_dt = max(float(sample_dt), 1e-3)
sample_times = torch.arange(0.0, float(times[-1]) + 1e-6, sample_dt, dtype=torch.float32)
sample_times[0] = max(sample_times[0], 1e-6)
sample_times = torch.clamp(sample_times, min=times[0], max=times[-1])
upper = torch.searchsorted(times, sample_times, right=True)
upper = torch.clamp(upper, min=1, max=frame_count - 1)
lower = upper - 1
t0 = times.index_select(0, lower)
t1 = times.index_select(0, upper)
v0 = values.index_select(0, lower)
v1 = values.index_select(0, upper)
alpha = ((sample_times - t0) / (t1 - t0 + 1e-6)).unsqueeze(-1)
interpolated = v0 + alpha * (v1 - v0)
return interpolated, names, sample_times
def _map_motors_to_joint_targets(
motor_table: torch.Tensor,
motor_names: list[str],
joint_names: list[str],
) -> torch.Tensor:
motor_to_idx = {name: idx for idx, name in enumerate(motor_names)}
joint_targets = torch.zeros((motor_table.shape[0], len(joint_names)), dtype=torch.float32)
for joint_idx, joint_name in enumerate(joint_names):
alias_name = JOINT_NAME_ALIAS.get(joint_name, joint_name)
base_name = alias_name.replace("Left_", "").replace("Right_", "")
source_name = None
if alias_name in motor_to_idx:
source_name = alias_name
elif base_name in motor_to_idx:
source_name = base_name
if source_name is None:
continue
sign = 1.0
if alias_name.startswith("Right_") and base_name in RIGHT_JOINT_SIGN_FLIP:
sign = -1.0
joint_targets[:, joint_idx] = sign * motor_table[:, motor_to_idx[source_name]]
return joint_targets
def _features_from_keyframes(
keyframe_yaml_path: str,
joint_names: list[str],
sample_dt: float,
repeat_count: int,
) -> torch.Tensor:
payload = _safe_load_yaml(Path(keyframe_yaml_path).expanduser().resolve())
keyframes = payload.get("keyframes", [])
if not isinstance(keyframes, list):
raise ValueError(f"Invalid keyframe list in {keyframe_yaml_path}")
motor_table, motor_names, _ = _build_interpolated_motor_table(keyframes, sample_dt=sample_dt)
joint_pos = _map_motors_to_joint_targets(motor_table, motor_names, joint_names)
joint_vel = torch.zeros_like(joint_pos)
if joint_pos.shape[0] > 1:
joint_vel[1:] = (joint_pos[1:] - joint_pos[:-1]) / max(sample_dt, 1e-3)
joint_vel[0] = joint_vel[1]
root_lin = torch.zeros((joint_pos.shape[0], 3), dtype=torch.float32)
root_ang = torch.zeros((joint_pos.shape[0], 3), dtype=torch.float32)
projected_gravity = torch.zeros((joint_pos.shape[0], 3), dtype=torch.float32)
projected_gravity[:, 2] = -1.0
features = torch.cat([joint_pos, joint_vel, root_lin, root_ang, projected_gravity], dim=-1)
repeat_count = max(int(repeat_count), 1)
if repeat_count > 1:
features = features.repeat(repeat_count, 1)
return features
def build_amp_expert_features_from_getup_keyframes(
*,
front_yaml_path: str,
back_yaml_path: str,
joint_names: list[str],
output_path: str,
sample_dt: float = 0.04,
repeat_count: int = 16,
) -> tuple[str, tuple[int, int]]:
"""Generate AMP expert feature tensor from front/back get-up keyframe yaml files."""
front_features = _features_from_keyframes(
keyframe_yaml_path=front_yaml_path,
joint_names=joint_names,
sample_dt=sample_dt,
repeat_count=repeat_count,
)
back_features = _features_from_keyframes(
keyframe_yaml_path=back_yaml_path,
joint_names=joint_names,
sample_dt=sample_dt,
repeat_count=repeat_count,
)
expert_features = torch.cat([front_features, back_features], dim=0).contiguous()
out_path = Path(output_path).expanduser().resolve()
out_path.parent.mkdir(parents=True, exist_ok=True)
torch.save(
{
"expert_features": expert_features,
"meta": {
"front_yaml_path": str(Path(front_yaml_path).expanduser().resolve()),
"back_yaml_path": str(Path(back_yaml_path).expanduser().resolve()),
"sample_dt": float(sample_dt),
"repeat_count": int(repeat_count),
"feature_dim": int(expert_features.shape[-1]),
},
},
str(out_path),
)
return str(out_path), (int(expert_features.shape[0]), int(expert_features.shape[1]))

View File

@@ -0,0 +1,457 @@
from __future__ import annotations
import pickle
from pathlib import Path
import joblib
import numpy as np
import torch
import torch.nn as nn
from isaaclab.envs import ManagerBasedRLEnv
def _safe_tensor(x: torch.Tensor, nan: float = 0.0, pos: float = 1e3, neg: float = -1e3) -> torch.Tensor:
return torch.nan_to_num(x, nan=nan, posinf=pos, neginf=neg)
class AMPDiscriminator(nn.Module):
"""Lightweight discriminator used by online AMP updates."""
def __init__(self, input_dim: int, hidden_dims: tuple[int, ...]):
super().__init__()
layers: list[nn.Module] = []
in_dim = input_dim
for h_dim in hidden_dims:
layers.append(nn.Linear(in_dim, h_dim))
layers.append(nn.LayerNorm(h_dim))
layers.append(nn.SiLU())
in_dim = h_dim
layers.append(nn.Linear(in_dim, 1))
self.net = nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)
def _to_tensor_2d(value) -> torch.Tensor | None:
if isinstance(value, torch.Tensor):
t = value.float()
elif isinstance(value, np.ndarray):
t = torch.as_tensor(value, dtype=torch.float32)
elif isinstance(value, list):
try:
t = torch.as_tensor(value, dtype=torch.float32)
except Exception:
return None
else:
return None
if t.ndim == 1:
t = t.unsqueeze(0)
if t.ndim != 2:
return None
return t
def _normalize_feature_dim(x: torch.Tensor, feature_dim: int) -> torch.Tensor:
if x.shape[-1] == feature_dim:
return x
if x.shape[-1] > feature_dim:
return x[:, :feature_dim]
pad = torch.zeros((x.shape[0], feature_dim - x.shape[1]), dtype=x.dtype)
return torch.cat([x, pad], dim=-1)
def _extract_expert_bank_from_payload(payload, feature_dim: int) -> dict | None:
clip_tensors: list[torch.Tensor] = []
clip_names: list[str] = []
clip_weights_tensor: torch.Tensor | None = None
if isinstance(payload, dict):
clip_values = payload.get("expert_clips", None)
if isinstance(clip_values, list):
for i, clip_value in enumerate(clip_values):
clip_tensor = _to_tensor_2d(clip_value)
if clip_tensor is None:
continue
clip_tensors.append(_normalize_feature_dim(clip_tensor, feature_dim))
clip_names.append(f"clip_{i}")
raw_names = payload.get("clip_names", None)
if isinstance(raw_names, list) and len(raw_names) == len(clip_tensors):
clip_names = [str(n) for n in raw_names]
raw_weights = payload.get("clip_weights", None)
clip_weights_tensor = _to_tensor_2d(raw_weights) if raw_weights is not None else None
if clip_weights_tensor is not None:
clip_weights_tensor = clip_weights_tensor.reshape(-1)
if clip_weights_tensor.shape[0] != len(clip_tensors):
clip_weights_tensor = None
for key in ("expert_features", "features", "obs"):
value = payload.get(key, None)
tensor = _to_tensor_2d(value)
if tensor is not None:
clip_tensors.append(_normalize_feature_dim(tensor, feature_dim))
clip_names.append(key)
break
else:
tensor = _to_tensor_2d(payload)
if tensor is not None:
clip_tensors.append(_normalize_feature_dim(tensor, feature_dim))
clip_names.append("expert_features")
clip_tensors = [c for c in clip_tensors if c.shape[0] >= 2]
if len(clip_tensors) == 0:
return None
if clip_weights_tensor is None:
clip_weights_tensor = torch.ones(len(clip_tensors), dtype=torch.float32)
else:
clip_weights_tensor = torch.clamp(clip_weights_tensor.float(), min=0.0)
if float(torch.sum(clip_weights_tensor).item()) <= 0.0:
clip_weights_tensor = torch.ones(len(clip_tensors), dtype=torch.float32)
flat_features = torch.cat(clip_tensors, dim=0)
return {
"flat_features": flat_features,
"clips": clip_tensors,
"clip_names": clip_names,
"clip_weights": clip_weights_tensor,
}
def _load_amp_expert_features(
expert_features_path: str,
device: str,
feature_dim: int,
fallback_samples: int,
) -> dict | None:
"""Load expert AMP features bank. Returns None when file is unavailable."""
if not expert_features_path:
return None
p = Path(expert_features_path).expanduser()
if not p.is_file():
return None
try:
payload = torch.load(str(p), map_location="cpu")
except Exception:
try:
with p.open("rb") as f:
payload = pickle.load(f)
except Exception:
try:
payload = joblib.load(str(p))
except Exception:
return None
bank = _extract_expert_bank_from_payload(payload, feature_dim=feature_dim)
if bank is None:
return None
flat_features = bank["flat_features"].float()
if flat_features.shape[0] < fallback_samples:
reps = int((fallback_samples + flat_features.shape[0] - 1) // flat_features.shape[0])
flat_features = flat_features.repeat(reps, 1)
clip_tensors = [clip.float() for clip in bank["clips"]]
clip_weights = bank["clip_weights"].float()
clip_weights = torch.clamp(clip_weights, min=0.0)
if float(torch.sum(clip_weights).item()) <= 0.0:
clip_weights = torch.ones_like(clip_weights)
clip_weights = clip_weights / torch.sum(clip_weights)
return {
"flat_features": flat_features.to(device=device),
"clips": [c.to(device=device) for c in clip_tensors],
"clip_names": bank["clip_names"],
"clip_weights": clip_weights.to(device=device),
}
def _policy_sequence_features(
env: ManagerBasedRLEnv,
current_features: torch.Tensor,
history_steps: int,
) -> torch.Tensor:
"""Build rolling policy sequence features [N, H, D]."""
history_steps = max(int(history_steps), 1)
cache_key = "amp_policy_hist_cache"
cache = env.extras.get(cache_key, None)
if not isinstance(cache, dict):
cache = {}
env.extras[cache_key] = cache
hist = cache.get("hist", None)
if (
not isinstance(hist, torch.Tensor)
or hist.shape[0] != env.num_envs
or hist.shape[1] != history_steps
or hist.shape[2] != current_features.shape[1]
):
hist = current_features.unsqueeze(1).repeat(1, history_steps, 1)
if history_steps > 1:
hist[:, :-1] = hist[:, 1:].clone()
hist[:, -1] = current_features
cache["hist"] = hist
env.extras[cache_key] = cache
return hist
def _sample_expert_sequence_batch(
expert_bank: dict,
batch_size: int,
history_steps: int,
device: str,
) -> torch.Tensor:
"""Sample expert sequence batch [B, H, D] with clip-weighted sampling."""
clips = expert_bank["clips"]
clip_weights = expert_bank["clip_weights"]
clip_count = len(clips)
if clip_count <= 0:
return torch.empty((0, history_steps, 0), device=device)
clip_ids = torch.multinomial(clip_weights, num_samples=batch_size, replacement=True)
seq_list: list[torch.Tensor] = []
for clip_id in clip_ids.tolist():
clip = clips[int(clip_id)]
clip_len = int(clip.shape[0])
if clip_len >= history_steps:
max_start = clip_len - history_steps
if max_start > 0:
start = int(torch.randint(0, max_start + 1, (1,), device=device).item())
else:
start = 0
seq = clip[start : start + history_steps]
else:
pad = clip[-1:].repeat(history_steps - clip_len, 1)
seq = torch.cat([clip, pad], dim=0)
seq_list.append(seq)
return torch.stack(seq_list, dim=0)
def _get_amp_state(
env: ManagerBasedRLEnv,
amp_enabled: bool,
amp_model_path: str,
amp_train_enabled: bool,
amp_expert_features_path: str,
feature_dim: int,
disc_hidden_dim: int,
disc_hidden_layers: int,
disc_lr: float,
disc_weight_decay: float,
disc_min_expert_samples: int,
disc_history_steps: int,
):
"""Get cached AMP state (frozen jit or trainable discriminator)."""
cache_key = "amp_state_cache"
hidden_layers = max(int(disc_hidden_layers), 1)
hidden_dim = max(int(disc_hidden_dim), 16)
history_steps = max(int(disc_history_steps), 1)
state_sig = (
bool(amp_enabled),
str(amp_model_path),
bool(amp_train_enabled),
str(amp_expert_features_path),
int(feature_dim),
history_steps,
hidden_dim,
hidden_layers,
float(disc_lr),
float(disc_weight_decay),
)
cached = env.extras.get(cache_key, None)
if isinstance(cached, dict) and cached.get("sig") == state_sig:
return cached
state = {
"sig": state_sig,
"mode": "disabled",
"model": None,
"optimizer": None,
"expert_bank": None,
"disc_history_steps": history_steps,
"step": 0,
"last_loss": 0.0,
"last_acc_policy": 0.0,
"last_acc_expert": 0.0,
}
if amp_train_enabled:
expert_bank = _load_amp_expert_features(
amp_expert_features_path,
device=env.device,
feature_dim=feature_dim,
fallback_samples=max(disc_min_expert_samples, 512),
)
if expert_bank is not None:
model = AMPDiscriminator(input_dim=feature_dim * history_steps, hidden_dims=tuple([hidden_dim] * hidden_layers)).to(env.device)
optimizer = torch.optim.AdamW(model.parameters(), lr=float(disc_lr), weight_decay=float(disc_weight_decay))
state["mode"] = "trainable"
state["model"] = model
state["optimizer"] = optimizer
state["expert_bank"] = expert_bank
elif amp_enabled and amp_model_path:
model_path = Path(amp_model_path).expanduser()
if model_path.is_file():
try:
model = torch.jit.load(str(model_path), map_location=env.device)
model.eval()
state["mode"] = "jit"
state["model"] = model
except Exception:
pass
env.extras[cache_key] = state
return state
def _build_amp_features(env: ManagerBasedRLEnv, feature_clip: float = 8.0) -> torch.Tensor:
"""Build AMP-style discriminator features from robot kinematics."""
robot_data = env.scene["robot"].data
joint_pos_rel = robot_data.joint_pos - robot_data.default_joint_pos
joint_vel = robot_data.joint_vel
root_lin_vel = robot_data.root_lin_vel_w
root_ang_vel = robot_data.root_ang_vel_w
projected_gravity = robot_data.projected_gravity_b
amp_features = torch.cat([joint_pos_rel, joint_vel, root_lin_vel, root_ang_vel, projected_gravity], dim=-1)
amp_features = _safe_tensor(amp_features, nan=0.0, pos=feature_clip, neg=-feature_clip)
return torch.clamp(amp_features, min=-feature_clip, max=feature_clip)
def amp_style_prior_reward(
env: ManagerBasedRLEnv,
amp_enabled: bool = False,
amp_model_path: str = "",
amp_train_enabled: bool = False,
amp_expert_features_path: str = "",
disc_hidden_dim: int = 256,
disc_hidden_layers: int = 2,
disc_lr: float = 3e-4,
disc_weight_decay: float = 1e-6,
disc_update_interval: int = 4,
disc_batch_size: int = 1024,
disc_min_expert_samples: int = 2048,
disc_history_steps: int = 4,
feature_clip: float = 8.0,
logit_scale: float = 1.0,
amp_reward_gain: float = 1.0,
internal_reward_scale: float = 1.0,
) -> torch.Tensor:
"""AMP style prior reward with optional online discriminator training."""
zeros = torch.zeros(env.num_envs, device=env.device)
amp_score = zeros
model_loaded = 0.0
amp_train_active = 0.0
disc_loss = 0.0
disc_acc_policy = 0.0
disc_acc_expert = 0.0
amp_features = _build_amp_features(env, feature_clip=feature_clip)
amp_state = _get_amp_state(
env=env,
amp_enabled=amp_enabled,
amp_model_path=amp_model_path,
amp_train_enabled=amp_train_enabled,
amp_expert_features_path=amp_expert_features_path,
feature_dim=int(amp_features.shape[-1]),
disc_hidden_dim=disc_hidden_dim,
disc_hidden_layers=disc_hidden_layers,
disc_lr=disc_lr,
disc_weight_decay=disc_weight_decay,
disc_min_expert_samples=disc_min_expert_samples,
disc_history_steps=disc_history_steps,
)
discriminator = amp_state.get("model", None)
history_steps = max(int(amp_state.get("disc_history_steps", disc_history_steps)), 1)
policy_seq = _policy_sequence_features(env, amp_features, history_steps=history_steps)
policy_seq_flat = policy_seq.reshape(policy_seq.shape[0], -1)
if discriminator is not None:
model_loaded = 1.0
if amp_state.get("mode") == "trainable" and discriminator is not None:
amp_train_active = 1.0
optimizer = amp_state.get("optimizer", None)
expert_bank = amp_state.get("expert_bank", None)
amp_state["step"] = int(amp_state.get("step", 0)) + 1
update_interval = max(int(disc_update_interval), 1)
batch_size = max(int(disc_batch_size), 32)
if optimizer is not None and isinstance(expert_bank, dict) and amp_state["step"] % update_interval == 0:
policy_features = policy_seq_flat.detach()
policy_count = policy_features.shape[0]
if policy_count > batch_size:
policy_ids = torch.randint(0, policy_count, (batch_size,), device=env.device)
policy_batch = policy_features.index_select(0, policy_ids)
else:
policy_batch = policy_features
expert_seq = _sample_expert_sequence_batch(
expert_bank=expert_bank,
batch_size=policy_batch.shape[0],
history_steps=history_steps,
device=env.device,
)
expert_batch = expert_seq.reshape(expert_seq.shape[0], -1)
discriminator.train()
optimizer.zero_grad(set_to_none=True)
logits_expert = discriminator(expert_batch).squeeze(-1)
logits_policy = discriminator(policy_batch).squeeze(-1)
loss_expert = nn.functional.binary_cross_entropy_with_logits(logits_expert, torch.ones_like(logits_expert))
loss_policy = nn.functional.binary_cross_entropy_with_logits(logits_policy, torch.zeros_like(logits_policy))
loss = 0.5 * (loss_expert + loss_policy)
loss.backward()
nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0)
optimizer.step()
discriminator.eval()
with torch.no_grad():
disc_loss = float(loss.detach().item())
disc_acc_expert = float((torch.sigmoid(logits_expert) > 0.5).float().mean().item())
disc_acc_policy = float((torch.sigmoid(logits_policy) < 0.5).float().mean().item())
amp_state["last_loss"] = disc_loss
amp_state["last_acc_expert"] = disc_acc_expert
amp_state["last_acc_policy"] = disc_acc_policy
else:
disc_loss = float(amp_state.get("last_loss", 0.0))
disc_acc_expert = float(amp_state.get("last_acc_expert", 0.0))
disc_acc_policy = float(amp_state.get("last_acc_policy", 0.0))
if discriminator is not None:
discriminator.eval()
with torch.no_grad():
logits = discriminator(policy_seq_flat)
if isinstance(logits, (tuple, list)):
logits = logits[0]
if logits.ndim > 1:
logits = logits.squeeze(-1)
elif amp_enabled and amp_model_path:
# For external scripted models, try temporal then fallback to single frame.
model = amp_state.get("model", None)
if model is not None:
with torch.no_grad():
try:
logits = model(policy_seq_flat)
except Exception:
logits = model(amp_features)
if isinstance(logits, (tuple, list)):
logits = logits[0]
if logits.ndim > 1:
logits = logits.squeeze(-1)
logits = _safe_tensor(logits, nan=0.0, pos=20.0, neg=-20.0)
amp_score = torch.sigmoid(logit_scale * logits)
amp_score = _safe_tensor(amp_score, nan=0.0, pos=1.0, neg=0.0)
amp_reward = _safe_tensor(amp_reward_gain * amp_score, nan=0.0, pos=10.0, neg=0.0)
log_dict = env.extras.get("log", {})
if isinstance(log_dict, dict):
log_dict["amp_score_mean"] = torch.mean(amp_score).detach().item()
log_dict["amp_reward_mean"] = torch.mean(amp_reward).detach().item()
log_dict["amp_model_loaded_mean"] = model_loaded
log_dict["amp_train_active_mean"] = amp_train_active
log_dict["amp_disc_loss_mean"] = disc_loss
log_dict["amp_disc_acc_policy_mean"] = disc_acc_policy
log_dict["amp_disc_acc_expert_mean"] = disc_acc_expert
log_dict["amp_disc_history_steps"] = float(history_steps)
env.extras["log"] = log_dict
return internal_reward_scale * amp_reward

View File

@@ -0,0 +1,258 @@
from __future__ import annotations
import argparse
import json
import pickle
from glob import glob
from pathlib import Path
import numpy as np
import joblib
try:
import torch
except Exception:
torch = None
GIT_GETUP_FOCUS_CLIP = "fallAndGetUp2_subject2_1200_1370"
def _load_payload(path: Path):
if path.suffix.lower() in (".pt", ".pth"):
if torch is None:
raise RuntimeError("Loading .pt/.pth requires torch to be installed.")
return torch.load(str(path), map_location="cpu")
if path.suffix.lower() in (".pkl",):
try:
with path.open("rb") as f:
return pickle.load(f)
except Exception:
return joblib.load(str(path))
if path.suffix.lower() in (".json",):
with path.open("r", encoding="utf-8") as f:
return json.load(f)
raise ValueError(f"Unsupported file type: {path.suffix}")
def _to_numpy(payload, key_hint: str = "") -> np.ndarray:
if torch is not None and isinstance(payload, torch.Tensor):
return payload.detach().cpu().numpy().astype(np.float32)
if isinstance(payload, np.ndarray):
return payload.astype(np.float32)
if isinstance(payload, list):
return np.asarray(payload, dtype=np.float32)
if isinstance(payload, dict):
candidate_keys = [key_hint] if key_hint else []
candidate_keys += [
"expert_features",
"features",
"observations",
"obs",
"joint_pos",
"motion",
"data",
]
for key in candidate_keys:
if key and key in payload:
value = payload[key]
if torch is not None and isinstance(value, torch.Tensor):
return value.detach().cpu().numpy().astype(np.float32)
if isinstance(value, np.ndarray):
return value.astype(np.float32)
if isinstance(value, list):
return np.asarray(value, dtype=np.float32)
for value in payload.values():
if torch is not None and isinstance(value, torch.Tensor):
return value.detach().cpu().numpy().astype(np.float32)
if isinstance(value, np.ndarray):
return value.astype(np.float32)
if isinstance(value, list):
return np.asarray(value, dtype=np.float32)
raise ValueError("Could not locate tensor-like payload. Provide --input_key when needed.")
def _ensure_2d(x: np.ndarray) -> np.ndarray:
if x.ndim == 1:
return x[None, :]
if x.ndim != 2:
raise ValueError(f"Expected 2D array [N, D], got shape={tuple(x.shape)}")
return x.astype(np.float32)
def _from_legged_lab_motion_dict(payload: dict, target_dof: int) -> np.ndarray:
"""
Convert legged_lab motion pickle payload to AMP feature tensor [N, D].
Expected keys: fps, root_pos, dof_pos.
"""
dof_pos = np.asarray(payload.get("dof_pos", None))
root_pos = np.asarray(payload.get("root_pos", None))
fps = float(payload.get("fps", 30.0))
if dof_pos.ndim != 2:
raise ValueError("legged_lab payload missing valid dof_pos [N, M].")
if root_pos.ndim != 2 or root_pos.shape[0] != dof_pos.shape[0]:
root_pos = np.zeros((dof_pos.shape[0], 3), dtype=np.float32)
dt = 1.0 / max(fps, 1e-3)
dof_pos = dof_pos.astype(np.float32)
root_pos = root_pos.astype(np.float32)
# Current get_up AMP feature dim expects 23 dof by default.
target_dof = int(target_dof)
if dof_pos.shape[1] >= target_dof:
dof_pos = dof_pos[:, :target_dof]
else:
dof_pos = np.pad(dof_pos, ((0, 0), (0, target_dof - dof_pos.shape[1])), mode="constant")
dof_vel = np.zeros_like(dof_pos, dtype=np.float32)
if dof_pos.shape[0] > 1:
dof_vel[1:] = (dof_pos[1:] - dof_pos[:-1]) / dt
dof_vel[0] = dof_vel[1]
root_lin = np.zeros((dof_pos.shape[0], 3), dtype=np.float32)
if root_pos.shape[0] > 1:
root_lin[1:] = (root_pos[1:] - root_pos[:-1]) / dt
root_lin[0] = root_lin[1]
root_ang = np.zeros((dof_pos.shape[0], 3), dtype=np.float32)
gravity = np.zeros((dof_pos.shape[0], 3), dtype=np.float32)
gravity[:, 2] = -1.0
x = np.concatenate([dof_pos, dof_vel, root_lin, root_ang, gravity], axis=-1).astype(np.float32)
return x
def _normalize_dim(x: np.ndarray, feature_dim: int) -> np.ndarray:
if x.shape[1] == feature_dim:
return x
if x.shape[1] > feature_dim:
return x[:, :feature_dim]
pad = np.zeros((x.shape[0], feature_dim - x.shape[1]), dtype=np.float32)
return np.concatenate([x, pad], axis=-1)
def _collect_input_files(input_path: Path, glob_pattern: str) -> list[Path]:
if input_path.is_file():
return [input_path]
if input_path.is_dir():
files = sorted([Path(p) for p in glob(str(input_path / glob_pattern))])
if files:
return files
raise FileNotFoundError(f"No source files found for input={input_path} pattern={glob_pattern}")
def _build_clip_weights(clip_names: list[str], weight_mode: str) -> np.ndarray:
if len(clip_names) == 0:
return np.zeros((0,), dtype=np.float32)
if weight_mode == "uniform":
return np.ones((len(clip_names),), dtype=np.float32)
if weight_mode == "git_getup_focus":
weights = np.array([1.0 if GIT_GETUP_FOCUS_CLIP in n else 0.0 for n in clip_names], dtype=np.float32)
if float(np.sum(weights)) <= 0.0:
return np.ones((len(clip_names),), dtype=np.float32)
return weights
return np.ones((len(clip_names),), dtype=np.float32)
def main():
parser = argparse.ArgumentParser(
description=(
"Template converter: migrate legged_lab (or other) motion/expert data "
"to AMP expert_features.pt for rl_game/get_up."
)
)
parser.add_argument("--input", required=True, type=str, help="Path to source motion/expert file or directory.")
parser.add_argument("--output", required=True, type=str, help="Output path for expert_features.pt.")
parser.add_argument(
"--input_key",
type=str,
default="",
help="Optional key to locate tensor inside input payload dict.",
)
parser.add_argument(
"--feature_dim",
type=int,
default=55,
help="Target AMP feature dimension. For current get_up config, default=55.",
)
parser.add_argument(
"--input_glob",
type=str,
default="*.pkl",
help="Glob used when --input is a directory.",
)
parser.add_argument(
"--target_dof",
type=int,
default=23,
help="Target dof count when converting legged_lab pkl dof_pos.",
)
parser.add_argument(
"--clip_weight_mode",
type=str,
default="git_getup_focus",
choices=["git_getup_focus", "uniform"],
help="Clip sampling weight mode for expert sequence training.",
)
parser.add_argument("--repeat", type=int, default=1, help="Repeat samples for small datasets.")
args = parser.parse_args()
in_path = Path(args.input).expanduser().resolve()
input_files = _collect_input_files(in_path, args.input_glob)
clip_arrays: list[np.ndarray] = []
clip_names: list[str] = []
for f in input_files:
payload = _load_payload(f)
if isinstance(payload, dict) and "dof_pos" in payload and f.suffix.lower() == ".pkl":
x = _from_legged_lab_motion_dict(payload, target_dof=int(args.target_dof))
else:
x = _to_numpy(payload, key_hint=args.input_key)
x = _ensure_2d(x)
x = _normalize_dim(x.astype(np.float32), int(args.feature_dim))
clip_arrays.append(x)
clip_names.append(f.stem)
repeat = max(int(args.repeat), 1)
if repeat > 1:
clip_arrays = clip_arrays * repeat
clip_names = clip_names * repeat
clip_weights = _build_clip_weights(clip_names, weight_mode=args.clip_weight_mode)
x = np.concatenate(clip_arrays, axis=0).astype(np.float32)
clip_lengths = [int(c.shape[0]) for c in clip_arrays]
clip_offsets: list[int] = []
cursor = 0
for length in clip_lengths:
clip_offsets.append(cursor)
cursor += length
repeat = max(int(args.repeat), 1)
out_path = Path(args.output).expanduser().resolve()
out_path.parent.mkdir(parents=True, exist_ok=True)
payload = {
"expert_features": x,
"expert_clips": clip_arrays,
"clip_names": clip_names,
"clip_weights": clip_weights,
"clip_lengths": clip_lengths,
"clip_offsets": clip_offsets,
"meta": {
"source": str(in_path),
"source_count": len(input_files),
"input_key": args.input_key,
"feature_dim": int(args.feature_dim),
"target_dof": int(args.target_dof),
"input_glob": args.input_glob,
"clip_weight_mode": args.clip_weight_mode,
"git_getup_focus_clip": GIT_GETUP_FOCUS_CLIP,
"repeat": int(repeat),
},
}
if torch is not None:
torch.save(payload, str(out_path))
else:
with out_path.open("wb") as f:
pickle.dump(payload, f)
print(f"[OK] Saved AMP expert features -> {out_path} shape={tuple(x.shape)}")
if __name__ == "__main__":
main()