Add AMP get-up pipeline with sequence discriminator and git-sourced expert data
This commit is contained in:
457
rl_game/get_up/amp/amp_rewards.py
Normal file
457
rl_game/get_up/amp/amp_rewards.py
Normal file
@@ -0,0 +1,457 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
|
||||
import joblib
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from isaaclab.envs import ManagerBasedRLEnv
|
||||
|
||||
|
||||
def _safe_tensor(x: torch.Tensor, nan: float = 0.0, pos: float = 1e3, neg: float = -1e3) -> torch.Tensor:
|
||||
return torch.nan_to_num(x, nan=nan, posinf=pos, neginf=neg)
|
||||
|
||||
|
||||
class AMPDiscriminator(nn.Module):
|
||||
"""Lightweight discriminator used by online AMP updates."""
|
||||
|
||||
def __init__(self, input_dim: int, hidden_dims: tuple[int, ...]):
|
||||
super().__init__()
|
||||
layers: list[nn.Module] = []
|
||||
in_dim = input_dim
|
||||
for h_dim in hidden_dims:
|
||||
layers.append(nn.Linear(in_dim, h_dim))
|
||||
layers.append(nn.LayerNorm(h_dim))
|
||||
layers.append(nn.SiLU())
|
||||
in_dim = h_dim
|
||||
layers.append(nn.Linear(in_dim, 1))
|
||||
self.net = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.net(x)
|
||||
|
||||
|
||||
def _to_tensor_2d(value) -> torch.Tensor | None:
|
||||
if isinstance(value, torch.Tensor):
|
||||
t = value.float()
|
||||
elif isinstance(value, np.ndarray):
|
||||
t = torch.as_tensor(value, dtype=torch.float32)
|
||||
elif isinstance(value, list):
|
||||
try:
|
||||
t = torch.as_tensor(value, dtype=torch.float32)
|
||||
except Exception:
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
if t.ndim == 1:
|
||||
t = t.unsqueeze(0)
|
||||
if t.ndim != 2:
|
||||
return None
|
||||
return t
|
||||
|
||||
|
||||
def _normalize_feature_dim(x: torch.Tensor, feature_dim: int) -> torch.Tensor:
|
||||
if x.shape[-1] == feature_dim:
|
||||
return x
|
||||
if x.shape[-1] > feature_dim:
|
||||
return x[:, :feature_dim]
|
||||
pad = torch.zeros((x.shape[0], feature_dim - x.shape[1]), dtype=x.dtype)
|
||||
return torch.cat([x, pad], dim=-1)
|
||||
|
||||
|
||||
def _extract_expert_bank_from_payload(payload, feature_dim: int) -> dict | None:
|
||||
clip_tensors: list[torch.Tensor] = []
|
||||
clip_names: list[str] = []
|
||||
clip_weights_tensor: torch.Tensor | None = None
|
||||
|
||||
if isinstance(payload, dict):
|
||||
clip_values = payload.get("expert_clips", None)
|
||||
if isinstance(clip_values, list):
|
||||
for i, clip_value in enumerate(clip_values):
|
||||
clip_tensor = _to_tensor_2d(clip_value)
|
||||
if clip_tensor is None:
|
||||
continue
|
||||
clip_tensors.append(_normalize_feature_dim(clip_tensor, feature_dim))
|
||||
clip_names.append(f"clip_{i}")
|
||||
raw_names = payload.get("clip_names", None)
|
||||
if isinstance(raw_names, list) and len(raw_names) == len(clip_tensors):
|
||||
clip_names = [str(n) for n in raw_names]
|
||||
raw_weights = payload.get("clip_weights", None)
|
||||
clip_weights_tensor = _to_tensor_2d(raw_weights) if raw_weights is not None else None
|
||||
if clip_weights_tensor is not None:
|
||||
clip_weights_tensor = clip_weights_tensor.reshape(-1)
|
||||
if clip_weights_tensor.shape[0] != len(clip_tensors):
|
||||
clip_weights_tensor = None
|
||||
|
||||
for key in ("expert_features", "features", "obs"):
|
||||
value = payload.get(key, None)
|
||||
tensor = _to_tensor_2d(value)
|
||||
if tensor is not None:
|
||||
clip_tensors.append(_normalize_feature_dim(tensor, feature_dim))
|
||||
clip_names.append(key)
|
||||
break
|
||||
else:
|
||||
tensor = _to_tensor_2d(payload)
|
||||
if tensor is not None:
|
||||
clip_tensors.append(_normalize_feature_dim(tensor, feature_dim))
|
||||
clip_names.append("expert_features")
|
||||
|
||||
clip_tensors = [c for c in clip_tensors if c.shape[0] >= 2]
|
||||
if len(clip_tensors) == 0:
|
||||
return None
|
||||
|
||||
if clip_weights_tensor is None:
|
||||
clip_weights_tensor = torch.ones(len(clip_tensors), dtype=torch.float32)
|
||||
else:
|
||||
clip_weights_tensor = torch.clamp(clip_weights_tensor.float(), min=0.0)
|
||||
if float(torch.sum(clip_weights_tensor).item()) <= 0.0:
|
||||
clip_weights_tensor = torch.ones(len(clip_tensors), dtype=torch.float32)
|
||||
|
||||
flat_features = torch.cat(clip_tensors, dim=0)
|
||||
return {
|
||||
"flat_features": flat_features,
|
||||
"clips": clip_tensors,
|
||||
"clip_names": clip_names,
|
||||
"clip_weights": clip_weights_tensor,
|
||||
}
|
||||
|
||||
|
||||
def _load_amp_expert_features(
|
||||
expert_features_path: str,
|
||||
device: str,
|
||||
feature_dim: int,
|
||||
fallback_samples: int,
|
||||
) -> dict | None:
|
||||
"""Load expert AMP features bank. Returns None when file is unavailable."""
|
||||
if not expert_features_path:
|
||||
return None
|
||||
p = Path(expert_features_path).expanduser()
|
||||
if not p.is_file():
|
||||
return None
|
||||
try:
|
||||
payload = torch.load(str(p), map_location="cpu")
|
||||
except Exception:
|
||||
try:
|
||||
with p.open("rb") as f:
|
||||
payload = pickle.load(f)
|
||||
except Exception:
|
||||
try:
|
||||
payload = joblib.load(str(p))
|
||||
except Exception:
|
||||
return None
|
||||
bank = _extract_expert_bank_from_payload(payload, feature_dim=feature_dim)
|
||||
if bank is None:
|
||||
return None
|
||||
flat_features = bank["flat_features"].float()
|
||||
if flat_features.shape[0] < fallback_samples:
|
||||
reps = int((fallback_samples + flat_features.shape[0] - 1) // flat_features.shape[0])
|
||||
flat_features = flat_features.repeat(reps, 1)
|
||||
|
||||
clip_tensors = [clip.float() for clip in bank["clips"]]
|
||||
clip_weights = bank["clip_weights"].float()
|
||||
clip_weights = torch.clamp(clip_weights, min=0.0)
|
||||
if float(torch.sum(clip_weights).item()) <= 0.0:
|
||||
clip_weights = torch.ones_like(clip_weights)
|
||||
clip_weights = clip_weights / torch.sum(clip_weights)
|
||||
|
||||
return {
|
||||
"flat_features": flat_features.to(device=device),
|
||||
"clips": [c.to(device=device) for c in clip_tensors],
|
||||
"clip_names": bank["clip_names"],
|
||||
"clip_weights": clip_weights.to(device=device),
|
||||
}
|
||||
|
||||
|
||||
def _policy_sequence_features(
|
||||
env: ManagerBasedRLEnv,
|
||||
current_features: torch.Tensor,
|
||||
history_steps: int,
|
||||
) -> torch.Tensor:
|
||||
"""Build rolling policy sequence features [N, H, D]."""
|
||||
history_steps = max(int(history_steps), 1)
|
||||
cache_key = "amp_policy_hist_cache"
|
||||
cache = env.extras.get(cache_key, None)
|
||||
if not isinstance(cache, dict):
|
||||
cache = {}
|
||||
env.extras[cache_key] = cache
|
||||
|
||||
hist = cache.get("hist", None)
|
||||
if (
|
||||
not isinstance(hist, torch.Tensor)
|
||||
or hist.shape[0] != env.num_envs
|
||||
or hist.shape[1] != history_steps
|
||||
or hist.shape[2] != current_features.shape[1]
|
||||
):
|
||||
hist = current_features.unsqueeze(1).repeat(1, history_steps, 1)
|
||||
|
||||
if history_steps > 1:
|
||||
hist[:, :-1] = hist[:, 1:].clone()
|
||||
hist[:, -1] = current_features
|
||||
cache["hist"] = hist
|
||||
env.extras[cache_key] = cache
|
||||
return hist
|
||||
|
||||
|
||||
def _sample_expert_sequence_batch(
|
||||
expert_bank: dict,
|
||||
batch_size: int,
|
||||
history_steps: int,
|
||||
device: str,
|
||||
) -> torch.Tensor:
|
||||
"""Sample expert sequence batch [B, H, D] with clip-weighted sampling."""
|
||||
clips = expert_bank["clips"]
|
||||
clip_weights = expert_bank["clip_weights"]
|
||||
clip_count = len(clips)
|
||||
if clip_count <= 0:
|
||||
return torch.empty((0, history_steps, 0), device=device)
|
||||
|
||||
clip_ids = torch.multinomial(clip_weights, num_samples=batch_size, replacement=True)
|
||||
seq_list: list[torch.Tensor] = []
|
||||
for clip_id in clip_ids.tolist():
|
||||
clip = clips[int(clip_id)]
|
||||
clip_len = int(clip.shape[0])
|
||||
if clip_len >= history_steps:
|
||||
max_start = clip_len - history_steps
|
||||
if max_start > 0:
|
||||
start = int(torch.randint(0, max_start + 1, (1,), device=device).item())
|
||||
else:
|
||||
start = 0
|
||||
seq = clip[start : start + history_steps]
|
||||
else:
|
||||
pad = clip[-1:].repeat(history_steps - clip_len, 1)
|
||||
seq = torch.cat([clip, pad], dim=0)
|
||||
seq_list.append(seq)
|
||||
return torch.stack(seq_list, dim=0)
|
||||
|
||||
|
||||
def _get_amp_state(
|
||||
env: ManagerBasedRLEnv,
|
||||
amp_enabled: bool,
|
||||
amp_model_path: str,
|
||||
amp_train_enabled: bool,
|
||||
amp_expert_features_path: str,
|
||||
feature_dim: int,
|
||||
disc_hidden_dim: int,
|
||||
disc_hidden_layers: int,
|
||||
disc_lr: float,
|
||||
disc_weight_decay: float,
|
||||
disc_min_expert_samples: int,
|
||||
disc_history_steps: int,
|
||||
):
|
||||
"""Get cached AMP state (frozen jit or trainable discriminator)."""
|
||||
cache_key = "amp_state_cache"
|
||||
hidden_layers = max(int(disc_hidden_layers), 1)
|
||||
hidden_dim = max(int(disc_hidden_dim), 16)
|
||||
history_steps = max(int(disc_history_steps), 1)
|
||||
state_sig = (
|
||||
bool(amp_enabled),
|
||||
str(amp_model_path),
|
||||
bool(amp_train_enabled),
|
||||
str(amp_expert_features_path),
|
||||
int(feature_dim),
|
||||
history_steps,
|
||||
hidden_dim,
|
||||
hidden_layers,
|
||||
float(disc_lr),
|
||||
float(disc_weight_decay),
|
||||
)
|
||||
cached = env.extras.get(cache_key, None)
|
||||
if isinstance(cached, dict) and cached.get("sig") == state_sig:
|
||||
return cached
|
||||
|
||||
state = {
|
||||
"sig": state_sig,
|
||||
"mode": "disabled",
|
||||
"model": None,
|
||||
"optimizer": None,
|
||||
"expert_bank": None,
|
||||
"disc_history_steps": history_steps,
|
||||
"step": 0,
|
||||
"last_loss": 0.0,
|
||||
"last_acc_policy": 0.0,
|
||||
"last_acc_expert": 0.0,
|
||||
}
|
||||
|
||||
if amp_train_enabled:
|
||||
expert_bank = _load_amp_expert_features(
|
||||
amp_expert_features_path,
|
||||
device=env.device,
|
||||
feature_dim=feature_dim,
|
||||
fallback_samples=max(disc_min_expert_samples, 512),
|
||||
)
|
||||
if expert_bank is not None:
|
||||
model = AMPDiscriminator(input_dim=feature_dim * history_steps, hidden_dims=tuple([hidden_dim] * hidden_layers)).to(env.device)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=float(disc_lr), weight_decay=float(disc_weight_decay))
|
||||
state["mode"] = "trainable"
|
||||
state["model"] = model
|
||||
state["optimizer"] = optimizer
|
||||
state["expert_bank"] = expert_bank
|
||||
elif amp_enabled and amp_model_path:
|
||||
model_path = Path(amp_model_path).expanduser()
|
||||
if model_path.is_file():
|
||||
try:
|
||||
model = torch.jit.load(str(model_path), map_location=env.device)
|
||||
model.eval()
|
||||
state["mode"] = "jit"
|
||||
state["model"] = model
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
env.extras[cache_key] = state
|
||||
return state
|
||||
|
||||
|
||||
def _build_amp_features(env: ManagerBasedRLEnv, feature_clip: float = 8.0) -> torch.Tensor:
|
||||
"""Build AMP-style discriminator features from robot kinematics."""
|
||||
robot_data = env.scene["robot"].data
|
||||
joint_pos_rel = robot_data.joint_pos - robot_data.default_joint_pos
|
||||
joint_vel = robot_data.joint_vel
|
||||
root_lin_vel = robot_data.root_lin_vel_w
|
||||
root_ang_vel = robot_data.root_ang_vel_w
|
||||
projected_gravity = robot_data.projected_gravity_b
|
||||
amp_features = torch.cat([joint_pos_rel, joint_vel, root_lin_vel, root_ang_vel, projected_gravity], dim=-1)
|
||||
amp_features = _safe_tensor(amp_features, nan=0.0, pos=feature_clip, neg=-feature_clip)
|
||||
return torch.clamp(amp_features, min=-feature_clip, max=feature_clip)
|
||||
|
||||
|
||||
def amp_style_prior_reward(
|
||||
env: ManagerBasedRLEnv,
|
||||
amp_enabled: bool = False,
|
||||
amp_model_path: str = "",
|
||||
amp_train_enabled: bool = False,
|
||||
amp_expert_features_path: str = "",
|
||||
disc_hidden_dim: int = 256,
|
||||
disc_hidden_layers: int = 2,
|
||||
disc_lr: float = 3e-4,
|
||||
disc_weight_decay: float = 1e-6,
|
||||
disc_update_interval: int = 4,
|
||||
disc_batch_size: int = 1024,
|
||||
disc_min_expert_samples: int = 2048,
|
||||
disc_history_steps: int = 4,
|
||||
feature_clip: float = 8.0,
|
||||
logit_scale: float = 1.0,
|
||||
amp_reward_gain: float = 1.0,
|
||||
internal_reward_scale: float = 1.0,
|
||||
) -> torch.Tensor:
|
||||
"""AMP style prior reward with optional online discriminator training."""
|
||||
zeros = torch.zeros(env.num_envs, device=env.device)
|
||||
amp_score = zeros
|
||||
model_loaded = 0.0
|
||||
amp_train_active = 0.0
|
||||
disc_loss = 0.0
|
||||
disc_acc_policy = 0.0
|
||||
disc_acc_expert = 0.0
|
||||
|
||||
amp_features = _build_amp_features(env, feature_clip=feature_clip)
|
||||
amp_state = _get_amp_state(
|
||||
env=env,
|
||||
amp_enabled=amp_enabled,
|
||||
amp_model_path=amp_model_path,
|
||||
amp_train_enabled=amp_train_enabled,
|
||||
amp_expert_features_path=amp_expert_features_path,
|
||||
feature_dim=int(amp_features.shape[-1]),
|
||||
disc_hidden_dim=disc_hidden_dim,
|
||||
disc_hidden_layers=disc_hidden_layers,
|
||||
disc_lr=disc_lr,
|
||||
disc_weight_decay=disc_weight_decay,
|
||||
disc_min_expert_samples=disc_min_expert_samples,
|
||||
disc_history_steps=disc_history_steps,
|
||||
)
|
||||
discriminator = amp_state.get("model", None)
|
||||
history_steps = max(int(amp_state.get("disc_history_steps", disc_history_steps)), 1)
|
||||
policy_seq = _policy_sequence_features(env, amp_features, history_steps=history_steps)
|
||||
policy_seq_flat = policy_seq.reshape(policy_seq.shape[0], -1)
|
||||
if discriminator is not None:
|
||||
model_loaded = 1.0
|
||||
|
||||
if amp_state.get("mode") == "trainable" and discriminator is not None:
|
||||
amp_train_active = 1.0
|
||||
optimizer = amp_state.get("optimizer", None)
|
||||
expert_bank = amp_state.get("expert_bank", None)
|
||||
amp_state["step"] = int(amp_state.get("step", 0)) + 1
|
||||
update_interval = max(int(disc_update_interval), 1)
|
||||
batch_size = max(int(disc_batch_size), 32)
|
||||
|
||||
if optimizer is not None and isinstance(expert_bank, dict) and amp_state["step"] % update_interval == 0:
|
||||
policy_features = policy_seq_flat.detach()
|
||||
policy_count = policy_features.shape[0]
|
||||
if policy_count > batch_size:
|
||||
policy_ids = torch.randint(0, policy_count, (batch_size,), device=env.device)
|
||||
policy_batch = policy_features.index_select(0, policy_ids)
|
||||
else:
|
||||
policy_batch = policy_features
|
||||
|
||||
expert_seq = _sample_expert_sequence_batch(
|
||||
expert_bank=expert_bank,
|
||||
batch_size=policy_batch.shape[0],
|
||||
history_steps=history_steps,
|
||||
device=env.device,
|
||||
)
|
||||
expert_batch = expert_seq.reshape(expert_seq.shape[0], -1)
|
||||
|
||||
discriminator.train()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
logits_expert = discriminator(expert_batch).squeeze(-1)
|
||||
logits_policy = discriminator(policy_batch).squeeze(-1)
|
||||
loss_expert = nn.functional.binary_cross_entropy_with_logits(logits_expert, torch.ones_like(logits_expert))
|
||||
loss_policy = nn.functional.binary_cross_entropy_with_logits(logits_policy, torch.zeros_like(logits_policy))
|
||||
loss = 0.5 * (loss_expert + loss_policy)
|
||||
loss.backward()
|
||||
nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0)
|
||||
optimizer.step()
|
||||
discriminator.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
disc_loss = float(loss.detach().item())
|
||||
disc_acc_expert = float((torch.sigmoid(logits_expert) > 0.5).float().mean().item())
|
||||
disc_acc_policy = float((torch.sigmoid(logits_policy) < 0.5).float().mean().item())
|
||||
amp_state["last_loss"] = disc_loss
|
||||
amp_state["last_acc_expert"] = disc_acc_expert
|
||||
amp_state["last_acc_policy"] = disc_acc_policy
|
||||
else:
|
||||
disc_loss = float(amp_state.get("last_loss", 0.0))
|
||||
disc_acc_expert = float(amp_state.get("last_acc_expert", 0.0))
|
||||
disc_acc_policy = float(amp_state.get("last_acc_policy", 0.0))
|
||||
|
||||
if discriminator is not None:
|
||||
discriminator.eval()
|
||||
with torch.no_grad():
|
||||
logits = discriminator(policy_seq_flat)
|
||||
if isinstance(logits, (tuple, list)):
|
||||
logits = logits[0]
|
||||
if logits.ndim > 1:
|
||||
logits = logits.squeeze(-1)
|
||||
elif amp_enabled and amp_model_path:
|
||||
# For external scripted models, try temporal then fallback to single frame.
|
||||
model = amp_state.get("model", None)
|
||||
if model is not None:
|
||||
with torch.no_grad():
|
||||
try:
|
||||
logits = model(policy_seq_flat)
|
||||
except Exception:
|
||||
logits = model(amp_features)
|
||||
if isinstance(logits, (tuple, list)):
|
||||
logits = logits[0]
|
||||
if logits.ndim > 1:
|
||||
logits = logits.squeeze(-1)
|
||||
logits = _safe_tensor(logits, nan=0.0, pos=20.0, neg=-20.0)
|
||||
amp_score = torch.sigmoid(logit_scale * logits)
|
||||
amp_score = _safe_tensor(amp_score, nan=0.0, pos=1.0, neg=0.0)
|
||||
|
||||
amp_reward = _safe_tensor(amp_reward_gain * amp_score, nan=0.0, pos=10.0, neg=0.0)
|
||||
|
||||
log_dict = env.extras.get("log", {})
|
||||
if isinstance(log_dict, dict):
|
||||
log_dict["amp_score_mean"] = torch.mean(amp_score).detach().item()
|
||||
log_dict["amp_reward_mean"] = torch.mean(amp_reward).detach().item()
|
||||
log_dict["amp_model_loaded_mean"] = model_loaded
|
||||
log_dict["amp_train_active_mean"] = amp_train_active
|
||||
log_dict["amp_disc_loss_mean"] = disc_loss
|
||||
log_dict["amp_disc_acc_policy_mean"] = disc_acc_policy
|
||||
log_dict["amp_disc_acc_expert_mean"] = disc_acc_expert
|
||||
log_dict["amp_disc_history_steps"] = float(history_steps)
|
||||
env.extras["log"] = log_dict
|
||||
|
||||
return internal_reward_scale * amp_reward
|
||||
Reference in New Issue
Block a user