112 lines
3.8 KiB
Python
112 lines
3.8 KiB
Python
import sys
|
|
import os
|
|
import argparse
|
|
import glob
|
|
import re
|
|
|
|
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
|
|
if PROJECT_ROOT not in sys.path:
|
|
sys.path.insert(0, PROJECT_ROOT)
|
|
|
|
from isaaclab.app import AppLauncher
|
|
|
|
parser = argparse.ArgumentParser(description="Train T1 get-up policy.")
|
|
parser.add_argument("--num_envs", type=int, default=8192, help="Number of parallel environments")
|
|
parser.add_argument("--seed", type=int, default=42, help="Random seed")
|
|
AppLauncher.add_app_launcher_args(parser)
|
|
args_cli = parser.parse_args()
|
|
|
|
app_launcher = AppLauncher(args_cli)
|
|
simulation_app = app_launcher.app
|
|
|
|
import gymnasium as gym
|
|
import yaml
|
|
from isaaclab_rl.rl_games import RlGamesVecEnvWrapper
|
|
from rl_games.torch_runner import Runner
|
|
from rl_games.common import env_configurations, vecenv
|
|
|
|
from rl_game.get_up.config.t1_env_cfg import T1EnvCfg
|
|
|
|
|
|
def _parse_reward_from_last_ckpt(path: str) -> float:
|
|
"""Extract reward value from checkpoint name like '..._rew_123.45.pth'."""
|
|
match = re.search(r"_rew_(-?\d+(?:\.\d+)?)\.pth$", os.path.basename(path))
|
|
if match is None:
|
|
return float("-inf")
|
|
return float(match.group(1))
|
|
|
|
|
|
def _find_best_resume_checkpoint(log_dir: str, run_name: str) -> str | None:
|
|
"""Find previous best checkpoint across historical runs."""
|
|
run_dirs = sorted(
|
|
[
|
|
p
|
|
for p in glob.glob(os.path.join(log_dir, f"{run_name}_*"))
|
|
if os.path.isdir(p)
|
|
],
|
|
key=os.path.getmtime,
|
|
reverse=True,
|
|
)
|
|
|
|
# Priority 1: canonical best checkpoint from latest available run.
|
|
for run_dir in run_dirs:
|
|
best_ckpt = os.path.join(run_dir, "nn", f"{run_name}.pth")
|
|
if os.path.exists(best_ckpt):
|
|
return best_ckpt
|
|
|
|
# Priority 2: best "last_*_rew_*.pth" among all runs (highest reward).
|
|
candidates: list[tuple[float, str]] = []
|
|
for run_dir in run_dirs:
|
|
pattern = os.path.join(run_dir, "nn", f"last_{run_name}_ep_*_rew_*.pth")
|
|
for ckpt in glob.glob(pattern):
|
|
candidates.append((_parse_reward_from_last_ckpt(ckpt), ckpt))
|
|
if candidates:
|
|
candidates.sort(key=lambda x: x[0], reverse=True)
|
|
return candidates[0][1]
|
|
|
|
return None
|
|
|
|
|
|
def main():
|
|
task_id = "Isaac-T1-GetUp-v0"
|
|
if task_id not in gym.registry:
|
|
gym.register(
|
|
id=task_id,
|
|
entry_point="isaaclab.envs:ManagerBasedRLEnv",
|
|
kwargs={"cfg": T1EnvCfg()},
|
|
)
|
|
|
|
env = gym.make(task_id, num_envs=args_cli.num_envs, disable_env_checker=True)
|
|
wrapped_env = RlGamesVecEnvWrapper(env, rl_device=args_cli.device, clip_obs=5.0, clip_actions=1.0)
|
|
|
|
vecenv.register("as_is", lambda config_name, num_actors, **kwargs: wrapped_env)
|
|
env_configurations.register("rlgym", {"vecenv_type": "as_is", "env_creator": lambda **kwargs: wrapped_env})
|
|
|
|
config_path = os.path.join(os.path.dirname(__file__), "config", "ppo_cfg.yaml")
|
|
with open(config_path, "r") as f:
|
|
rl_config = yaml.safe_load(f)
|
|
|
|
run_name = "T1_GetUp"
|
|
log_dir = os.path.join(os.path.dirname(__file__), "logs")
|
|
rl_config["params"]["config"]["train_dir"] = log_dir
|
|
rl_config["params"]["config"]["name"] = run_name
|
|
rl_config["params"]["config"]["env_name"] = "rlgym"
|
|
|
|
checkpoint_path = None #_find_best_resume_checkpoint(log_dir, run_name)
|
|
if checkpoint_path is not None:
|
|
print(f"[INFO]: resume from checkpoint: {checkpoint_path}")
|
|
rl_config["params"]["config"]["load_path"] = checkpoint_path
|
|
else:
|
|
print("[INFO]: no checkpoint found, train from scratch")
|
|
|
|
runner = Runner()
|
|
runner.load(rl_config)
|
|
try:
|
|
runner.run({"train": True, "play": False, "checkpoint": checkpoint_path, "vec_env": wrapped_env})
|
|
finally:
|
|
wrapped_env.close()
|
|
simulation_app.close()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main() |