import sys import os import argparse import glob import re PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) if PROJECT_ROOT not in sys.path: sys.path.insert(0, PROJECT_ROOT) from isaaclab.app import AppLauncher parser = argparse.ArgumentParser(description="Train T1 get-up policy.") parser.add_argument("--num_envs", type=int, default=8192, help="Number of parallel environments") parser.add_argument("--seed", type=int, default=42, help="Random seed") parser.add_argument( "--amp_model", type=str, default="", help="TorchScript AMP discriminator path (.pt/.jit). Empty disables AMP reward.", ) parser.add_argument( "--amp_reward_weight", type=float, default=0.6, help="Reward term weight for AMP style prior when --amp_model is provided.", ) parser.add_argument( "--amp_expert_features", type=str, default="", help="Path to torch file containing expert AMP features (N, D) for online discriminator training.", ) parser.add_argument( "--amp_train_discriminator", action="store_true", help="Enable online AMP discriminator updates using --amp_expert_features.", ) parser.add_argument("--amp_disc_hidden_dim", type=int, default=256, help="Hidden width for AMP discriminator.") parser.add_argument("--amp_disc_hidden_layers", type=int, default=2, help="Hidden layer count for AMP discriminator.") parser.add_argument("--amp_disc_lr", type=float, default=3e-4, help="Learning rate for AMP discriminator.") parser.add_argument("--amp_disc_weight_decay", type=float, default=1e-6, help="Weight decay for AMP discriminator.") parser.add_argument("--amp_disc_update_interval", type=int, default=4, help="Train discriminator every N reward calls.") parser.add_argument("--amp_disc_batch_size", type=int, default=1024, help="Discriminator train batch size.") parser.add_argument("--amp_logit_scale", type=float, default=1.0, help="Scale before sigmoid(logits) for AMP score.") AppLauncher.add_app_launcher_args(parser) args_cli = parser.parse_args() app_launcher = AppLauncher(args_cli) simulation_app = app_launcher.app import gymnasium as gym import yaml from isaaclab_rl.rl_games import RlGamesVecEnvWrapper from rl_games.common.algo_observer import DefaultAlgoObserver from rl_games.torch_runner import Runner from rl_games.common import env_configurations, vecenv from rl_game.get_up.config.t1_env_cfg import T1EnvCfg class T1MetricObserver(DefaultAlgoObserver): """Collect custom env metrics and print to terminal.""" def __init__(self): super().__init__() self._tracked = ( "upright_mean", "foot_support_ratio_mean", "arm_support_ratio_mean", "hip_roll_mean", "stand_core_mean", "amp_score_mean", "amp_reward_mean", "amp_model_loaded_mean", "amp_train_active_mean", "amp_disc_loss_mean", "amp_disc_acc_policy_mean", "amp_disc_acc_expert_mean", ) self._metric_sums: dict[str, float] = {} self._metric_counts: dict[str, int] = {} @staticmethod def _to_float(value): """Best-effort conversion for scalars/tensors/arrays.""" if value is None: return None # Handles torch tensors and numpy arrays without importing either package. if hasattr(value, "detach"): value = value.detach() if hasattr(value, "cpu"): value = value.cpu() if hasattr(value, "numel") and callable(value.numel): if value.numel() == 0: return None if value.numel() == 1: if hasattr(value, "item"): return float(value.item()) return float(value) if hasattr(value, "float") and callable(value.float): return float(value.float().mean().item()) if hasattr(value, "mean") and callable(value.mean): try: return float(value.mean()) except Exception: pass try: return float(value) except Exception: return None def _collect_from_dict(self, data: dict): for key in self._tracked: val = self._to_float(data.get(key)) if val is None: continue self._metric_sums[key] = self._metric_sums.get(key, 0.0) + val self._metric_counts[key] = self._metric_counts.get(key, 0) + 1 def process_infos(self, infos, done_indices): # Keep default score handling. super().process_infos(infos, done_indices) if not infos: return if isinstance(infos, dict): self._collect_from_dict(infos) log = infos.get("log") if isinstance(log, dict): self._collect_from_dict(log) episode = infos.get("episode") if isinstance(episode, dict): self._collect_from_dict(episode) return if isinstance(infos, (list, tuple)): for item in infos: if not isinstance(item, dict): continue self._collect_from_dict(item) log = item.get("log") if isinstance(log, dict): self._collect_from_dict(log) episode = item.get("episode") if isinstance(episode, dict): self._collect_from_dict(episode) def after_print_stats(self, frame, epoch_num, total_time): super().after_print_stats(frame, epoch_num, total_time) parts = [] for key in self._tracked: count = self._metric_counts.get(key, 0) if count <= 0: continue mean_val = self._metric_sums[key] / count parts.append(f"{key}={mean_val:.4f}") if parts: print(f"[CUSTOM][epoch={epoch_num} frame={frame}] " + " ".join(parts)) self._metric_sums.clear() self._metric_counts.clear() def _parse_reward_from_last_ckpt(path: str) -> float: """Extract reward value from checkpoint name like '..._rew_123.45.pth'.""" match = re.search(r"_rew_(-?\d+(?:\.\d+)?)\.pth$", os.path.basename(path)) if match is None: return float("-inf") return float(match.group(1)) def _find_best_resume_checkpoint(log_dir: str, run_name: str) -> str | None: """Find previous best checkpoint across historical runs.""" run_dirs = sorted( [ p for p in glob.glob(os.path.join(log_dir, f"{run_name}_*")) if os.path.isdir(p) ], key=os.path.getmtime, reverse=True, ) # Priority 1: canonical best checkpoint from latest available run. for run_dir in run_dirs: best_ckpt = os.path.join(run_dir, "nn", f"{run_name}.pth") if os.path.exists(best_ckpt): return best_ckpt # Priority 2: best "last_*_rew_*.pth" among all runs (highest reward). candidates: list[tuple[float, str]] = [] for run_dir in run_dirs: pattern = os.path.join(run_dir, "nn", f"last_{run_name}_ep_*_rew_*.pth") for ckpt in glob.glob(pattern): candidates.append((_parse_reward_from_last_ckpt(ckpt), ckpt)) if candidates: candidates.sort(key=lambda x: x[0], reverse=True) return candidates[0][1] return None def main(): task_id = "Isaac-T1-GetUp-v0" env_cfg = T1EnvCfg() amp_cfg = env_cfg.rewards.amp_style_prior amp_cfg.params["logit_scale"] = float(args_cli.amp_logit_scale) if args_cli.amp_train_discriminator: expert_path = os.path.abspath(os.path.expanduser(args_cli.amp_expert_features)) if args_cli.amp_expert_features else "" amp_cfg.weight = float(args_cli.amp_reward_weight) amp_cfg.params["amp_train_enabled"] = True amp_cfg.params["amp_enabled"] = False amp_cfg.params["amp_expert_features_path"] = expert_path amp_cfg.params["disc_hidden_dim"] = int(args_cli.amp_disc_hidden_dim) amp_cfg.params["disc_hidden_layers"] = int(args_cli.amp_disc_hidden_layers) amp_cfg.params["disc_lr"] = float(args_cli.amp_disc_lr) amp_cfg.params["disc_weight_decay"] = float(args_cli.amp_disc_weight_decay) amp_cfg.params["disc_update_interval"] = int(args_cli.amp_disc_update_interval) amp_cfg.params["disc_batch_size"] = int(args_cli.amp_disc_batch_size) print(f"[INFO]: AMP online discriminator enabled, expert_features={expert_path or ''}") print(f"[INFO]: AMP reward weight={amp_cfg.weight:.3f}") elif args_cli.amp_model: amp_model_path = os.path.abspath(os.path.expanduser(args_cli.amp_model)) amp_cfg.weight = float(args_cli.amp_reward_weight) amp_cfg.params["amp_enabled"] = True amp_cfg.params["amp_model_path"] = amp_model_path amp_cfg.params["amp_train_enabled"] = False print(f"[INFO]: AMP inference enabled, discriminator={amp_model_path}") print(f"[INFO]: AMP reward weight={amp_cfg.weight:.3f}") else: print("[INFO]: AMP disabled (use --amp_model to enable)") if task_id not in gym.registry: gym.register( id=task_id, entry_point="isaaclab.envs:ManagerBasedRLEnv", kwargs={"cfg": env_cfg}, ) env = gym.make(task_id, num_envs=args_cli.num_envs, disable_env_checker=True) wrapped_env = RlGamesVecEnvWrapper(env, rl_device=args_cli.device, clip_obs=5.0, clip_actions=1.0) vecenv.register("as_is", lambda config_name, num_actors, **kwargs: wrapped_env) env_configurations.register("rlgym", {"vecenv_type": "as_is", "env_creator": lambda **kwargs: wrapped_env}) config_path = os.path.join(os.path.dirname(__file__), "config", "ppo_cfg.yaml") with open(config_path, "r") as f: rl_config = yaml.safe_load(f) run_name = "T1_GetUp" log_dir = os.path.join(os.path.dirname(__file__), "logs") rl_config["params"]["config"]["train_dir"] = log_dir rl_config["params"]["config"]["name"] = run_name rl_config["params"]["config"]["env_name"] = "rlgym" checkpoint_path = _find_best_resume_checkpoint(log_dir, run_name) if checkpoint_path is not None: print(f"[INFO]: resume from checkpoint: {checkpoint_path}") rl_config["params"]["config"]["load_path"] = checkpoint_path else: print("[INFO]: no checkpoint found, train from scratch") runner = Runner(algo_observer=T1MetricObserver()) runner.load(rl_config) try: runner.run({"train": True, "play": False, "checkpoint": checkpoint_path, "vec_env": wrapped_env}) finally: wrapped_env.close() simulation_app.close() if __name__ == "__main__": main()