use amp strategy
This commit is contained in:
@@ -13,6 +13,36 @@ from isaaclab.app import AppLauncher
|
||||
parser = argparse.ArgumentParser(description="Train T1 get-up policy.")
|
||||
parser.add_argument("--num_envs", type=int, default=8192, help="Number of parallel environments")
|
||||
parser.add_argument("--seed", type=int, default=42, help="Random seed")
|
||||
parser.add_argument(
|
||||
"--amp_model",
|
||||
type=str,
|
||||
default="",
|
||||
help="TorchScript AMP discriminator path (.pt/.jit). Empty disables AMP reward.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--amp_reward_weight",
|
||||
type=float,
|
||||
default=0.6,
|
||||
help="Reward term weight for AMP style prior when --amp_model is provided.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--amp_expert_features",
|
||||
type=str,
|
||||
default="",
|
||||
help="Path to torch file containing expert AMP features (N, D) for online discriminator training.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--amp_train_discriminator",
|
||||
action="store_true",
|
||||
help="Enable online AMP discriminator updates using --amp_expert_features.",
|
||||
)
|
||||
parser.add_argument("--amp_disc_hidden_dim", type=int, default=256, help="Hidden width for AMP discriminator.")
|
||||
parser.add_argument("--amp_disc_hidden_layers", type=int, default=2, help="Hidden layer count for AMP discriminator.")
|
||||
parser.add_argument("--amp_disc_lr", type=float, default=3e-4, help="Learning rate for AMP discriminator.")
|
||||
parser.add_argument("--amp_disc_weight_decay", type=float, default=1e-6, help="Weight decay for AMP discriminator.")
|
||||
parser.add_argument("--amp_disc_update_interval", type=int, default=4, help="Train discriminator every N reward calls.")
|
||||
parser.add_argument("--amp_disc_batch_size", type=int, default=1024, help="Discriminator train batch size.")
|
||||
parser.add_argument("--amp_logit_scale", type=float, default=1.0, help="Scale before sigmoid(logits) for AMP score.")
|
||||
AppLauncher.add_app_launcher_args(parser)
|
||||
args_cli = parser.parse_args()
|
||||
|
||||
@@ -22,12 +52,113 @@ simulation_app = app_launcher.app
|
||||
import gymnasium as gym
|
||||
import yaml
|
||||
from isaaclab_rl.rl_games import RlGamesVecEnvWrapper
|
||||
from rl_games.common.algo_observer import DefaultAlgoObserver
|
||||
from rl_games.torch_runner import Runner
|
||||
from rl_games.common import env_configurations, vecenv
|
||||
|
||||
from rl_game.get_up.config.t1_env_cfg import T1EnvCfg
|
||||
|
||||
|
||||
class T1MetricObserver(DefaultAlgoObserver):
|
||||
"""Collect custom env metrics and print to terminal."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._tracked = (
|
||||
"upright_mean",
|
||||
"foot_support_ratio_mean",
|
||||
"arm_support_ratio_mean",
|
||||
"hip_roll_mean",
|
||||
"stand_core_mean",
|
||||
"amp_score_mean",
|
||||
"amp_reward_mean",
|
||||
"amp_model_loaded_mean",
|
||||
"amp_train_active_mean",
|
||||
"amp_disc_loss_mean",
|
||||
"amp_disc_acc_policy_mean",
|
||||
"amp_disc_acc_expert_mean",
|
||||
)
|
||||
self._metric_sums: dict[str, float] = {}
|
||||
self._metric_counts: dict[str, int] = {}
|
||||
|
||||
@staticmethod
|
||||
def _to_float(value):
|
||||
"""Best-effort conversion for scalars/tensors/arrays."""
|
||||
if value is None:
|
||||
return None
|
||||
# Handles torch tensors and numpy arrays without importing either package.
|
||||
if hasattr(value, "detach"):
|
||||
value = value.detach()
|
||||
if hasattr(value, "cpu"):
|
||||
value = value.cpu()
|
||||
if hasattr(value, "numel") and callable(value.numel):
|
||||
if value.numel() == 0:
|
||||
return None
|
||||
if value.numel() == 1:
|
||||
if hasattr(value, "item"):
|
||||
return float(value.item())
|
||||
return float(value)
|
||||
if hasattr(value, "float") and callable(value.float):
|
||||
return float(value.float().mean().item())
|
||||
if hasattr(value, "mean") and callable(value.mean):
|
||||
try:
|
||||
return float(value.mean())
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
return float(value)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def _collect_from_dict(self, data: dict):
|
||||
for key in self._tracked:
|
||||
val = self._to_float(data.get(key))
|
||||
if val is None:
|
||||
continue
|
||||
self._metric_sums[key] = self._metric_sums.get(key, 0.0) + val
|
||||
self._metric_counts[key] = self._metric_counts.get(key, 0) + 1
|
||||
|
||||
def process_infos(self, infos, done_indices):
|
||||
# Keep default score handling.
|
||||
super().process_infos(infos, done_indices)
|
||||
if not infos:
|
||||
return
|
||||
if isinstance(infos, dict):
|
||||
self._collect_from_dict(infos)
|
||||
log = infos.get("log")
|
||||
if isinstance(log, dict):
|
||||
self._collect_from_dict(log)
|
||||
episode = infos.get("episode")
|
||||
if isinstance(episode, dict):
|
||||
self._collect_from_dict(episode)
|
||||
return
|
||||
if isinstance(infos, (list, tuple)):
|
||||
for item in infos:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
self._collect_from_dict(item)
|
||||
log = item.get("log")
|
||||
if isinstance(log, dict):
|
||||
self._collect_from_dict(log)
|
||||
episode = item.get("episode")
|
||||
if isinstance(episode, dict):
|
||||
self._collect_from_dict(episode)
|
||||
|
||||
def after_print_stats(self, frame, epoch_num, total_time):
|
||||
super().after_print_stats(frame, epoch_num, total_time)
|
||||
parts = []
|
||||
for key in self._tracked:
|
||||
count = self._metric_counts.get(key, 0)
|
||||
if count <= 0:
|
||||
continue
|
||||
mean_val = self._metric_sums[key] / count
|
||||
parts.append(f"{key}={mean_val:.4f}")
|
||||
if parts:
|
||||
print(f"[CUSTOM][epoch={epoch_num} frame={frame}] " + " ".join(parts))
|
||||
self._metric_sums.clear()
|
||||
self._metric_counts.clear()
|
||||
|
||||
|
||||
def _parse_reward_from_last_ckpt(path: str) -> float:
|
||||
"""Extract reward value from checkpoint name like '..._rew_123.45.pth'."""
|
||||
match = re.search(r"_rew_(-?\d+(?:\.\d+)?)\.pth$", os.path.basename(path))
|
||||
@@ -69,11 +200,40 @@ def _find_best_resume_checkpoint(log_dir: str, run_name: str) -> str | None:
|
||||
|
||||
def main():
|
||||
task_id = "Isaac-T1-GetUp-v0"
|
||||
env_cfg = T1EnvCfg()
|
||||
|
||||
amp_cfg = env_cfg.rewards.amp_style_prior
|
||||
amp_cfg.params["logit_scale"] = float(args_cli.amp_logit_scale)
|
||||
if args_cli.amp_train_discriminator:
|
||||
expert_path = os.path.abspath(os.path.expanduser(args_cli.amp_expert_features)) if args_cli.amp_expert_features else ""
|
||||
amp_cfg.weight = float(args_cli.amp_reward_weight)
|
||||
amp_cfg.params["amp_train_enabled"] = True
|
||||
amp_cfg.params["amp_enabled"] = False
|
||||
amp_cfg.params["amp_expert_features_path"] = expert_path
|
||||
amp_cfg.params["disc_hidden_dim"] = int(args_cli.amp_disc_hidden_dim)
|
||||
amp_cfg.params["disc_hidden_layers"] = int(args_cli.amp_disc_hidden_layers)
|
||||
amp_cfg.params["disc_lr"] = float(args_cli.amp_disc_lr)
|
||||
amp_cfg.params["disc_weight_decay"] = float(args_cli.amp_disc_weight_decay)
|
||||
amp_cfg.params["disc_update_interval"] = int(args_cli.amp_disc_update_interval)
|
||||
amp_cfg.params["disc_batch_size"] = int(args_cli.amp_disc_batch_size)
|
||||
print(f"[INFO]: AMP online discriminator enabled, expert_features={expert_path or '<missing>'}")
|
||||
print(f"[INFO]: AMP reward weight={amp_cfg.weight:.3f}")
|
||||
elif args_cli.amp_model:
|
||||
amp_model_path = os.path.abspath(os.path.expanduser(args_cli.amp_model))
|
||||
amp_cfg.weight = float(args_cli.amp_reward_weight)
|
||||
amp_cfg.params["amp_enabled"] = True
|
||||
amp_cfg.params["amp_model_path"] = amp_model_path
|
||||
amp_cfg.params["amp_train_enabled"] = False
|
||||
print(f"[INFO]: AMP inference enabled, discriminator={amp_model_path}")
|
||||
print(f"[INFO]: AMP reward weight={amp_cfg.weight:.3f}")
|
||||
else:
|
||||
print("[INFO]: AMP disabled (use --amp_model to enable)")
|
||||
|
||||
if task_id not in gym.registry:
|
||||
gym.register(
|
||||
id=task_id,
|
||||
entry_point="isaaclab.envs:ManagerBasedRLEnv",
|
||||
kwargs={"cfg": T1EnvCfg()},
|
||||
kwargs={"cfg": env_cfg},
|
||||
)
|
||||
|
||||
env = gym.make(task_id, num_envs=args_cli.num_envs, disable_env_checker=True)
|
||||
@@ -92,14 +252,14 @@ def main():
|
||||
rl_config["params"]["config"]["name"] = run_name
|
||||
rl_config["params"]["config"]["env_name"] = "rlgym"
|
||||
|
||||
checkpoint_path = None #_find_best_resume_checkpoint(log_dir, run_name)
|
||||
checkpoint_path = _find_best_resume_checkpoint(log_dir, run_name)
|
||||
if checkpoint_path is not None:
|
||||
print(f"[INFO]: resume from checkpoint: {checkpoint_path}")
|
||||
rl_config["params"]["config"]["load_path"] = checkpoint_path
|
||||
else:
|
||||
print("[INFO]: no checkpoint found, train from scratch")
|
||||
|
||||
runner = Runner()
|
||||
runner = Runner(algo_observer=T1MetricObserver())
|
||||
runner.load(rl_config)
|
||||
try:
|
||||
runner.run({"train": True, "play": False, "checkpoint": checkpoint_path, "vec_env": wrapped_env})
|
||||
|
||||
Reference in New Issue
Block a user