from __future__ import annotations import pickle from pathlib import Path import joblib import numpy as np import torch import torch.nn as nn from isaaclab.envs import ManagerBasedRLEnv def _safe_tensor(x: torch.Tensor, nan: float = 0.0, pos: float = 1e3, neg: float = -1e3) -> torch.Tensor: return torch.nan_to_num(x, nan=nan, posinf=pos, neginf=neg) class AMPDiscriminator(nn.Module): """Lightweight discriminator used by online AMP updates.""" def __init__(self, input_dim: int, hidden_dims: tuple[int, ...]): super().__init__() layers: list[nn.Module] = [] in_dim = input_dim for h_dim in hidden_dims: layers.append(nn.Linear(in_dim, h_dim)) layers.append(nn.LayerNorm(h_dim)) layers.append(nn.SiLU()) in_dim = h_dim layers.append(nn.Linear(in_dim, 1)) self.net = nn.Sequential(*layers) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.net(x) def _to_tensor_2d(value) -> torch.Tensor | None: if isinstance(value, torch.Tensor): t = value.float() elif isinstance(value, np.ndarray): t = torch.as_tensor(value, dtype=torch.float32) elif isinstance(value, list): try: t = torch.as_tensor(value, dtype=torch.float32) except Exception: return None else: return None if t.ndim == 1: t = t.unsqueeze(0) if t.ndim != 2: return None return t def _normalize_feature_dim(x: torch.Tensor, feature_dim: int) -> torch.Tensor: if x.shape[-1] == feature_dim: return x if x.shape[-1] > feature_dim: return x[:, :feature_dim] pad = torch.zeros((x.shape[0], feature_dim - x.shape[1]), dtype=x.dtype) return torch.cat([x, pad], dim=-1) def _extract_expert_bank_from_payload(payload, feature_dim: int) -> dict | None: clip_tensors: list[torch.Tensor] = [] clip_names: list[str] = [] clip_weights_tensor: torch.Tensor | None = None if isinstance(payload, dict): clip_values = payload.get("expert_clips", None) if isinstance(clip_values, list): for i, clip_value in enumerate(clip_values): clip_tensor = _to_tensor_2d(clip_value) if clip_tensor is None: continue clip_tensors.append(_normalize_feature_dim(clip_tensor, feature_dim)) clip_names.append(f"clip_{i}") raw_names = payload.get("clip_names", None) if isinstance(raw_names, list) and len(raw_names) == len(clip_tensors): clip_names = [str(n) for n in raw_names] raw_weights = payload.get("clip_weights", None) clip_weights_tensor = _to_tensor_2d(raw_weights) if raw_weights is not None else None if clip_weights_tensor is not None: clip_weights_tensor = clip_weights_tensor.reshape(-1) if clip_weights_tensor.shape[0] != len(clip_tensors): clip_weights_tensor = None for key in ("expert_features", "features", "obs"): value = payload.get(key, None) tensor = _to_tensor_2d(value) if tensor is not None: clip_tensors.append(_normalize_feature_dim(tensor, feature_dim)) clip_names.append(key) break else: tensor = _to_tensor_2d(payload) if tensor is not None: clip_tensors.append(_normalize_feature_dim(tensor, feature_dim)) clip_names.append("expert_features") clip_tensors = [c for c in clip_tensors if c.shape[0] >= 2] if len(clip_tensors) == 0: return None if clip_weights_tensor is None: clip_weights_tensor = torch.ones(len(clip_tensors), dtype=torch.float32) else: clip_weights_tensor = torch.clamp(clip_weights_tensor.float(), min=0.0) if float(torch.sum(clip_weights_tensor).item()) <= 0.0: clip_weights_tensor = torch.ones(len(clip_tensors), dtype=torch.float32) flat_features = torch.cat(clip_tensors, dim=0) return { "flat_features": flat_features, "clips": clip_tensors, "clip_names": clip_names, "clip_weights": clip_weights_tensor, } def _load_amp_expert_features( expert_features_path: str, device: str, feature_dim: int, fallback_samples: int, ) -> dict | None: """Load expert AMP features bank. Returns None when file is unavailable.""" if not expert_features_path: return None p = Path(expert_features_path).expanduser() if not p.is_file(): return None try: payload = torch.load(str(p), map_location="cpu") except Exception: try: with p.open("rb") as f: payload = pickle.load(f) except Exception: try: payload = joblib.load(str(p)) except Exception: return None bank = _extract_expert_bank_from_payload(payload, feature_dim=feature_dim) if bank is None: return None flat_features = bank["flat_features"].float() if flat_features.shape[0] < fallback_samples: reps = int((fallback_samples + flat_features.shape[0] - 1) // flat_features.shape[0]) flat_features = flat_features.repeat(reps, 1) clip_tensors = [clip.float() for clip in bank["clips"]] clip_weights = bank["clip_weights"].float() clip_weights = torch.clamp(clip_weights, min=0.0) if float(torch.sum(clip_weights).item()) <= 0.0: clip_weights = torch.ones_like(clip_weights) clip_weights = clip_weights / torch.sum(clip_weights) return { "flat_features": flat_features.to(device=device), "clips": [c.to(device=device) for c in clip_tensors], "clip_names": bank["clip_names"], "clip_weights": clip_weights.to(device=device), } def _policy_sequence_features( env: ManagerBasedRLEnv, current_features: torch.Tensor, history_steps: int, ) -> torch.Tensor: """Build rolling policy sequence features [N, H, D].""" history_steps = max(int(history_steps), 1) cache_key = "amp_policy_hist_cache" cache = env.extras.get(cache_key, None) if not isinstance(cache, dict): cache = {} env.extras[cache_key] = cache hist = cache.get("hist", None) if ( not isinstance(hist, torch.Tensor) or hist.shape[0] != env.num_envs or hist.shape[1] != history_steps or hist.shape[2] != current_features.shape[1] ): hist = current_features.unsqueeze(1).repeat(1, history_steps, 1) if history_steps > 1: hist[:, :-1] = hist[:, 1:].clone() hist[:, -1] = current_features cache["hist"] = hist env.extras[cache_key] = cache return hist def _sample_expert_sequence_batch( expert_bank: dict, batch_size: int, history_steps: int, device: str, ) -> torch.Tensor: """Sample expert sequence batch [B, H, D] with clip-weighted sampling.""" clips = expert_bank["clips"] clip_weights = expert_bank["clip_weights"] clip_count = len(clips) if clip_count <= 0: return torch.empty((0, history_steps, 0), device=device) clip_ids = torch.multinomial(clip_weights, num_samples=batch_size, replacement=True) seq_list: list[torch.Tensor] = [] for clip_id in clip_ids.tolist(): clip = clips[int(clip_id)] clip_len = int(clip.shape[0]) if clip_len >= history_steps: max_start = clip_len - history_steps if max_start > 0: start = int(torch.randint(0, max_start + 1, (1,), device=device).item()) else: start = 0 seq = clip[start : start + history_steps] else: pad = clip[-1:].repeat(history_steps - clip_len, 1) seq = torch.cat([clip, pad], dim=0) seq_list.append(seq) return torch.stack(seq_list, dim=0) def _get_amp_state( env: ManagerBasedRLEnv, amp_enabled: bool, amp_model_path: str, amp_train_enabled: bool, amp_expert_features_path: str, feature_dim: int, disc_hidden_dim: int, disc_hidden_layers: int, disc_lr: float, disc_weight_decay: float, disc_min_expert_samples: int, disc_history_steps: int, ): """Get cached AMP state (frozen jit or trainable discriminator).""" cache_key = "amp_state_cache" hidden_layers = max(int(disc_hidden_layers), 1) hidden_dim = max(int(disc_hidden_dim), 16) history_steps = max(int(disc_history_steps), 1) state_sig = ( bool(amp_enabled), str(amp_model_path), bool(amp_train_enabled), str(amp_expert_features_path), int(feature_dim), history_steps, hidden_dim, hidden_layers, float(disc_lr), float(disc_weight_decay), ) cached = env.extras.get(cache_key, None) if isinstance(cached, dict) and cached.get("sig") == state_sig: return cached state = { "sig": state_sig, "mode": "disabled", "model": None, "optimizer": None, "expert_bank": None, "disc_history_steps": history_steps, "step": 0, "last_loss": 0.0, "last_acc_policy": 0.0, "last_acc_expert": 0.0, } if amp_train_enabled: expert_bank = _load_amp_expert_features( amp_expert_features_path, device=env.device, feature_dim=feature_dim, fallback_samples=max(disc_min_expert_samples, 512), ) if expert_bank is not None: model = AMPDiscriminator(input_dim=feature_dim * history_steps, hidden_dims=tuple([hidden_dim] * hidden_layers)).to(env.device) optimizer = torch.optim.AdamW(model.parameters(), lr=float(disc_lr), weight_decay=float(disc_weight_decay)) state["mode"] = "trainable" state["model"] = model state["optimizer"] = optimizer state["expert_bank"] = expert_bank elif amp_enabled and amp_model_path: model_path = Path(amp_model_path).expanduser() if model_path.is_file(): try: model = torch.jit.load(str(model_path), map_location=env.device) model.eval() state["mode"] = "jit" state["model"] = model except Exception: pass env.extras[cache_key] = state return state def _build_amp_features(env: ManagerBasedRLEnv, feature_clip: float = 8.0) -> torch.Tensor: """Build AMP-style discriminator features from robot kinematics.""" robot_data = env.scene["robot"].data joint_pos_rel = robot_data.joint_pos - robot_data.default_joint_pos joint_vel = robot_data.joint_vel root_lin_vel = robot_data.root_lin_vel_w root_ang_vel = robot_data.root_ang_vel_w projected_gravity = robot_data.projected_gravity_b amp_features = torch.cat([joint_pos_rel, joint_vel, root_lin_vel, root_ang_vel, projected_gravity], dim=-1) amp_features = _safe_tensor(amp_features, nan=0.0, pos=feature_clip, neg=-feature_clip) return torch.clamp(amp_features, min=-feature_clip, max=feature_clip) def amp_style_prior_reward( env: ManagerBasedRLEnv, amp_enabled: bool = False, amp_model_path: str = "", amp_train_enabled: bool = False, amp_expert_features_path: str = "", disc_hidden_dim: int = 256, disc_hidden_layers: int = 2, disc_lr: float = 3e-4, disc_weight_decay: float = 1e-6, disc_update_interval: int = 4, disc_batch_size: int = 1024, disc_min_expert_samples: int = 2048, disc_history_steps: int = 4, feature_clip: float = 8.0, logit_scale: float = 1.0, amp_reward_gain: float = 1.0, internal_reward_scale: float = 1.0, ) -> torch.Tensor: """AMP style prior reward with optional online discriminator training.""" zeros = torch.zeros(env.num_envs, device=env.device) amp_score = zeros model_loaded = 0.0 amp_train_active = 0.0 disc_loss = 0.0 disc_acc_policy = 0.0 disc_acc_expert = 0.0 amp_features = _build_amp_features(env, feature_clip=feature_clip) amp_state = _get_amp_state( env=env, amp_enabled=amp_enabled, amp_model_path=amp_model_path, amp_train_enabled=amp_train_enabled, amp_expert_features_path=amp_expert_features_path, feature_dim=int(amp_features.shape[-1]), disc_hidden_dim=disc_hidden_dim, disc_hidden_layers=disc_hidden_layers, disc_lr=disc_lr, disc_weight_decay=disc_weight_decay, disc_min_expert_samples=disc_min_expert_samples, disc_history_steps=disc_history_steps, ) discriminator = amp_state.get("model", None) history_steps = max(int(amp_state.get("disc_history_steps", disc_history_steps)), 1) policy_seq = _policy_sequence_features(env, amp_features, history_steps=history_steps) policy_seq_flat = policy_seq.reshape(policy_seq.shape[0], -1) if discriminator is not None: model_loaded = 1.0 if amp_state.get("mode") == "trainable" and discriminator is not None: amp_train_active = 1.0 optimizer = amp_state.get("optimizer", None) expert_bank = amp_state.get("expert_bank", None) amp_state["step"] = int(amp_state.get("step", 0)) + 1 update_interval = max(int(disc_update_interval), 1) batch_size = max(int(disc_batch_size), 32) if optimizer is not None and isinstance(expert_bank, dict) and amp_state["step"] % update_interval == 0: policy_features = policy_seq_flat.detach() policy_count = policy_features.shape[0] if policy_count > batch_size: policy_ids = torch.randint(0, policy_count, (batch_size,), device=env.device) policy_batch = policy_features.index_select(0, policy_ids) else: policy_batch = policy_features expert_seq = _sample_expert_sequence_batch( expert_bank=expert_bank, batch_size=policy_batch.shape[0], history_steps=history_steps, device=env.device, ) expert_batch = expert_seq.reshape(expert_seq.shape[0], -1) discriminator.train() optimizer.zero_grad(set_to_none=True) logits_expert = discriminator(expert_batch).squeeze(-1) logits_policy = discriminator(policy_batch).squeeze(-1) loss_expert = nn.functional.binary_cross_entropy_with_logits(logits_expert, torch.ones_like(logits_expert)) loss_policy = nn.functional.binary_cross_entropy_with_logits(logits_policy, torch.zeros_like(logits_policy)) loss = 0.5 * (loss_expert + loss_policy) loss.backward() nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0) optimizer.step() discriminator.eval() with torch.no_grad(): disc_loss = float(loss.detach().item()) disc_acc_expert = float((torch.sigmoid(logits_expert) > 0.5).float().mean().item()) disc_acc_policy = float((torch.sigmoid(logits_policy) < 0.5).float().mean().item()) amp_state["last_loss"] = disc_loss amp_state["last_acc_expert"] = disc_acc_expert amp_state["last_acc_policy"] = disc_acc_policy else: disc_loss = float(amp_state.get("last_loss", 0.0)) disc_acc_expert = float(amp_state.get("last_acc_expert", 0.0)) disc_acc_policy = float(amp_state.get("last_acc_policy", 0.0)) if discriminator is not None: discriminator.eval() with torch.no_grad(): logits = discriminator(policy_seq_flat) if isinstance(logits, (tuple, list)): logits = logits[0] if logits.ndim > 1: logits = logits.squeeze(-1) elif amp_enabled and amp_model_path: # For external scripted models, try temporal then fallback to single frame. model = amp_state.get("model", None) if model is not None: with torch.no_grad(): try: logits = model(policy_seq_flat) except Exception: logits = model(amp_features) if isinstance(logits, (tuple, list)): logits = logits[0] if logits.ndim > 1: logits = logits.squeeze(-1) logits = _safe_tensor(logits, nan=0.0, pos=20.0, neg=-20.0) amp_score = torch.sigmoid(logit_scale * logits) amp_score = _safe_tensor(amp_score, nan=0.0, pos=1.0, neg=0.0) amp_reward = _safe_tensor(amp_reward_gain * amp_score, nan=0.0, pos=10.0, neg=0.0) log_dict = env.extras.get("log", {}) if isinstance(log_dict, dict): log_dict["amp_score_mean"] = torch.mean(amp_score).detach().item() log_dict["amp_reward_mean"] = torch.mean(amp_reward).detach().item() log_dict["amp_model_loaded_mean"] = model_loaded log_dict["amp_train_active_mean"] = amp_train_active log_dict["amp_disc_loss_mean"] = disc_loss log_dict["amp_disc_acc_policy_mean"] = disc_acc_policy log_dict["amp_disc_acc_expert_mean"] = disc_acc_expert log_dict["amp_disc_history_steps"] = float(history_steps) env.extras["log"] = log_dict return internal_reward_scale * amp_reward