Add AMP get-up pipeline with sequence discriminator and git-sourced expert data
This commit is contained in:
62
rl_game/get_up/amp/README.md
Normal file
62
rl_game/get_up/amp/README.md
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
# AMP Tools for `get_up`
|
||||||
|
|
||||||
|
This folder contains all AMP-related code for `rl_game/get_up`.
|
||||||
|
|
||||||
|
## Files
|
||||||
|
|
||||||
|
- `amp_rewards.py`: AMP discriminator + reward function used by training config.
|
||||||
|
- `amp_motion.py`: Build AMP expert features from local get-up keyframe YAML files.
|
||||||
|
- `migrate_legged_lab_expert_template.py`: Template converter for migrating external expert data (for example legged_lab outputs) to `expert_features.pt`.
|
||||||
|
|
||||||
|
## Quick start
|
||||||
|
|
||||||
|
Generate expert features from current local keyframes:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python rl_game/get_up/train.py --amp_from_keyframes --headless
|
||||||
|
```
|
||||||
|
|
||||||
|
Convert external motion/expert file to AMP template:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python rl_game/get_up/amp/migrate_legged_lab_expert_template.py \
|
||||||
|
--input /path/to/source_data.pt \
|
||||||
|
--output rl_game/get_up/amp/expert_features.pt \
|
||||||
|
--input_key expert_features \
|
||||||
|
--feature_dim 55 \
|
||||||
|
--repeat 4
|
||||||
|
```
|
||||||
|
|
||||||
|
Convert downloaded `legged_lab` motion pickles directly:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python rl_game/get_up/amp/migrate_legged_lab_expert_template.py \
|
||||||
|
--input third_party/legged_lab/source/legged_lab/legged_lab/data/MotionData/g1_29dof/amp/walk_and_run \
|
||||||
|
--input_glob "*.pkl" \
|
||||||
|
--target_dof 23 \
|
||||||
|
--feature_dim 55 \
|
||||||
|
--clip_weight_mode uniform \
|
||||||
|
--output rl_game/get_up/amp/expert_features.pt
|
||||||
|
```
|
||||||
|
|
||||||
|
For get-up data from gitee legged_lab, use git-like focused clip sampling:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python rl_game/get_up/amp/migrate_legged_lab_expert_template.py \
|
||||||
|
--input third_party/legged_lab_gitee/source/legged_lab/legged_lab/data/MotionData/g1_29dof/amp/get_up \
|
||||||
|
--input_glob "*.pkl" \
|
||||||
|
--target_dof 23 \
|
||||||
|
--feature_dim 55 \
|
||||||
|
--clip_weight_mode git_getup_focus \
|
||||||
|
--output rl_game/get_up/amp/expert_features.pt
|
||||||
|
```
|
||||||
|
|
||||||
|
Then train with online AMP discriminator:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python rl_game/get_up/train.py \
|
||||||
|
--amp_train_discriminator \
|
||||||
|
--amp_expert_features rl_game/get_up/amp/expert_features.pt \
|
||||||
|
--amp_reward_weight 0.6 \
|
||||||
|
--headless
|
||||||
|
```
|
||||||
7
rl_game/get_up/amp/__init__.py
Normal file
7
rl_game/get_up/amp/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
from .amp_motion import build_amp_expert_features_from_getup_keyframes
|
||||||
|
from .amp_rewards import amp_style_prior_reward
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"amp_style_prior_reward",
|
||||||
|
"build_amp_expert_features_from_getup_keyframes",
|
||||||
|
]
|
||||||
186
rl_game/get_up/amp/amp_motion.py
Normal file
186
rl_game/get_up/amp/amp_motion.py
Normal file
@@ -0,0 +1,186 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
|
||||||
|
RIGHT_JOINT_SIGN_FLIP = {
|
||||||
|
"Shoulder_Roll",
|
||||||
|
"Elbow_Yaw",
|
||||||
|
"Hip_Roll",
|
||||||
|
"Hip_Yaw",
|
||||||
|
"Ankle_Roll",
|
||||||
|
}
|
||||||
|
|
||||||
|
JOINT_NAME_ALIAS = {
|
||||||
|
"AAHead_yaw": "Head_yaw",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _safe_load_yaml(path: Path) -> dict[str, Any]:
|
||||||
|
with path.open("r", encoding="utf-8") as f:
|
||||||
|
payload = yaml.safe_load(f) or {}
|
||||||
|
if not isinstance(payload, dict):
|
||||||
|
raise ValueError(f"Invalid keyframe yaml: {path}")
|
||||||
|
return payload
|
||||||
|
|
||||||
|
|
||||||
|
def _build_interpolated_motor_table(
|
||||||
|
keyframes: list[dict[str, Any]],
|
||||||
|
sample_dt: float,
|
||||||
|
) -> tuple[torch.Tensor, list[str], torch.Tensor]:
|
||||||
|
if len(keyframes) == 0:
|
||||||
|
raise ValueError("No keyframes in yaml")
|
||||||
|
|
||||||
|
durations: list[float] = []
|
||||||
|
motor_names: set[str] = set()
|
||||||
|
for frame in keyframes:
|
||||||
|
durations.append(float(frame.get("delta", 0.1)))
|
||||||
|
motors = frame.get("motor_positions", {})
|
||||||
|
if isinstance(motors, dict):
|
||||||
|
motor_names.update(motors.keys())
|
||||||
|
names = sorted(motor_names)
|
||||||
|
if len(names) == 0:
|
||||||
|
raise ValueError("No motor_positions found in keyframes")
|
||||||
|
|
||||||
|
frame_count = len(keyframes)
|
||||||
|
motor_count = len(names)
|
||||||
|
times = torch.zeros(frame_count, dtype=torch.float32)
|
||||||
|
values = torch.full((frame_count, motor_count), float("nan"), dtype=torch.float32)
|
||||||
|
|
||||||
|
elapsed = 0.0
|
||||||
|
name_to_idx = {name: idx for idx, name in enumerate(names)}
|
||||||
|
for i, frame in enumerate(keyframes):
|
||||||
|
elapsed += max(float(frame.get("delta", 0.1)), 1e-4)
|
||||||
|
times[i] = elapsed
|
||||||
|
motors = frame.get("motor_positions", {})
|
||||||
|
if not isinstance(motors, dict):
|
||||||
|
continue
|
||||||
|
for motor_name, motor_deg in motors.items():
|
||||||
|
if motor_name not in name_to_idx:
|
||||||
|
continue
|
||||||
|
values[i, name_to_idx[motor_name]] = float(motor_deg) * (torch.pi / 180.0)
|
||||||
|
|
||||||
|
first_valid = torch.nan_to_num(values[0], nan=0.0)
|
||||||
|
values[0] = first_valid
|
||||||
|
for i in range(1, frame_count):
|
||||||
|
cur = values[i]
|
||||||
|
prev = values[i - 1]
|
||||||
|
values[i] = torch.where(torch.isnan(cur), prev, cur)
|
||||||
|
|
||||||
|
sample_dt = max(float(sample_dt), 1e-3)
|
||||||
|
sample_times = torch.arange(0.0, float(times[-1]) + 1e-6, sample_dt, dtype=torch.float32)
|
||||||
|
sample_times[0] = max(sample_times[0], 1e-6)
|
||||||
|
sample_times = torch.clamp(sample_times, min=times[0], max=times[-1])
|
||||||
|
|
||||||
|
upper = torch.searchsorted(times, sample_times, right=True)
|
||||||
|
upper = torch.clamp(upper, min=1, max=frame_count - 1)
|
||||||
|
lower = upper - 1
|
||||||
|
|
||||||
|
t0 = times.index_select(0, lower)
|
||||||
|
t1 = times.index_select(0, upper)
|
||||||
|
v0 = values.index_select(0, lower)
|
||||||
|
v1 = values.index_select(0, upper)
|
||||||
|
alpha = ((sample_times - t0) / (t1 - t0 + 1e-6)).unsqueeze(-1)
|
||||||
|
interpolated = v0 + alpha * (v1 - v0)
|
||||||
|
|
||||||
|
return interpolated, names, sample_times
|
||||||
|
|
||||||
|
|
||||||
|
def _map_motors_to_joint_targets(
|
||||||
|
motor_table: torch.Tensor,
|
||||||
|
motor_names: list[str],
|
||||||
|
joint_names: list[str],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
motor_to_idx = {name: idx for idx, name in enumerate(motor_names)}
|
||||||
|
joint_targets = torch.zeros((motor_table.shape[0], len(joint_names)), dtype=torch.float32)
|
||||||
|
|
||||||
|
for joint_idx, joint_name in enumerate(joint_names):
|
||||||
|
alias_name = JOINT_NAME_ALIAS.get(joint_name, joint_name)
|
||||||
|
base_name = alias_name.replace("Left_", "").replace("Right_", "")
|
||||||
|
source_name = None
|
||||||
|
if alias_name in motor_to_idx:
|
||||||
|
source_name = alias_name
|
||||||
|
elif base_name in motor_to_idx:
|
||||||
|
source_name = base_name
|
||||||
|
if source_name is None:
|
||||||
|
continue
|
||||||
|
sign = 1.0
|
||||||
|
if alias_name.startswith("Right_") and base_name in RIGHT_JOINT_SIGN_FLIP:
|
||||||
|
sign = -1.0
|
||||||
|
joint_targets[:, joint_idx] = sign * motor_table[:, motor_to_idx[source_name]]
|
||||||
|
|
||||||
|
return joint_targets
|
||||||
|
|
||||||
|
|
||||||
|
def _features_from_keyframes(
|
||||||
|
keyframe_yaml_path: str,
|
||||||
|
joint_names: list[str],
|
||||||
|
sample_dt: float,
|
||||||
|
repeat_count: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
payload = _safe_load_yaml(Path(keyframe_yaml_path).expanduser().resolve())
|
||||||
|
keyframes = payload.get("keyframes", [])
|
||||||
|
if not isinstance(keyframes, list):
|
||||||
|
raise ValueError(f"Invalid keyframe list in {keyframe_yaml_path}")
|
||||||
|
|
||||||
|
motor_table, motor_names, _ = _build_interpolated_motor_table(keyframes, sample_dt=sample_dt)
|
||||||
|
joint_pos = _map_motors_to_joint_targets(motor_table, motor_names, joint_names)
|
||||||
|
joint_vel = torch.zeros_like(joint_pos)
|
||||||
|
if joint_pos.shape[0] > 1:
|
||||||
|
joint_vel[1:] = (joint_pos[1:] - joint_pos[:-1]) / max(sample_dt, 1e-3)
|
||||||
|
joint_vel[0] = joint_vel[1]
|
||||||
|
root_lin = torch.zeros((joint_pos.shape[0], 3), dtype=torch.float32)
|
||||||
|
root_ang = torch.zeros((joint_pos.shape[0], 3), dtype=torch.float32)
|
||||||
|
projected_gravity = torch.zeros((joint_pos.shape[0], 3), dtype=torch.float32)
|
||||||
|
projected_gravity[:, 2] = -1.0
|
||||||
|
features = torch.cat([joint_pos, joint_vel, root_lin, root_ang, projected_gravity], dim=-1)
|
||||||
|
repeat_count = max(int(repeat_count), 1)
|
||||||
|
if repeat_count > 1:
|
||||||
|
features = features.repeat(repeat_count, 1)
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
|
def build_amp_expert_features_from_getup_keyframes(
|
||||||
|
*,
|
||||||
|
front_yaml_path: str,
|
||||||
|
back_yaml_path: str,
|
||||||
|
joint_names: list[str],
|
||||||
|
output_path: str,
|
||||||
|
sample_dt: float = 0.04,
|
||||||
|
repeat_count: int = 16,
|
||||||
|
) -> tuple[str, tuple[int, int]]:
|
||||||
|
"""Generate AMP expert feature tensor from front/back get-up keyframe yaml files."""
|
||||||
|
front_features = _features_from_keyframes(
|
||||||
|
keyframe_yaml_path=front_yaml_path,
|
||||||
|
joint_names=joint_names,
|
||||||
|
sample_dt=sample_dt,
|
||||||
|
repeat_count=repeat_count,
|
||||||
|
)
|
||||||
|
back_features = _features_from_keyframes(
|
||||||
|
keyframe_yaml_path=back_yaml_path,
|
||||||
|
joint_names=joint_names,
|
||||||
|
sample_dt=sample_dt,
|
||||||
|
repeat_count=repeat_count,
|
||||||
|
)
|
||||||
|
expert_features = torch.cat([front_features, back_features], dim=0).contiguous()
|
||||||
|
|
||||||
|
out_path = Path(output_path).expanduser().resolve()
|
||||||
|
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
torch.save(
|
||||||
|
{
|
||||||
|
"expert_features": expert_features,
|
||||||
|
"meta": {
|
||||||
|
"front_yaml_path": str(Path(front_yaml_path).expanduser().resolve()),
|
||||||
|
"back_yaml_path": str(Path(back_yaml_path).expanduser().resolve()),
|
||||||
|
"sample_dt": float(sample_dt),
|
||||||
|
"repeat_count": int(repeat_count),
|
||||||
|
"feature_dim": int(expert_features.shape[-1]),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
str(out_path),
|
||||||
|
)
|
||||||
|
return str(out_path), (int(expert_features.shape[0]), int(expert_features.shape[1]))
|
||||||
457
rl_game/get_up/amp/amp_rewards.py
Normal file
457
rl_game/get_up/amp/amp_rewards.py
Normal file
@@ -0,0 +1,457 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pickle
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import joblib
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from isaaclab.envs import ManagerBasedRLEnv
|
||||||
|
|
||||||
|
|
||||||
|
def _safe_tensor(x: torch.Tensor, nan: float = 0.0, pos: float = 1e3, neg: float = -1e3) -> torch.Tensor:
|
||||||
|
return torch.nan_to_num(x, nan=nan, posinf=pos, neginf=neg)
|
||||||
|
|
||||||
|
|
||||||
|
class AMPDiscriminator(nn.Module):
|
||||||
|
"""Lightweight discriminator used by online AMP updates."""
|
||||||
|
|
||||||
|
def __init__(self, input_dim: int, hidden_dims: tuple[int, ...]):
|
||||||
|
super().__init__()
|
||||||
|
layers: list[nn.Module] = []
|
||||||
|
in_dim = input_dim
|
||||||
|
for h_dim in hidden_dims:
|
||||||
|
layers.append(nn.Linear(in_dim, h_dim))
|
||||||
|
layers.append(nn.LayerNorm(h_dim))
|
||||||
|
layers.append(nn.SiLU())
|
||||||
|
in_dim = h_dim
|
||||||
|
layers.append(nn.Linear(in_dim, 1))
|
||||||
|
self.net = nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.net(x)
|
||||||
|
|
||||||
|
|
||||||
|
def _to_tensor_2d(value) -> torch.Tensor | None:
|
||||||
|
if isinstance(value, torch.Tensor):
|
||||||
|
t = value.float()
|
||||||
|
elif isinstance(value, np.ndarray):
|
||||||
|
t = torch.as_tensor(value, dtype=torch.float32)
|
||||||
|
elif isinstance(value, list):
|
||||||
|
try:
|
||||||
|
t = torch.as_tensor(value, dtype=torch.float32)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
if t.ndim == 1:
|
||||||
|
t = t.unsqueeze(0)
|
||||||
|
if t.ndim != 2:
|
||||||
|
return None
|
||||||
|
return t
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_feature_dim(x: torch.Tensor, feature_dim: int) -> torch.Tensor:
|
||||||
|
if x.shape[-1] == feature_dim:
|
||||||
|
return x
|
||||||
|
if x.shape[-1] > feature_dim:
|
||||||
|
return x[:, :feature_dim]
|
||||||
|
pad = torch.zeros((x.shape[0], feature_dim - x.shape[1]), dtype=x.dtype)
|
||||||
|
return torch.cat([x, pad], dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_expert_bank_from_payload(payload, feature_dim: int) -> dict | None:
|
||||||
|
clip_tensors: list[torch.Tensor] = []
|
||||||
|
clip_names: list[str] = []
|
||||||
|
clip_weights_tensor: torch.Tensor | None = None
|
||||||
|
|
||||||
|
if isinstance(payload, dict):
|
||||||
|
clip_values = payload.get("expert_clips", None)
|
||||||
|
if isinstance(clip_values, list):
|
||||||
|
for i, clip_value in enumerate(clip_values):
|
||||||
|
clip_tensor = _to_tensor_2d(clip_value)
|
||||||
|
if clip_tensor is None:
|
||||||
|
continue
|
||||||
|
clip_tensors.append(_normalize_feature_dim(clip_tensor, feature_dim))
|
||||||
|
clip_names.append(f"clip_{i}")
|
||||||
|
raw_names = payload.get("clip_names", None)
|
||||||
|
if isinstance(raw_names, list) and len(raw_names) == len(clip_tensors):
|
||||||
|
clip_names = [str(n) for n in raw_names]
|
||||||
|
raw_weights = payload.get("clip_weights", None)
|
||||||
|
clip_weights_tensor = _to_tensor_2d(raw_weights) if raw_weights is not None else None
|
||||||
|
if clip_weights_tensor is not None:
|
||||||
|
clip_weights_tensor = clip_weights_tensor.reshape(-1)
|
||||||
|
if clip_weights_tensor.shape[0] != len(clip_tensors):
|
||||||
|
clip_weights_tensor = None
|
||||||
|
|
||||||
|
for key in ("expert_features", "features", "obs"):
|
||||||
|
value = payload.get(key, None)
|
||||||
|
tensor = _to_tensor_2d(value)
|
||||||
|
if tensor is not None:
|
||||||
|
clip_tensors.append(_normalize_feature_dim(tensor, feature_dim))
|
||||||
|
clip_names.append(key)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
tensor = _to_tensor_2d(payload)
|
||||||
|
if tensor is not None:
|
||||||
|
clip_tensors.append(_normalize_feature_dim(tensor, feature_dim))
|
||||||
|
clip_names.append("expert_features")
|
||||||
|
|
||||||
|
clip_tensors = [c for c in clip_tensors if c.shape[0] >= 2]
|
||||||
|
if len(clip_tensors) == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if clip_weights_tensor is None:
|
||||||
|
clip_weights_tensor = torch.ones(len(clip_tensors), dtype=torch.float32)
|
||||||
|
else:
|
||||||
|
clip_weights_tensor = torch.clamp(clip_weights_tensor.float(), min=0.0)
|
||||||
|
if float(torch.sum(clip_weights_tensor).item()) <= 0.0:
|
||||||
|
clip_weights_tensor = torch.ones(len(clip_tensors), dtype=torch.float32)
|
||||||
|
|
||||||
|
flat_features = torch.cat(clip_tensors, dim=0)
|
||||||
|
return {
|
||||||
|
"flat_features": flat_features,
|
||||||
|
"clips": clip_tensors,
|
||||||
|
"clip_names": clip_names,
|
||||||
|
"clip_weights": clip_weights_tensor,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _load_amp_expert_features(
|
||||||
|
expert_features_path: str,
|
||||||
|
device: str,
|
||||||
|
feature_dim: int,
|
||||||
|
fallback_samples: int,
|
||||||
|
) -> dict | None:
|
||||||
|
"""Load expert AMP features bank. Returns None when file is unavailable."""
|
||||||
|
if not expert_features_path:
|
||||||
|
return None
|
||||||
|
p = Path(expert_features_path).expanduser()
|
||||||
|
if not p.is_file():
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
payload = torch.load(str(p), map_location="cpu")
|
||||||
|
except Exception:
|
||||||
|
try:
|
||||||
|
with p.open("rb") as f:
|
||||||
|
payload = pickle.load(f)
|
||||||
|
except Exception:
|
||||||
|
try:
|
||||||
|
payload = joblib.load(str(p))
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
bank = _extract_expert_bank_from_payload(payload, feature_dim=feature_dim)
|
||||||
|
if bank is None:
|
||||||
|
return None
|
||||||
|
flat_features = bank["flat_features"].float()
|
||||||
|
if flat_features.shape[0] < fallback_samples:
|
||||||
|
reps = int((fallback_samples + flat_features.shape[0] - 1) // flat_features.shape[0])
|
||||||
|
flat_features = flat_features.repeat(reps, 1)
|
||||||
|
|
||||||
|
clip_tensors = [clip.float() for clip in bank["clips"]]
|
||||||
|
clip_weights = bank["clip_weights"].float()
|
||||||
|
clip_weights = torch.clamp(clip_weights, min=0.0)
|
||||||
|
if float(torch.sum(clip_weights).item()) <= 0.0:
|
||||||
|
clip_weights = torch.ones_like(clip_weights)
|
||||||
|
clip_weights = clip_weights / torch.sum(clip_weights)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"flat_features": flat_features.to(device=device),
|
||||||
|
"clips": [c.to(device=device) for c in clip_tensors],
|
||||||
|
"clip_names": bank["clip_names"],
|
||||||
|
"clip_weights": clip_weights.to(device=device),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _policy_sequence_features(
|
||||||
|
env: ManagerBasedRLEnv,
|
||||||
|
current_features: torch.Tensor,
|
||||||
|
history_steps: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Build rolling policy sequence features [N, H, D]."""
|
||||||
|
history_steps = max(int(history_steps), 1)
|
||||||
|
cache_key = "amp_policy_hist_cache"
|
||||||
|
cache = env.extras.get(cache_key, None)
|
||||||
|
if not isinstance(cache, dict):
|
||||||
|
cache = {}
|
||||||
|
env.extras[cache_key] = cache
|
||||||
|
|
||||||
|
hist = cache.get("hist", None)
|
||||||
|
if (
|
||||||
|
not isinstance(hist, torch.Tensor)
|
||||||
|
or hist.shape[0] != env.num_envs
|
||||||
|
or hist.shape[1] != history_steps
|
||||||
|
or hist.shape[2] != current_features.shape[1]
|
||||||
|
):
|
||||||
|
hist = current_features.unsqueeze(1).repeat(1, history_steps, 1)
|
||||||
|
|
||||||
|
if history_steps > 1:
|
||||||
|
hist[:, :-1] = hist[:, 1:].clone()
|
||||||
|
hist[:, -1] = current_features
|
||||||
|
cache["hist"] = hist
|
||||||
|
env.extras[cache_key] = cache
|
||||||
|
return hist
|
||||||
|
|
||||||
|
|
||||||
|
def _sample_expert_sequence_batch(
|
||||||
|
expert_bank: dict,
|
||||||
|
batch_size: int,
|
||||||
|
history_steps: int,
|
||||||
|
device: str,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Sample expert sequence batch [B, H, D] with clip-weighted sampling."""
|
||||||
|
clips = expert_bank["clips"]
|
||||||
|
clip_weights = expert_bank["clip_weights"]
|
||||||
|
clip_count = len(clips)
|
||||||
|
if clip_count <= 0:
|
||||||
|
return torch.empty((0, history_steps, 0), device=device)
|
||||||
|
|
||||||
|
clip_ids = torch.multinomial(clip_weights, num_samples=batch_size, replacement=True)
|
||||||
|
seq_list: list[torch.Tensor] = []
|
||||||
|
for clip_id in clip_ids.tolist():
|
||||||
|
clip = clips[int(clip_id)]
|
||||||
|
clip_len = int(clip.shape[0])
|
||||||
|
if clip_len >= history_steps:
|
||||||
|
max_start = clip_len - history_steps
|
||||||
|
if max_start > 0:
|
||||||
|
start = int(torch.randint(0, max_start + 1, (1,), device=device).item())
|
||||||
|
else:
|
||||||
|
start = 0
|
||||||
|
seq = clip[start : start + history_steps]
|
||||||
|
else:
|
||||||
|
pad = clip[-1:].repeat(history_steps - clip_len, 1)
|
||||||
|
seq = torch.cat([clip, pad], dim=0)
|
||||||
|
seq_list.append(seq)
|
||||||
|
return torch.stack(seq_list, dim=0)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_amp_state(
|
||||||
|
env: ManagerBasedRLEnv,
|
||||||
|
amp_enabled: bool,
|
||||||
|
amp_model_path: str,
|
||||||
|
amp_train_enabled: bool,
|
||||||
|
amp_expert_features_path: str,
|
||||||
|
feature_dim: int,
|
||||||
|
disc_hidden_dim: int,
|
||||||
|
disc_hidden_layers: int,
|
||||||
|
disc_lr: float,
|
||||||
|
disc_weight_decay: float,
|
||||||
|
disc_min_expert_samples: int,
|
||||||
|
disc_history_steps: int,
|
||||||
|
):
|
||||||
|
"""Get cached AMP state (frozen jit or trainable discriminator)."""
|
||||||
|
cache_key = "amp_state_cache"
|
||||||
|
hidden_layers = max(int(disc_hidden_layers), 1)
|
||||||
|
hidden_dim = max(int(disc_hidden_dim), 16)
|
||||||
|
history_steps = max(int(disc_history_steps), 1)
|
||||||
|
state_sig = (
|
||||||
|
bool(amp_enabled),
|
||||||
|
str(amp_model_path),
|
||||||
|
bool(amp_train_enabled),
|
||||||
|
str(amp_expert_features_path),
|
||||||
|
int(feature_dim),
|
||||||
|
history_steps,
|
||||||
|
hidden_dim,
|
||||||
|
hidden_layers,
|
||||||
|
float(disc_lr),
|
||||||
|
float(disc_weight_decay),
|
||||||
|
)
|
||||||
|
cached = env.extras.get(cache_key, None)
|
||||||
|
if isinstance(cached, dict) and cached.get("sig") == state_sig:
|
||||||
|
return cached
|
||||||
|
|
||||||
|
state = {
|
||||||
|
"sig": state_sig,
|
||||||
|
"mode": "disabled",
|
||||||
|
"model": None,
|
||||||
|
"optimizer": None,
|
||||||
|
"expert_bank": None,
|
||||||
|
"disc_history_steps": history_steps,
|
||||||
|
"step": 0,
|
||||||
|
"last_loss": 0.0,
|
||||||
|
"last_acc_policy": 0.0,
|
||||||
|
"last_acc_expert": 0.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
if amp_train_enabled:
|
||||||
|
expert_bank = _load_amp_expert_features(
|
||||||
|
amp_expert_features_path,
|
||||||
|
device=env.device,
|
||||||
|
feature_dim=feature_dim,
|
||||||
|
fallback_samples=max(disc_min_expert_samples, 512),
|
||||||
|
)
|
||||||
|
if expert_bank is not None:
|
||||||
|
model = AMPDiscriminator(input_dim=feature_dim * history_steps, hidden_dims=tuple([hidden_dim] * hidden_layers)).to(env.device)
|
||||||
|
optimizer = torch.optim.AdamW(model.parameters(), lr=float(disc_lr), weight_decay=float(disc_weight_decay))
|
||||||
|
state["mode"] = "trainable"
|
||||||
|
state["model"] = model
|
||||||
|
state["optimizer"] = optimizer
|
||||||
|
state["expert_bank"] = expert_bank
|
||||||
|
elif amp_enabled and amp_model_path:
|
||||||
|
model_path = Path(amp_model_path).expanduser()
|
||||||
|
if model_path.is_file():
|
||||||
|
try:
|
||||||
|
model = torch.jit.load(str(model_path), map_location=env.device)
|
||||||
|
model.eval()
|
||||||
|
state["mode"] = "jit"
|
||||||
|
state["model"] = model
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
env.extras[cache_key] = state
|
||||||
|
return state
|
||||||
|
|
||||||
|
|
||||||
|
def _build_amp_features(env: ManagerBasedRLEnv, feature_clip: float = 8.0) -> torch.Tensor:
|
||||||
|
"""Build AMP-style discriminator features from robot kinematics."""
|
||||||
|
robot_data = env.scene["robot"].data
|
||||||
|
joint_pos_rel = robot_data.joint_pos - robot_data.default_joint_pos
|
||||||
|
joint_vel = robot_data.joint_vel
|
||||||
|
root_lin_vel = robot_data.root_lin_vel_w
|
||||||
|
root_ang_vel = robot_data.root_ang_vel_w
|
||||||
|
projected_gravity = robot_data.projected_gravity_b
|
||||||
|
amp_features = torch.cat([joint_pos_rel, joint_vel, root_lin_vel, root_ang_vel, projected_gravity], dim=-1)
|
||||||
|
amp_features = _safe_tensor(amp_features, nan=0.0, pos=feature_clip, neg=-feature_clip)
|
||||||
|
return torch.clamp(amp_features, min=-feature_clip, max=feature_clip)
|
||||||
|
|
||||||
|
|
||||||
|
def amp_style_prior_reward(
|
||||||
|
env: ManagerBasedRLEnv,
|
||||||
|
amp_enabled: bool = False,
|
||||||
|
amp_model_path: str = "",
|
||||||
|
amp_train_enabled: bool = False,
|
||||||
|
amp_expert_features_path: str = "",
|
||||||
|
disc_hidden_dim: int = 256,
|
||||||
|
disc_hidden_layers: int = 2,
|
||||||
|
disc_lr: float = 3e-4,
|
||||||
|
disc_weight_decay: float = 1e-6,
|
||||||
|
disc_update_interval: int = 4,
|
||||||
|
disc_batch_size: int = 1024,
|
||||||
|
disc_min_expert_samples: int = 2048,
|
||||||
|
disc_history_steps: int = 4,
|
||||||
|
feature_clip: float = 8.0,
|
||||||
|
logit_scale: float = 1.0,
|
||||||
|
amp_reward_gain: float = 1.0,
|
||||||
|
internal_reward_scale: float = 1.0,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""AMP style prior reward with optional online discriminator training."""
|
||||||
|
zeros = torch.zeros(env.num_envs, device=env.device)
|
||||||
|
amp_score = zeros
|
||||||
|
model_loaded = 0.0
|
||||||
|
amp_train_active = 0.0
|
||||||
|
disc_loss = 0.0
|
||||||
|
disc_acc_policy = 0.0
|
||||||
|
disc_acc_expert = 0.0
|
||||||
|
|
||||||
|
amp_features = _build_amp_features(env, feature_clip=feature_clip)
|
||||||
|
amp_state = _get_amp_state(
|
||||||
|
env=env,
|
||||||
|
amp_enabled=amp_enabled,
|
||||||
|
amp_model_path=amp_model_path,
|
||||||
|
amp_train_enabled=amp_train_enabled,
|
||||||
|
amp_expert_features_path=amp_expert_features_path,
|
||||||
|
feature_dim=int(amp_features.shape[-1]),
|
||||||
|
disc_hidden_dim=disc_hidden_dim,
|
||||||
|
disc_hidden_layers=disc_hidden_layers,
|
||||||
|
disc_lr=disc_lr,
|
||||||
|
disc_weight_decay=disc_weight_decay,
|
||||||
|
disc_min_expert_samples=disc_min_expert_samples,
|
||||||
|
disc_history_steps=disc_history_steps,
|
||||||
|
)
|
||||||
|
discriminator = amp_state.get("model", None)
|
||||||
|
history_steps = max(int(amp_state.get("disc_history_steps", disc_history_steps)), 1)
|
||||||
|
policy_seq = _policy_sequence_features(env, amp_features, history_steps=history_steps)
|
||||||
|
policy_seq_flat = policy_seq.reshape(policy_seq.shape[0], -1)
|
||||||
|
if discriminator is not None:
|
||||||
|
model_loaded = 1.0
|
||||||
|
|
||||||
|
if amp_state.get("mode") == "trainable" and discriminator is not None:
|
||||||
|
amp_train_active = 1.0
|
||||||
|
optimizer = amp_state.get("optimizer", None)
|
||||||
|
expert_bank = amp_state.get("expert_bank", None)
|
||||||
|
amp_state["step"] = int(amp_state.get("step", 0)) + 1
|
||||||
|
update_interval = max(int(disc_update_interval), 1)
|
||||||
|
batch_size = max(int(disc_batch_size), 32)
|
||||||
|
|
||||||
|
if optimizer is not None and isinstance(expert_bank, dict) and amp_state["step"] % update_interval == 0:
|
||||||
|
policy_features = policy_seq_flat.detach()
|
||||||
|
policy_count = policy_features.shape[0]
|
||||||
|
if policy_count > batch_size:
|
||||||
|
policy_ids = torch.randint(0, policy_count, (batch_size,), device=env.device)
|
||||||
|
policy_batch = policy_features.index_select(0, policy_ids)
|
||||||
|
else:
|
||||||
|
policy_batch = policy_features
|
||||||
|
|
||||||
|
expert_seq = _sample_expert_sequence_batch(
|
||||||
|
expert_bank=expert_bank,
|
||||||
|
batch_size=policy_batch.shape[0],
|
||||||
|
history_steps=history_steps,
|
||||||
|
device=env.device,
|
||||||
|
)
|
||||||
|
expert_batch = expert_seq.reshape(expert_seq.shape[0], -1)
|
||||||
|
|
||||||
|
discriminator.train()
|
||||||
|
optimizer.zero_grad(set_to_none=True)
|
||||||
|
logits_expert = discriminator(expert_batch).squeeze(-1)
|
||||||
|
logits_policy = discriminator(policy_batch).squeeze(-1)
|
||||||
|
loss_expert = nn.functional.binary_cross_entropy_with_logits(logits_expert, torch.ones_like(logits_expert))
|
||||||
|
loss_policy = nn.functional.binary_cross_entropy_with_logits(logits_policy, torch.zeros_like(logits_policy))
|
||||||
|
loss = 0.5 * (loss_expert + loss_policy)
|
||||||
|
loss.backward()
|
||||||
|
nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0)
|
||||||
|
optimizer.step()
|
||||||
|
discriminator.eval()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
disc_loss = float(loss.detach().item())
|
||||||
|
disc_acc_expert = float((torch.sigmoid(logits_expert) > 0.5).float().mean().item())
|
||||||
|
disc_acc_policy = float((torch.sigmoid(logits_policy) < 0.5).float().mean().item())
|
||||||
|
amp_state["last_loss"] = disc_loss
|
||||||
|
amp_state["last_acc_expert"] = disc_acc_expert
|
||||||
|
amp_state["last_acc_policy"] = disc_acc_policy
|
||||||
|
else:
|
||||||
|
disc_loss = float(amp_state.get("last_loss", 0.0))
|
||||||
|
disc_acc_expert = float(amp_state.get("last_acc_expert", 0.0))
|
||||||
|
disc_acc_policy = float(amp_state.get("last_acc_policy", 0.0))
|
||||||
|
|
||||||
|
if discriminator is not None:
|
||||||
|
discriminator.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
logits = discriminator(policy_seq_flat)
|
||||||
|
if isinstance(logits, (tuple, list)):
|
||||||
|
logits = logits[0]
|
||||||
|
if logits.ndim > 1:
|
||||||
|
logits = logits.squeeze(-1)
|
||||||
|
elif amp_enabled and amp_model_path:
|
||||||
|
# For external scripted models, try temporal then fallback to single frame.
|
||||||
|
model = amp_state.get("model", None)
|
||||||
|
if model is not None:
|
||||||
|
with torch.no_grad():
|
||||||
|
try:
|
||||||
|
logits = model(policy_seq_flat)
|
||||||
|
except Exception:
|
||||||
|
logits = model(amp_features)
|
||||||
|
if isinstance(logits, (tuple, list)):
|
||||||
|
logits = logits[0]
|
||||||
|
if logits.ndim > 1:
|
||||||
|
logits = logits.squeeze(-1)
|
||||||
|
logits = _safe_tensor(logits, nan=0.0, pos=20.0, neg=-20.0)
|
||||||
|
amp_score = torch.sigmoid(logit_scale * logits)
|
||||||
|
amp_score = _safe_tensor(amp_score, nan=0.0, pos=1.0, neg=0.0)
|
||||||
|
|
||||||
|
amp_reward = _safe_tensor(amp_reward_gain * amp_score, nan=0.0, pos=10.0, neg=0.0)
|
||||||
|
|
||||||
|
log_dict = env.extras.get("log", {})
|
||||||
|
if isinstance(log_dict, dict):
|
||||||
|
log_dict["amp_score_mean"] = torch.mean(amp_score).detach().item()
|
||||||
|
log_dict["amp_reward_mean"] = torch.mean(amp_reward).detach().item()
|
||||||
|
log_dict["amp_model_loaded_mean"] = model_loaded
|
||||||
|
log_dict["amp_train_active_mean"] = amp_train_active
|
||||||
|
log_dict["amp_disc_loss_mean"] = disc_loss
|
||||||
|
log_dict["amp_disc_acc_policy_mean"] = disc_acc_policy
|
||||||
|
log_dict["amp_disc_acc_expert_mean"] = disc_acc_expert
|
||||||
|
log_dict["amp_disc_history_steps"] = float(history_steps)
|
||||||
|
env.extras["log"] = log_dict
|
||||||
|
|
||||||
|
return internal_reward_scale * amp_reward
|
||||||
Binary file not shown.
258
rl_game/get_up/amp/migrate_legged_lab_expert_template.py
Normal file
258
rl_game/get_up/amp/migrate_legged_lab_expert_template.py
Normal file
@@ -0,0 +1,258 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import pickle
|
||||||
|
from glob import glob
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import joblib
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
except Exception:
|
||||||
|
torch = None
|
||||||
|
|
||||||
|
GIT_GETUP_FOCUS_CLIP = "fallAndGetUp2_subject2_1200_1370"
|
||||||
|
|
||||||
|
|
||||||
|
def _load_payload(path: Path):
|
||||||
|
if path.suffix.lower() in (".pt", ".pth"):
|
||||||
|
if torch is None:
|
||||||
|
raise RuntimeError("Loading .pt/.pth requires torch to be installed.")
|
||||||
|
return torch.load(str(path), map_location="cpu")
|
||||||
|
if path.suffix.lower() in (".pkl",):
|
||||||
|
try:
|
||||||
|
with path.open("rb") as f:
|
||||||
|
return pickle.load(f)
|
||||||
|
except Exception:
|
||||||
|
return joblib.load(str(path))
|
||||||
|
if path.suffix.lower() in (".json",):
|
||||||
|
with path.open("r", encoding="utf-8") as f:
|
||||||
|
return json.load(f)
|
||||||
|
raise ValueError(f"Unsupported file type: {path.suffix}")
|
||||||
|
|
||||||
|
|
||||||
|
def _to_numpy(payload, key_hint: str = "") -> np.ndarray:
|
||||||
|
if torch is not None and isinstance(payload, torch.Tensor):
|
||||||
|
return payload.detach().cpu().numpy().astype(np.float32)
|
||||||
|
if isinstance(payload, np.ndarray):
|
||||||
|
return payload.astype(np.float32)
|
||||||
|
if isinstance(payload, list):
|
||||||
|
return np.asarray(payload, dtype=np.float32)
|
||||||
|
if isinstance(payload, dict):
|
||||||
|
candidate_keys = [key_hint] if key_hint else []
|
||||||
|
candidate_keys += [
|
||||||
|
"expert_features",
|
||||||
|
"features",
|
||||||
|
"observations",
|
||||||
|
"obs",
|
||||||
|
"joint_pos",
|
||||||
|
"motion",
|
||||||
|
"data",
|
||||||
|
]
|
||||||
|
for key in candidate_keys:
|
||||||
|
if key and key in payload:
|
||||||
|
value = payload[key]
|
||||||
|
if torch is not None and isinstance(value, torch.Tensor):
|
||||||
|
return value.detach().cpu().numpy().astype(np.float32)
|
||||||
|
if isinstance(value, np.ndarray):
|
||||||
|
return value.astype(np.float32)
|
||||||
|
if isinstance(value, list):
|
||||||
|
return np.asarray(value, dtype=np.float32)
|
||||||
|
for value in payload.values():
|
||||||
|
if torch is not None and isinstance(value, torch.Tensor):
|
||||||
|
return value.detach().cpu().numpy().astype(np.float32)
|
||||||
|
if isinstance(value, np.ndarray):
|
||||||
|
return value.astype(np.float32)
|
||||||
|
if isinstance(value, list):
|
||||||
|
return np.asarray(value, dtype=np.float32)
|
||||||
|
raise ValueError("Could not locate tensor-like payload. Provide --input_key when needed.")
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_2d(x: np.ndarray) -> np.ndarray:
|
||||||
|
if x.ndim == 1:
|
||||||
|
return x[None, :]
|
||||||
|
if x.ndim != 2:
|
||||||
|
raise ValueError(f"Expected 2D array [N, D], got shape={tuple(x.shape)}")
|
||||||
|
return x.astype(np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def _from_legged_lab_motion_dict(payload: dict, target_dof: int) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Convert legged_lab motion pickle payload to AMP feature tensor [N, D].
|
||||||
|
Expected keys: fps, root_pos, dof_pos.
|
||||||
|
"""
|
||||||
|
dof_pos = np.asarray(payload.get("dof_pos", None))
|
||||||
|
root_pos = np.asarray(payload.get("root_pos", None))
|
||||||
|
fps = float(payload.get("fps", 30.0))
|
||||||
|
if dof_pos.ndim != 2:
|
||||||
|
raise ValueError("legged_lab payload missing valid dof_pos [N, M].")
|
||||||
|
if root_pos.ndim != 2 or root_pos.shape[0] != dof_pos.shape[0]:
|
||||||
|
root_pos = np.zeros((dof_pos.shape[0], 3), dtype=np.float32)
|
||||||
|
|
||||||
|
dt = 1.0 / max(fps, 1e-3)
|
||||||
|
dof_pos = dof_pos.astype(np.float32)
|
||||||
|
root_pos = root_pos.astype(np.float32)
|
||||||
|
|
||||||
|
# Current get_up AMP feature dim expects 23 dof by default.
|
||||||
|
target_dof = int(target_dof)
|
||||||
|
if dof_pos.shape[1] >= target_dof:
|
||||||
|
dof_pos = dof_pos[:, :target_dof]
|
||||||
|
else:
|
||||||
|
dof_pos = np.pad(dof_pos, ((0, 0), (0, target_dof - dof_pos.shape[1])), mode="constant")
|
||||||
|
|
||||||
|
dof_vel = np.zeros_like(dof_pos, dtype=np.float32)
|
||||||
|
if dof_pos.shape[0] > 1:
|
||||||
|
dof_vel[1:] = (dof_pos[1:] - dof_pos[:-1]) / dt
|
||||||
|
dof_vel[0] = dof_vel[1]
|
||||||
|
|
||||||
|
root_lin = np.zeros((dof_pos.shape[0], 3), dtype=np.float32)
|
||||||
|
if root_pos.shape[0] > 1:
|
||||||
|
root_lin[1:] = (root_pos[1:] - root_pos[:-1]) / dt
|
||||||
|
root_lin[0] = root_lin[1]
|
||||||
|
root_ang = np.zeros((dof_pos.shape[0], 3), dtype=np.float32)
|
||||||
|
gravity = np.zeros((dof_pos.shape[0], 3), dtype=np.float32)
|
||||||
|
gravity[:, 2] = -1.0
|
||||||
|
|
||||||
|
x = np.concatenate([dof_pos, dof_vel, root_lin, root_ang, gravity], axis=-1).astype(np.float32)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_dim(x: np.ndarray, feature_dim: int) -> np.ndarray:
|
||||||
|
if x.shape[1] == feature_dim:
|
||||||
|
return x
|
||||||
|
if x.shape[1] > feature_dim:
|
||||||
|
return x[:, :feature_dim]
|
||||||
|
pad = np.zeros((x.shape[0], feature_dim - x.shape[1]), dtype=np.float32)
|
||||||
|
return np.concatenate([x, pad], axis=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def _collect_input_files(input_path: Path, glob_pattern: str) -> list[Path]:
|
||||||
|
if input_path.is_file():
|
||||||
|
return [input_path]
|
||||||
|
if input_path.is_dir():
|
||||||
|
files = sorted([Path(p) for p in glob(str(input_path / glob_pattern))])
|
||||||
|
if files:
|
||||||
|
return files
|
||||||
|
raise FileNotFoundError(f"No source files found for input={input_path} pattern={glob_pattern}")
|
||||||
|
|
||||||
|
|
||||||
|
def _build_clip_weights(clip_names: list[str], weight_mode: str) -> np.ndarray:
|
||||||
|
if len(clip_names) == 0:
|
||||||
|
return np.zeros((0,), dtype=np.float32)
|
||||||
|
if weight_mode == "uniform":
|
||||||
|
return np.ones((len(clip_names),), dtype=np.float32)
|
||||||
|
if weight_mode == "git_getup_focus":
|
||||||
|
weights = np.array([1.0 if GIT_GETUP_FOCUS_CLIP in n else 0.0 for n in clip_names], dtype=np.float32)
|
||||||
|
if float(np.sum(weights)) <= 0.0:
|
||||||
|
return np.ones((len(clip_names),), dtype=np.float32)
|
||||||
|
return weights
|
||||||
|
return np.ones((len(clip_names),), dtype=np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description=(
|
||||||
|
"Template converter: migrate legged_lab (or other) motion/expert data "
|
||||||
|
"to AMP expert_features.pt for rl_game/get_up."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
parser.add_argument("--input", required=True, type=str, help="Path to source motion/expert file or directory.")
|
||||||
|
parser.add_argument("--output", required=True, type=str, help="Output path for expert_features.pt.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--input_key",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="Optional key to locate tensor inside input payload dict.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--feature_dim",
|
||||||
|
type=int,
|
||||||
|
default=55,
|
||||||
|
help="Target AMP feature dimension. For current get_up config, default=55.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--input_glob",
|
||||||
|
type=str,
|
||||||
|
default="*.pkl",
|
||||||
|
help="Glob used when --input is a directory.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--target_dof",
|
||||||
|
type=int,
|
||||||
|
default=23,
|
||||||
|
help="Target dof count when converting legged_lab pkl dof_pos.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--clip_weight_mode",
|
||||||
|
type=str,
|
||||||
|
default="git_getup_focus",
|
||||||
|
choices=["git_getup_focus", "uniform"],
|
||||||
|
help="Clip sampling weight mode for expert sequence training.",
|
||||||
|
)
|
||||||
|
parser.add_argument("--repeat", type=int, default=1, help="Repeat samples for small datasets.")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
in_path = Path(args.input).expanduser().resolve()
|
||||||
|
input_files = _collect_input_files(in_path, args.input_glob)
|
||||||
|
clip_arrays: list[np.ndarray] = []
|
||||||
|
clip_names: list[str] = []
|
||||||
|
for f in input_files:
|
||||||
|
payload = _load_payload(f)
|
||||||
|
if isinstance(payload, dict) and "dof_pos" in payload and f.suffix.lower() == ".pkl":
|
||||||
|
x = _from_legged_lab_motion_dict(payload, target_dof=int(args.target_dof))
|
||||||
|
else:
|
||||||
|
x = _to_numpy(payload, key_hint=args.input_key)
|
||||||
|
x = _ensure_2d(x)
|
||||||
|
x = _normalize_dim(x.astype(np.float32), int(args.feature_dim))
|
||||||
|
clip_arrays.append(x)
|
||||||
|
clip_names.append(f.stem)
|
||||||
|
|
||||||
|
repeat = max(int(args.repeat), 1)
|
||||||
|
if repeat > 1:
|
||||||
|
clip_arrays = clip_arrays * repeat
|
||||||
|
clip_names = clip_names * repeat
|
||||||
|
|
||||||
|
clip_weights = _build_clip_weights(clip_names, weight_mode=args.clip_weight_mode)
|
||||||
|
x = np.concatenate(clip_arrays, axis=0).astype(np.float32)
|
||||||
|
clip_lengths = [int(c.shape[0]) for c in clip_arrays]
|
||||||
|
clip_offsets: list[int] = []
|
||||||
|
cursor = 0
|
||||||
|
for length in clip_lengths:
|
||||||
|
clip_offsets.append(cursor)
|
||||||
|
cursor += length
|
||||||
|
repeat = max(int(args.repeat), 1)
|
||||||
|
|
||||||
|
out_path = Path(args.output).expanduser().resolve()
|
||||||
|
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
payload = {
|
||||||
|
"expert_features": x,
|
||||||
|
"expert_clips": clip_arrays,
|
||||||
|
"clip_names": clip_names,
|
||||||
|
"clip_weights": clip_weights,
|
||||||
|
"clip_lengths": clip_lengths,
|
||||||
|
"clip_offsets": clip_offsets,
|
||||||
|
"meta": {
|
||||||
|
"source": str(in_path),
|
||||||
|
"source_count": len(input_files),
|
||||||
|
"input_key": args.input_key,
|
||||||
|
"feature_dim": int(args.feature_dim),
|
||||||
|
"target_dof": int(args.target_dof),
|
||||||
|
"input_glob": args.input_glob,
|
||||||
|
"clip_weight_mode": args.clip_weight_mode,
|
||||||
|
"git_getup_focus_clip": GIT_GETUP_FOCUS_CLIP,
|
||||||
|
"repeat": int(repeat),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if torch is not None:
|
||||||
|
torch.save(payload, str(out_path))
|
||||||
|
else:
|
||||||
|
with out_path.open("wb") as f:
|
||||||
|
pickle.dump(payload, f)
|
||||||
|
print(f"[OK] Saved AMP expert features -> {out_path} shape={tuple(x.shape)}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -1,204 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import math
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import yaml
|
|
||||||
|
|
||||||
# AMP feature order must match `_build_amp_features` in `t1_env_cfg.py`:
|
|
||||||
# [joint_pos_rel(23), joint_vel(23), root_lin_vel(3), root_ang_vel(3), projected_gravity(3)].
|
|
||||||
T1_JOINT_NAMES = [
|
|
||||||
"AAHead_yaw",
|
|
||||||
"Head_pitch",
|
|
||||||
"Left_Shoulder_Pitch",
|
|
||||||
"Left_Shoulder_Roll",
|
|
||||||
"Left_Elbow_Pitch",
|
|
||||||
"Left_Elbow_Yaw",
|
|
||||||
"Right_Shoulder_Pitch",
|
|
||||||
"Right_Shoulder_Roll",
|
|
||||||
"Right_Elbow_Pitch",
|
|
||||||
"Right_Elbow_Yaw",
|
|
||||||
"Waist",
|
|
||||||
"Left_Hip_Pitch",
|
|
||||||
"Left_Hip_Roll",
|
|
||||||
"Left_Hip_Yaw",
|
|
||||||
"Left_Knee_Pitch",
|
|
||||||
"Left_Ankle_Pitch",
|
|
||||||
"Left_Ankle_Roll",
|
|
||||||
"Right_Hip_Pitch",
|
|
||||||
"Right_Hip_Roll",
|
|
||||||
"Right_Hip_Yaw",
|
|
||||||
"Right_Knee_Pitch",
|
|
||||||
"Right_Ankle_Pitch",
|
|
||||||
"Right_Ankle_Roll",
|
|
||||||
]
|
|
||||||
|
|
||||||
JOINT_TO_IDX = {name: i for i, name in enumerate(T1_JOINT_NAMES)}
|
|
||||||
|
|
||||||
# Mirror rules aligned with `behaviors/custom/keyframe/keyframe.py`.
|
|
||||||
MOTOR_SYMMETRY = {
|
|
||||||
"Head_yaw": (("Head_yaw",), False),
|
|
||||||
"Head_pitch": (("Head_pitch",), False),
|
|
||||||
"Shoulder_Pitch": (("Left_Shoulder_Pitch", "Right_Shoulder_Pitch"), False),
|
|
||||||
"Shoulder_Roll": (("Left_Shoulder_Roll", "Right_Shoulder_Roll"), True),
|
|
||||||
"Elbow_Pitch": (("Left_Elbow_Pitch", "Right_Elbow_Pitch"), False),
|
|
||||||
"Elbow_Yaw": (("Left_Elbow_Yaw", "Right_Elbow_Yaw"), True),
|
|
||||||
"Waist": (("Waist",), False),
|
|
||||||
"Hip_Pitch": (("Left_Hip_Pitch", "Right_Hip_Pitch"), False),
|
|
||||||
"Hip_Roll": (("Left_Hip_Roll", "Right_Hip_Roll"), True),
|
|
||||||
"Hip_Yaw": (("Left_Hip_Yaw", "Right_Hip_Yaw"), True),
|
|
||||||
"Knee_Pitch": (("Left_Knee_Pitch", "Right_Knee_Pitch"), False),
|
|
||||||
"Ankle_Pitch": (("Left_Ankle_Pitch", "Right_Ankle_Pitch"), False),
|
|
||||||
"Ankle_Roll": (("Left_Ankle_Roll", "Right_Ankle_Roll"), True),
|
|
||||||
}
|
|
||||||
READABLE_TO_POLICY = {"Head_yaw": "AAHead_yaw"}
|
|
||||||
|
|
||||||
|
|
||||||
def decode_keyframe_motor_positions(raw_motor_positions: dict[str, float]) -> dict[str, float]:
|
|
||||||
"""Decode one keyframe into per-joint radians."""
|
|
||||||
out: dict[str, float] = {}
|
|
||||||
deg_to_rad = math.pi / 180.0
|
|
||||||
for readable_name, position_deg in raw_motor_positions.items():
|
|
||||||
if readable_name in MOTOR_SYMMETRY:
|
|
||||||
motor_names, is_inverse_direction = MOTOR_SYMMETRY[readable_name]
|
|
||||||
invert_state = bool(is_inverse_direction)
|
|
||||||
for motor_name in motor_names:
|
|
||||||
signed_deg = position_deg if invert_state else -position_deg
|
|
||||||
invert_state = False
|
|
||||||
out_name = READABLE_TO_POLICY.get(motor_name, motor_name)
|
|
||||||
out[out_name] = float(signed_deg) * deg_to_rad
|
|
||||||
else:
|
|
||||||
out_name = READABLE_TO_POLICY.get(readable_name, readable_name)
|
|
||||||
out[out_name] = float(position_deg) * deg_to_rad
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
def load_sequence(yaml_path: Path) -> list[tuple[float, torch.Tensor]]:
|
|
||||||
"""Load yaml keyframes -> list[(delta_seconds, joint_pos_vec)]."""
|
|
||||||
with yaml_path.open("r", encoding="utf-8") as f:
|
|
||||||
desc = yaml.safe_load(f) or {}
|
|
||||||
out: list[tuple[float, torch.Tensor]] = []
|
|
||||||
for keyframe in desc.get("keyframes", []):
|
|
||||||
delta_s = max(float(keyframe.get("delta", 0.1)), 1e-3)
|
|
||||||
raw = keyframe.get("motor_positions", {}) or {}
|
|
||||||
decoded = decode_keyframe_motor_positions(raw)
|
|
||||||
joint_pos = torch.zeros(len(T1_JOINT_NAMES), dtype=torch.float32)
|
|
||||||
for j_name, j_val in decoded.items():
|
|
||||||
idx = JOINT_TO_IDX.get(j_name, None)
|
|
||||||
if idx is not None:
|
|
||||||
joint_pos[idx] = float(j_val)
|
|
||||||
out.append((delta_s, joint_pos))
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
def sequence_to_amp_features(
|
|
||||||
sequence: list[tuple[float, torch.Tensor]],
|
|
||||||
sample_fps: float,
|
|
||||||
projected_gravity: tuple[float, float, float],
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""Convert decoded sequence into AMP features tensor (N, 55)."""
|
|
||||||
if len(sequence) == 0:
|
|
||||||
raise ValueError("Empty keyframe sequence.")
|
|
||||||
dt = 1.0 / max(sample_fps, 1e-6)
|
|
||||||
grav = torch.tensor(projected_gravity, dtype=torch.float32)
|
|
||||||
|
|
||||||
frames_joint_pos: list[torch.Tensor] = []
|
|
||||||
for delta_s, joint_pos in sequence:
|
|
||||||
repeat = max(int(round(delta_s / dt)), 1)
|
|
||||||
for _ in range(repeat):
|
|
||||||
frames_joint_pos.append(joint_pos.clone())
|
|
||||||
if len(frames_joint_pos) < 2:
|
|
||||||
frames_joint_pos.append(frames_joint_pos[0].clone())
|
|
||||||
|
|
||||||
pos = torch.stack(frames_joint_pos, dim=0)
|
|
||||||
vel = torch.zeros_like(pos)
|
|
||||||
vel[1:] = (pos[1:] - pos[:-1]) / dt
|
|
||||||
vel[0] = vel[1]
|
|
||||||
|
|
||||||
root_lin = torch.zeros((pos.shape[0], 3), dtype=torch.float32)
|
|
||||||
root_ang = torch.zeros((pos.shape[0], 3), dtype=torch.float32)
|
|
||||||
grav_batch = grav.unsqueeze(0).repeat(pos.shape[0], 1)
|
|
||||||
return torch.cat([pos, vel, root_lin, root_ang, grav_batch], dim=-1)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser(description="Build AMP expert features from get_up keyframe YAML files.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--front_yaml",
|
|
||||||
type=str,
|
|
||||||
default="behaviors/custom/keyframe/get_up/get_up_front.yaml",
|
|
||||||
help="Path to front get-up YAML.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--back_yaml",
|
|
||||||
type=str,
|
|
||||||
default="behaviors/custom/keyframe/get_up/get_up_back.yaml",
|
|
||||||
help="Path to back get-up YAML.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--sample_fps",
|
|
||||||
type=float,
|
|
||||||
default=50.0,
|
|
||||||
help="Sampling fps when expanding keyframe durations.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--repeat_cycles",
|
|
||||||
type=int,
|
|
||||||
default=200,
|
|
||||||
help="How many times to repeat front+back sequences to enlarge dataset.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--projected_gravity",
|
|
||||||
type=float,
|
|
||||||
nargs=3,
|
|
||||||
default=(0.0, 0.0, -1.0),
|
|
||||||
help="Projected gravity feature used for synthesized expert data.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--output",
|
|
||||||
type=str,
|
|
||||||
default="rl_game/get_up/amp/expert_features.pt",
|
|
||||||
help="Output expert feature file path.",
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
front_path = Path(args.front_yaml).expanduser().resolve()
|
|
||||||
back_path = Path(args.back_yaml).expanduser().resolve()
|
|
||||||
if not front_path.is_file():
|
|
||||||
raise FileNotFoundError(f"Front YAML not found: {front_path}")
|
|
||||||
if not back_path.is_file():
|
|
||||||
raise FileNotFoundError(f"Back YAML not found: {back_path}")
|
|
||||||
|
|
||||||
front_seq = load_sequence(front_path)
|
|
||||||
back_seq = load_sequence(back_path)
|
|
||||||
front_feat = sequence_to_amp_features(front_seq, args.sample_fps, tuple(args.projected_gravity))
|
|
||||||
back_feat = sequence_to_amp_features(back_seq, args.sample_fps, tuple(args.projected_gravity))
|
|
||||||
base_feat = torch.cat([front_feat, back_feat], dim=0)
|
|
||||||
|
|
||||||
repeat_cycles = max(int(args.repeat_cycles), 1)
|
|
||||||
expert_features = base_feat.repeat(repeat_cycles, 1).contiguous()
|
|
||||||
|
|
||||||
out_path = Path(args.output).expanduser()
|
|
||||||
if not out_path.is_absolute():
|
|
||||||
out_path = Path.cwd() / out_path
|
|
||||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
torch.save(
|
|
||||||
{
|
|
||||||
"expert_features": expert_features,
|
|
||||||
"feature_dim": int(expert_features.shape[1]),
|
|
||||||
"num_samples": int(expert_features.shape[0]),
|
|
||||||
"source": "get_up_keyframe_yaml",
|
|
||||||
"front_yaml": str(front_path),
|
|
||||||
"back_yaml": str(back_path),
|
|
||||||
"sample_fps": float(args.sample_fps),
|
|
||||||
"repeat_cycles": repeat_cycles,
|
|
||||||
"projected_gravity": [float(v) for v in args.projected_gravity],
|
|
||||||
},
|
|
||||||
str(out_path),
|
|
||||||
)
|
|
||||||
print(f"[INFO]: saved expert features -> {out_path}")
|
|
||||||
print(f"[INFO]: shape={tuple(expert_features.shape)}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -27,7 +27,7 @@ params:
|
|||||||
name: default
|
name: default
|
||||||
|
|
||||||
config:
|
config:
|
||||||
name: T1_Walking
|
name: T1_GetUp
|
||||||
env_name: rlgym # Isaac Lab 包装器
|
env_name: rlgym # Isaac Lab 包装器
|
||||||
multi_gpu: False
|
multi_gpu: False
|
||||||
ppo: True
|
ppo: True
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
import yaml
|
||||||
import isaaclab.envs.mdp as mdp
|
import isaaclab.envs.mdp as mdp
|
||||||
from isaaclab.envs import ManagerBasedRLEnvCfg, ManagerBasedRLEnv
|
from isaaclab.envs import ManagerBasedRLEnvCfg, ManagerBasedRLEnv
|
||||||
from isaaclab.managers import ObservationGroupCfg as ObsGroup
|
from isaaclab.managers import ObservationGroupCfg as ObsGroup
|
||||||
@@ -11,8 +11,13 @@ from isaaclab.managers import EventTermCfg as EventTerm
|
|||||||
from isaaclab.envs.mdp import JointPositionActionCfg
|
from isaaclab.envs.mdp import JointPositionActionCfg
|
||||||
from isaaclab.managers import SceneEntityCfg
|
from isaaclab.managers import SceneEntityCfg
|
||||||
from isaaclab.utils import configclass
|
from isaaclab.utils import configclass
|
||||||
|
from rl_game.get_up.amp.amp_rewards import amp_style_prior_reward
|
||||||
from rl_game.get_up.env.t1_env import T1SceneCfg
|
from rl_game.get_up.env.t1_env import T1SceneCfg
|
||||||
|
|
||||||
|
_PROJECT_ROOT = Path(__file__).resolve().parents[3]
|
||||||
|
_DEFAULT_FRONT_KEYFRAME_YAML = str(_PROJECT_ROOT / "behaviors" / "custom" / "keyframe" / "get_up" / "get_up_front.yaml")
|
||||||
|
_DEFAULT_BACK_KEYFRAME_YAML = str(_PROJECT_ROOT / "behaviors" / "custom" / "keyframe" / "get_up" / "get_up_back.yaml")
|
||||||
|
|
||||||
def _contact_force_z(env: ManagerBasedRLEnv, sensor_cfg: SceneEntityCfg) -> torch.Tensor:
|
def _contact_force_z(env: ManagerBasedRLEnv, sensor_cfg: SceneEntityCfg) -> torch.Tensor:
|
||||||
"""Sum positive vertical contact force on selected bodies."""
|
"""Sum positive vertical contact force on selected bodies."""
|
||||||
sensor = env.scene.sensors.get(sensor_cfg.name)
|
sensor = env.scene.sensors.get(sensor_cfg.name)
|
||||||
@@ -30,272 +35,216 @@ def _safe_tensor(x: torch.Tensor, nan: float = 0.0, pos: float = 1e3, neg: float
|
|||||||
return torch.nan_to_num(x, nan=nan, posinf=pos, neginf=neg)
|
return torch.nan_to_num(x, nan=nan, posinf=pos, neginf=neg)
|
||||||
|
|
||||||
|
|
||||||
class AMPDiscriminator(nn.Module):
|
def _resolve_path(path_like: str) -> Path:
|
||||||
"""Lightweight discriminator used by online AMP updates."""
|
path = Path(path_like).expanduser()
|
||||||
|
if path.is_absolute():
|
||||||
def __init__(self, input_dim: int, hidden_dims: tuple[int, ...]):
|
return path
|
||||||
super().__init__()
|
return (_PROJECT_ROOT / path).resolve()
|
||||||
layers: list[nn.Module] = []
|
|
||||||
in_dim = input_dim
|
|
||||||
for h_dim in hidden_dims:
|
|
||||||
layers.append(nn.Linear(in_dim, h_dim))
|
|
||||||
layers.append(nn.LayerNorm(h_dim))
|
|
||||||
layers.append(nn.SiLU())
|
|
||||||
in_dim = h_dim
|
|
||||||
layers.append(nn.Linear(in_dim, 1))
|
|
||||||
self.net = nn.Sequential(*layers)
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
return self.net(x)
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_tensor_from_amp_payload(payload) -> torch.Tensor | None:
|
def _interpolate_keyframes(
|
||||||
if isinstance(payload, torch.Tensor):
|
keyframes: list[dict],
|
||||||
return payload
|
sample_dt: float,
|
||||||
if isinstance(payload, dict):
|
) -> tuple[torch.Tensor, list[str]]:
|
||||||
for key in ("expert_features", "features", "obs"):
|
"""Interpolate sparse keyframes into dense [T, M] motor angle table in radians."""
|
||||||
value = payload.get(key, None)
|
if len(keyframes) == 0:
|
||||||
if isinstance(value, torch.Tensor):
|
raise ValueError("No keyframes in motion yaml")
|
||||||
return value
|
|
||||||
return None
|
motor_names: set[str] = set()
|
||||||
|
for frame in keyframes:
|
||||||
|
motors = frame.get("motor_positions", {})
|
||||||
|
if isinstance(motors, dict):
|
||||||
|
motor_names.update(motors.keys())
|
||||||
|
names = sorted(motor_names)
|
||||||
|
if len(names) == 0:
|
||||||
|
raise ValueError("No motor_positions found in keyframes")
|
||||||
|
|
||||||
|
frame_count = len(keyframes)
|
||||||
|
times = torch.zeros(frame_count, dtype=torch.float32)
|
||||||
|
values = torch.full((frame_count, len(names)), float("nan"), dtype=torch.float32)
|
||||||
|
name_to_idx = {name: idx for idx, name in enumerate(names)}
|
||||||
|
|
||||||
|
elapsed = 0.0
|
||||||
|
for i, frame in enumerate(keyframes):
|
||||||
|
elapsed += max(float(frame.get("delta", 0.1)), 1e-4)
|
||||||
|
times[i] = elapsed
|
||||||
|
motors = frame.get("motor_positions", {})
|
||||||
|
if not isinstance(motors, dict):
|
||||||
|
continue
|
||||||
|
for motor_name, motor_deg in motors.items():
|
||||||
|
if motor_name not in name_to_idx:
|
||||||
|
continue
|
||||||
|
values[i, name_to_idx[motor_name]] = float(motor_deg) * (torch.pi / 180.0)
|
||||||
|
|
||||||
|
values[0] = torch.nan_to_num(values[0], nan=0.0)
|
||||||
|
for i in range(1, frame_count):
|
||||||
|
values[i] = torch.where(torch.isnan(values[i]), values[i - 1], values[i])
|
||||||
|
|
||||||
|
sample_dt = max(float(sample_dt), 1e-3)
|
||||||
|
sample_times = torch.arange(0.0, float(times[-1]) + 1e-6, sample_dt, dtype=torch.float32)
|
||||||
|
sample_times[0] = max(sample_times[0], 1e-6)
|
||||||
|
sample_times = torch.clamp(sample_times, min=times[0], max=times[-1])
|
||||||
|
|
||||||
|
upper = torch.searchsorted(times, sample_times, right=True)
|
||||||
|
upper = torch.clamp(upper, min=1, max=frame_count - 1)
|
||||||
|
lower = upper - 1
|
||||||
|
|
||||||
|
t0 = times.index_select(0, lower)
|
||||||
|
t1 = times.index_select(0, upper)
|
||||||
|
v0 = values.index_select(0, lower)
|
||||||
|
v1 = values.index_select(0, upper)
|
||||||
|
alpha = ((sample_times - t0) / (t1 - t0 + 1e-6)).unsqueeze(-1)
|
||||||
|
interp = v0 + alpha * (v1 - v0)
|
||||||
|
return interp, names
|
||||||
|
|
||||||
|
|
||||||
def _load_amp_expert_features(
|
def _motion_table_from_yaml(
|
||||||
expert_features_path: str,
|
yaml_path: str,
|
||||||
device: str,
|
joint_names: list[str],
|
||||||
feature_dim: int,
|
sample_dt: float,
|
||||||
fallback_samples: int,
|
) -> tuple[torch.Tensor, float]:
|
||||||
) -> torch.Tensor | None:
|
"""
|
||||||
"""Load expert AMP features. Returns None when file is unavailable."""
|
Build [T, J] target joint motion from get-up keyframes.
|
||||||
if not expert_features_path:
|
Unknown joints default to 0.0 (neutral).
|
||||||
return None
|
"""
|
||||||
p = Path(expert_features_path).expanduser()
|
path = _resolve_path(yaml_path)
|
||||||
if not p.is_file():
|
if not path.is_file():
|
||||||
return None
|
raise FileNotFoundError(f"Motion yaml not found: {path}")
|
||||||
try:
|
with path.open("r", encoding="utf-8") as f:
|
||||||
payload = torch.load(str(p), map_location="cpu")
|
payload = yaml.safe_load(f) or {}
|
||||||
except Exception:
|
keyframes = payload.get("keyframes", [])
|
||||||
return None
|
if not isinstance(keyframes, list):
|
||||||
expert = _extract_tensor_from_amp_payload(payload)
|
raise ValueError(f"Invalid keyframes in yaml: {path}")
|
||||||
if expert is None:
|
|
||||||
return None
|
motor_table, motor_names = _interpolate_keyframes(keyframes, sample_dt=sample_dt)
|
||||||
expert = expert.float()
|
motor_to_idx = {name: idx for idx, name in enumerate(motor_names)}
|
||||||
if expert.ndim == 1:
|
joint_table = torch.zeros((motor_table.shape[0], len(joint_names)), dtype=torch.float32)
|
||||||
expert = expert.unsqueeze(0)
|
|
||||||
if expert.ndim != 2:
|
sign_flip_bases = {"Shoulder_Roll", "Elbow_Yaw", "Hip_Roll", "Hip_Yaw", "Ankle_Roll"}
|
||||||
return None
|
for j, joint_name in enumerate(joint_names):
|
||||||
if expert.shape[1] != feature_dim:
|
alias = "Head_yaw" if joint_name == "AAHead_yaw" else joint_name
|
||||||
return None
|
base = alias.replace("Left_", "").replace("Right_", "")
|
||||||
if expert.shape[0] < 2:
|
src = None
|
||||||
return None
|
if alias in motor_to_idx:
|
||||||
if expert.shape[0] < fallback_samples:
|
src = alias
|
||||||
reps = int((fallback_samples + expert.shape[0] - 1) // expert.shape[0])
|
elif base in motor_to_idx:
|
||||||
expert = expert.repeat(reps, 1)
|
src = base
|
||||||
return expert.to(device=device)
|
if src is None:
|
||||||
|
continue
|
||||||
|
sign = -1.0 if alias.startswith("Right_") and base in sign_flip_bases else 1.0
|
||||||
|
joint_table[:, j] = sign * motor_table[:, motor_to_idx[src]]
|
||||||
|
|
||||||
|
duration_s = max(float(motor_table.shape[0] - 1) * sample_dt, sample_dt)
|
||||||
|
return joint_table, duration_s
|
||||||
|
|
||||||
|
|
||||||
def _get_amp_state(
|
def _get_keyframe_motion_cache(
|
||||||
env: ManagerBasedRLEnv,
|
env: ManagerBasedRLEnv,
|
||||||
amp_enabled: bool,
|
front_motion_path: str,
|
||||||
amp_model_path: str,
|
back_motion_path: str,
|
||||||
amp_train_enabled: bool,
|
sample_dt: float,
|
||||||
amp_expert_features_path: str,
|
|
||||||
feature_dim: int,
|
|
||||||
disc_hidden_dim: int,
|
|
||||||
disc_hidden_layers: int,
|
|
||||||
disc_lr: float,
|
|
||||||
disc_weight_decay: float,
|
|
||||||
disc_min_expert_samples: int,
|
|
||||||
):
|
):
|
||||||
"""Get cached AMP state (frozen jit or trainable discriminator)."""
|
"""Cache interpolated front/back get-up motion priors on env device."""
|
||||||
cache_key = "amp_state_cache"
|
cache_key = "getup_keyframe_motion_cache"
|
||||||
hidden_layers = max(int(disc_hidden_layers), 1)
|
sig = (str(front_motion_path), str(back_motion_path), float(sample_dt))
|
||||||
hidden_dim = max(int(disc_hidden_dim), 16)
|
|
||||||
state_sig = (
|
|
||||||
bool(amp_enabled),
|
|
||||||
str(amp_model_path),
|
|
||||||
bool(amp_train_enabled),
|
|
||||||
str(amp_expert_features_path),
|
|
||||||
int(feature_dim),
|
|
||||||
hidden_dim,
|
|
||||||
hidden_layers,
|
|
||||||
float(disc_lr),
|
|
||||||
float(disc_weight_decay),
|
|
||||||
)
|
|
||||||
cached = env.extras.get(cache_key, None)
|
cached = env.extras.get(cache_key, None)
|
||||||
if isinstance(cached, dict) and cached.get("sig") == state_sig:
|
if isinstance(cached, dict) and cached.get("sig") == sig:
|
||||||
return cached
|
return cached
|
||||||
|
|
||||||
state = {
|
front_table, front_duration = _motion_table_from_yaml(front_motion_path, T1_JOINT_NAMES, sample_dt)
|
||||||
"sig": state_sig,
|
back_table, back_duration = _motion_table_from_yaml(back_motion_path, T1_JOINT_NAMES, sample_dt)
|
||||||
"mode": "disabled",
|
|
||||||
"model": None,
|
cache = {
|
||||||
"optimizer": None,
|
"sig": sig,
|
||||||
"expert_features": None,
|
"sample_dt": float(sample_dt),
|
||||||
"step": 0,
|
"front_motion": front_table.to(device=env.device),
|
||||||
"last_loss": 0.0,
|
"back_motion": back_table.to(device=env.device),
|
||||||
"last_acc_policy": 0.0,
|
"front_duration": float(front_duration),
|
||||||
"last_acc_expert": 0.0,
|
"back_duration": float(back_duration),
|
||||||
}
|
}
|
||||||
|
env.extras[cache_key] = cache
|
||||||
if amp_train_enabled:
|
return cache
|
||||||
expert_features = _load_amp_expert_features(
|
|
||||||
amp_expert_features_path,
|
|
||||||
device=env.device,
|
|
||||||
feature_dim=feature_dim,
|
|
||||||
fallback_samples=max(disc_min_expert_samples, 512),
|
|
||||||
)
|
|
||||||
if expert_features is not None:
|
|
||||||
model = AMPDiscriminator(input_dim=feature_dim, hidden_dims=tuple([hidden_dim] * hidden_layers)).to(env.device)
|
|
||||||
optimizer = torch.optim.AdamW(model.parameters(), lr=float(disc_lr), weight_decay=float(disc_weight_decay))
|
|
||||||
state["mode"] = "trainable"
|
|
||||||
state["model"] = model
|
|
||||||
state["optimizer"] = optimizer
|
|
||||||
state["expert_features"] = expert_features
|
|
||||||
elif amp_enabled and amp_model_path:
|
|
||||||
model_path = Path(amp_model_path).expanduser()
|
|
||||||
if model_path.is_file():
|
|
||||||
try:
|
|
||||||
model = torch.jit.load(str(model_path), map_location=env.device)
|
|
||||||
model.eval()
|
|
||||||
state["mode"] = "jit"
|
|
||||||
state["model"] = model
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
env.extras[cache_key] = state
|
|
||||||
return state
|
|
||||||
|
|
||||||
|
|
||||||
def _build_amp_features(env: ManagerBasedRLEnv, feature_clip: float = 8.0) -> torch.Tensor:
|
def keyframe_motion_prior_reward(
|
||||||
"""Build AMP-style discriminator features from robot kinematics."""
|
|
||||||
robot_data = env.scene["robot"].data
|
|
||||||
joint_pos_rel = robot_data.joint_pos - robot_data.default_joint_pos
|
|
||||||
joint_vel = robot_data.joint_vel
|
|
||||||
root_lin_vel = robot_data.root_lin_vel_w
|
|
||||||
root_ang_vel = robot_data.root_ang_vel_w
|
|
||||||
projected_gravity = robot_data.projected_gravity_b
|
|
||||||
amp_features = torch.cat([joint_pos_rel, joint_vel, root_lin_vel, root_ang_vel, projected_gravity], dim=-1)
|
|
||||||
amp_features = _safe_tensor(amp_features, nan=0.0, pos=feature_clip, neg=-feature_clip)
|
|
||||||
return torch.clamp(amp_features, min=-feature_clip, max=feature_clip)
|
|
||||||
|
|
||||||
|
|
||||||
def amp_style_prior_reward(
|
|
||||||
env: ManagerBasedRLEnv,
|
env: ManagerBasedRLEnv,
|
||||||
amp_enabled: bool = False,
|
front_motion_path: str = _DEFAULT_FRONT_KEYFRAME_YAML,
|
||||||
amp_model_path: str = "",
|
back_motion_path: str = _DEFAULT_BACK_KEYFRAME_YAML,
|
||||||
amp_train_enabled: bool = False,
|
sample_dt: float = 0.04,
|
||||||
amp_expert_features_path: str = "",
|
pose_sigma: float = 0.42,
|
||||||
disc_hidden_dim: int = 256,
|
vel_sigma: float = 1.6,
|
||||||
disc_hidden_layers: int = 2,
|
joint_subset: str = "all",
|
||||||
disc_lr: float = 3e-4,
|
|
||||||
disc_weight_decay: float = 1e-6,
|
|
||||||
disc_update_interval: int = 4,
|
|
||||||
disc_batch_size: int = 1024,
|
|
||||||
disc_min_expert_samples: int = 2048,
|
|
||||||
feature_clip: float = 8.0,
|
|
||||||
logit_scale: float = 1.0,
|
|
||||||
amp_reward_gain: float = 1.0,
|
|
||||||
internal_reward_scale: float = 1.0,
|
internal_reward_scale: float = 1.0,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""AMP style prior reward with optional online discriminator training."""
|
"""
|
||||||
zeros = torch.zeros(env.num_envs, device=env.device)
|
DeepMimic-style dense reward from keyframe get-up motions.
|
||||||
amp_score = zeros
|
- mode=1 uses front sequence
|
||||||
model_loaded = 0.0
|
- mode=0 uses back sequence
|
||||||
amp_train_active = 0.0
|
"""
|
||||||
disc_loss = 0.0
|
motion_cache = _get_keyframe_motion_cache(
|
||||||
disc_acc_policy = 0.0
|
|
||||||
disc_acc_expert = 0.0
|
|
||||||
|
|
||||||
amp_features = _build_amp_features(env, feature_clip=feature_clip)
|
|
||||||
amp_state = _get_amp_state(
|
|
||||||
env=env,
|
env=env,
|
||||||
amp_enabled=amp_enabled,
|
front_motion_path=front_motion_path,
|
||||||
amp_model_path=amp_model_path,
|
back_motion_path=back_motion_path,
|
||||||
amp_train_enabled=amp_train_enabled,
|
sample_dt=sample_dt,
|
||||||
amp_expert_features_path=amp_expert_features_path,
|
|
||||||
feature_dim=int(amp_features.shape[-1]),
|
|
||||||
disc_hidden_dim=disc_hidden_dim,
|
|
||||||
disc_hidden_layers=disc_hidden_layers,
|
|
||||||
disc_lr=disc_lr,
|
|
||||||
disc_weight_decay=disc_weight_decay,
|
|
||||||
disc_min_expert_samples=disc_min_expert_samples,
|
|
||||||
)
|
)
|
||||||
discriminator = amp_state.get("model", None)
|
|
||||||
if discriminator is not None:
|
|
||||||
model_loaded = 1.0
|
|
||||||
|
|
||||||
if amp_state.get("mode") == "trainable" and discriminator is not None:
|
# env.extras["getup_mode"]: 1=front, 0=back
|
||||||
amp_train_active = 1.0
|
getup_mode = env.extras.get("getup_mode", None)
|
||||||
optimizer = amp_state.get("optimizer", None)
|
if not isinstance(getup_mode, torch.Tensor) or getup_mode.shape[0] != env.num_envs:
|
||||||
expert_features = amp_state.get("expert_features", None)
|
getup_mode = torch.zeros(env.num_envs, device=env.device, dtype=torch.long)
|
||||||
amp_state["step"] = int(amp_state.get("step", 0)) + 1
|
env.extras["getup_mode"] = getup_mode
|
||||||
update_interval = max(int(disc_update_interval), 1)
|
getup_mode = getup_mode.to(dtype=torch.long)
|
||||||
batch_size = max(int(disc_batch_size), 32)
|
use_front = getup_mode == 1
|
||||||
|
|
||||||
if optimizer is not None and isinstance(expert_features, torch.Tensor) and amp_state["step"] % update_interval == 0:
|
step_dt = env.step_dt
|
||||||
policy_features = amp_features.detach()
|
phase_time = torch.clamp(env.episode_length_buf * step_dt, min=0.0)
|
||||||
policy_count = policy_features.shape[0]
|
|
||||||
if policy_count > batch_size:
|
|
||||||
policy_ids = torch.randint(0, policy_count, (batch_size,), device=env.device)
|
|
||||||
policy_batch = policy_features.index_select(0, policy_ids)
|
|
||||||
else:
|
|
||||||
policy_batch = policy_features
|
|
||||||
|
|
||||||
expert_count = expert_features.shape[0]
|
front_motion = motion_cache["front_motion"]
|
||||||
expert_ids = torch.randint(0, expert_count, (policy_batch.shape[0],), device=env.device)
|
back_motion = motion_cache["back_motion"]
|
||||||
expert_batch = expert_features.index_select(0, expert_ids)
|
front_idx = torch.clamp((phase_time / sample_dt).to(torch.long), min=0, max=front_motion.shape[0] - 1)
|
||||||
|
back_idx = torch.clamp((phase_time / sample_dt).to(torch.long), min=0, max=back_motion.shape[0] - 1)
|
||||||
|
|
||||||
discriminator.train()
|
target_front = front_motion.index_select(0, front_idx)
|
||||||
optimizer.zero_grad(set_to_none=True)
|
target_back = back_motion.index_select(0, back_idx)
|
||||||
logits_expert = discriminator(expert_batch).squeeze(-1)
|
target_pos = torch.where(use_front.unsqueeze(-1), target_front, target_back)
|
||||||
logits_policy = discriminator(policy_batch).squeeze(-1)
|
|
||||||
loss_expert = nn.functional.binary_cross_entropy_with_logits(logits_expert, torch.ones_like(logits_expert))
|
|
||||||
loss_policy = nn.functional.binary_cross_entropy_with_logits(logits_policy, torch.zeros_like(logits_policy))
|
|
||||||
loss = 0.5 * (loss_expert + loss_policy)
|
|
||||||
loss.backward()
|
|
||||||
nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0)
|
|
||||||
optimizer.step()
|
|
||||||
discriminator.eval()
|
|
||||||
|
|
||||||
with torch.no_grad():
|
robot_data = env.scene["robot"].data
|
||||||
disc_loss = float(loss.detach().item())
|
current_pos = _safe_tensor(robot_data.joint_pos)
|
||||||
disc_acc_expert = float((torch.sigmoid(logits_expert) > 0.5).float().mean().item())
|
current_vel = _safe_tensor(robot_data.joint_vel)
|
||||||
disc_acc_policy = float((torch.sigmoid(logits_policy) < 0.5).float().mean().item())
|
|
||||||
amp_state["last_loss"] = disc_loss
|
|
||||||
amp_state["last_acc_expert"] = disc_acc_expert
|
|
||||||
amp_state["last_acc_policy"] = disc_acc_policy
|
|
||||||
else:
|
|
||||||
disc_loss = float(amp_state.get("last_loss", 0.0))
|
|
||||||
disc_acc_expert = float(amp_state.get("last_acc_expert", 0.0))
|
|
||||||
disc_acc_policy = float(amp_state.get("last_acc_policy", 0.0))
|
|
||||||
|
|
||||||
if discriminator is not None:
|
if joint_subset == "legs":
|
||||||
discriminator.eval()
|
legs_idx, _ = env.scene["robot"].find_joints(".*(Hip|Knee|Ankle).*")
|
||||||
with torch.no_grad():
|
if len(legs_idx) > 0:
|
||||||
logits = discriminator(amp_features)
|
ids = torch.tensor(legs_idx, device=env.device, dtype=torch.long)
|
||||||
if isinstance(logits, (tuple, list)):
|
current_pos = current_pos.index_select(1, ids)
|
||||||
logits = logits[0]
|
current_vel = current_vel.index_select(1, ids)
|
||||||
if logits.ndim > 1:
|
target_pos = target_pos.index_select(1, ids)
|
||||||
logits = logits.squeeze(-1)
|
elif joint_subset == "core":
|
||||||
logits = _safe_tensor(logits, nan=0.0, pos=20.0, neg=-20.0)
|
core_idx, _ = env.scene["robot"].find_joints(".*(Waist|Hip|Knee|Ankle).*")
|
||||||
amp_score = torch.sigmoid(logit_scale * logits)
|
if len(core_idx) > 0:
|
||||||
amp_score = _safe_tensor(amp_score, nan=0.0, pos=1.0, neg=0.0)
|
ids = torch.tensor(core_idx, device=env.device, dtype=torch.long)
|
||||||
|
current_pos = current_pos.index_select(1, ids)
|
||||||
|
current_vel = current_vel.index_select(1, ids)
|
||||||
|
target_pos = target_pos.index_select(1, ids)
|
||||||
|
|
||||||
amp_reward = _safe_tensor(amp_reward_gain * amp_score, nan=0.0, pos=10.0, neg=0.0)
|
target_vel = torch.zeros_like(target_pos)
|
||||||
|
pos_mse = torch.mean(torch.square(current_pos - target_pos), dim=-1)
|
||||||
|
vel_mse = torch.mean(torch.square(current_vel - target_vel), dim=-1)
|
||||||
|
|
||||||
|
pose_sigma = max(float(pose_sigma), 1e-3)
|
||||||
|
vel_sigma = max(float(vel_sigma), 1e-3)
|
||||||
|
pose_reward = torch.exp(-pos_mse / pose_sigma)
|
||||||
|
vel_reward = torch.exp(-vel_mse / vel_sigma)
|
||||||
|
prior_reward = 0.75 * pose_reward + 0.25 * vel_reward
|
||||||
|
prior_reward = _safe_tensor(prior_reward, nan=0.0, pos=1.0, neg=0.0)
|
||||||
|
|
||||||
log_dict = env.extras.get("log", {})
|
log_dict = env.extras.get("log", {})
|
||||||
if isinstance(log_dict, dict):
|
if isinstance(log_dict, dict):
|
||||||
log_dict["amp_score_mean"] = torch.mean(amp_score).detach().item()
|
log_dict["keyframe_prior_mean"] = torch.mean(prior_reward).detach().item()
|
||||||
log_dict["amp_reward_mean"] = torch.mean(amp_reward).detach().item()
|
log_dict["keyframe_front_ratio"] = torch.mean(use_front.float()).detach().item()
|
||||||
log_dict["amp_model_loaded_mean"] = model_loaded
|
|
||||||
log_dict["amp_train_active_mean"] = amp_train_active
|
|
||||||
log_dict["amp_disc_loss_mean"] = disc_loss
|
|
||||||
log_dict["amp_disc_acc_policy_mean"] = disc_acc_policy
|
|
||||||
log_dict["amp_disc_acc_expert_mean"] = disc_acc_expert
|
|
||||||
env.extras["log"] = log_dict
|
env.extras["log"] = log_dict
|
||||||
|
|
||||||
return internal_reward_scale * amp_reward
|
return internal_reward_scale * prior_reward
|
||||||
|
|
||||||
|
|
||||||
def root_height_obs(env: ManagerBasedRLEnv) -> torch.Tensor:
|
def root_height_obs(env: ManagerBasedRLEnv) -> torch.Tensor:
|
||||||
@@ -365,6 +314,10 @@ def reset_root_state_bimodal_lie_pose(
|
|||||||
-torch.ones(num_resets, device=env.device),
|
-torch.ones(num_resets, device=env.device),
|
||||||
)
|
)
|
||||||
euler_angles[:, 1] = pitch_mag * pitch_sign
|
euler_angles[:, 1] = pitch_mag * pitch_sign
|
||||||
|
# Cache get-up mode for motion priors: pitch<0 => front, pitch>=0 => back.
|
||||||
|
if "getup_mode" not in env.extras or not isinstance(env.extras.get("getup_mode"), torch.Tensor):
|
||||||
|
env.extras["getup_mode"] = torch.zeros(env.num_envs, device=env.device, dtype=torch.long)
|
||||||
|
env.extras["getup_mode"][env_ids] = (pitch_sign < 0.0).to(torch.long)
|
||||||
|
|
||||||
yaw_min, yaw_max = yaw_abs_range
|
yaw_min, yaw_max = yaw_abs_range
|
||||||
yaw_mag = yaw_min + torch.rand(num_resets, device=env.device) * (yaw_max - yaw_min)
|
yaw_mag = yaw_min + torch.rand(num_resets, device=env.device) * (yaw_max - yaw_min)
|
||||||
@@ -778,6 +731,19 @@ class T1GetUpRewardCfg:
|
|||||||
"timer_name": "reward_stable_timer",
|
"timer_name": "reward_stable_timer",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
keyframe_motion_prior = RewTerm(
|
||||||
|
func=keyframe_motion_prior_reward,
|
||||||
|
weight=0.0,
|
||||||
|
params={
|
||||||
|
"front_motion_path": _DEFAULT_FRONT_KEYFRAME_YAML,
|
||||||
|
"back_motion_path": _DEFAULT_BACK_KEYFRAME_YAML,
|
||||||
|
"sample_dt": 0.04,
|
||||||
|
"pose_sigma": 0.42,
|
||||||
|
"vel_sigma": 1.6,
|
||||||
|
"joint_subset": "all",
|
||||||
|
"internal_reward_scale": 1.0,
|
||||||
|
},
|
||||||
|
)
|
||||||
# AMP reward is disabled by default until a discriminator model path is provided.
|
# AMP reward is disabled by default until a discriminator model path is provided.
|
||||||
amp_style_prior = RewTerm(
|
amp_style_prior = RewTerm(
|
||||||
func=amp_style_prior_reward,
|
func=amp_style_prior_reward,
|
||||||
@@ -794,6 +760,7 @@ class T1GetUpRewardCfg:
|
|||||||
"disc_update_interval": 4,
|
"disc_update_interval": 4,
|
||||||
"disc_batch_size": 1024,
|
"disc_batch_size": 1024,
|
||||||
"disc_min_expert_samples": 2048,
|
"disc_min_expert_samples": 2048,
|
||||||
|
"disc_history_steps": 4,
|
||||||
"feature_clip": 8.0,
|
"feature_clip": 8.0,
|
||||||
"logit_scale": 1.0,
|
"logit_scale": 1.0,
|
||||||
"amp_reward_gain": 1.0,
|
"amp_reward_gain": 1.0,
|
||||||
|
|||||||
@@ -42,7 +42,29 @@ parser.add_argument("--amp_disc_lr", type=float, default=3e-4, help="Learning ra
|
|||||||
parser.add_argument("--amp_disc_weight_decay", type=float, default=1e-6, help="Weight decay for AMP discriminator.")
|
parser.add_argument("--amp_disc_weight_decay", type=float, default=1e-6, help="Weight decay for AMP discriminator.")
|
||||||
parser.add_argument("--amp_disc_update_interval", type=int, default=4, help="Train discriminator every N reward calls.")
|
parser.add_argument("--amp_disc_update_interval", type=int, default=4, help="Train discriminator every N reward calls.")
|
||||||
parser.add_argument("--amp_disc_batch_size", type=int, default=1024, help="Discriminator train batch size.")
|
parser.add_argument("--amp_disc_batch_size", type=int, default=1024, help="Discriminator train batch size.")
|
||||||
|
parser.add_argument("--amp_disc_history_steps", type=int, default=4, help="Temporal history steps for AMP discriminator.")
|
||||||
parser.add_argument("--amp_logit_scale", type=float, default=1.0, help="Scale before sigmoid(logits) for AMP score.")
|
parser.add_argument("--amp_logit_scale", type=float, default=1.0, help="Scale before sigmoid(logits) for AMP score.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--amp_from_keyframes",
|
||||||
|
action="store_true",
|
||||||
|
help="Generate AMP expert features from get-up keyframe yaml files and enable online discriminator training.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--amp_keyframe_front",
|
||||||
|
type=str,
|
||||||
|
default=os.path.join(PROJECT_ROOT, "behaviors", "custom", "keyframe", "get_up", "get_up_front.yaml"),
|
||||||
|
help="Front get-up keyframe yaml path for AMP expert generation.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--amp_keyframe_back",
|
||||||
|
type=str,
|
||||||
|
default=os.path.join(PROJECT_ROOT, "behaviors", "custom", "keyframe", "get_up", "get_up_back.yaml"),
|
||||||
|
help="Back get-up keyframe yaml path for AMP expert generation.",
|
||||||
|
)
|
||||||
|
parser.add_argument("--amp_keyframe_dt", type=float, default=0.04, help="Resampling dt for keyframe AMP expert features.")
|
||||||
|
parser.add_argument("--amp_keyframe_repeat", type=int, default=16, help="Repeat count for each keyframe sequence.")
|
||||||
|
parser.add_argument("--keyframe_prior_weight", type=float, default=1.0, help="Weight for keyframe motion prior reward.")
|
||||||
|
parser.add_argument("--disable_keyframe_prior", action="store_true", help="Disable keyframe motion prior reward.")
|
||||||
AppLauncher.add_app_launcher_args(parser)
|
AppLauncher.add_app_launcher_args(parser)
|
||||||
args_cli = parser.parse_args()
|
args_cli = parser.parse_args()
|
||||||
|
|
||||||
@@ -56,7 +78,8 @@ from rl_games.common.algo_observer import DefaultAlgoObserver
|
|||||||
from rl_games.torch_runner import Runner
|
from rl_games.torch_runner import Runner
|
||||||
from rl_games.common import env_configurations, vecenv
|
from rl_games.common import env_configurations, vecenv
|
||||||
|
|
||||||
from rl_game.get_up.config.t1_env_cfg import T1EnvCfg
|
from rl_game.get_up.amp.amp_motion import build_amp_expert_features_from_getup_keyframes
|
||||||
|
from rl_game.get_up.config.t1_env_cfg import T1EnvCfg, T1_JOINT_NAMES
|
||||||
|
|
||||||
|
|
||||||
class T1MetricObserver(DefaultAlgoObserver):
|
class T1MetricObserver(DefaultAlgoObserver):
|
||||||
@@ -77,6 +100,8 @@ class T1MetricObserver(DefaultAlgoObserver):
|
|||||||
"amp_disc_loss_mean",
|
"amp_disc_loss_mean",
|
||||||
"amp_disc_acc_policy_mean",
|
"amp_disc_acc_policy_mean",
|
||||||
"amp_disc_acc_expert_mean",
|
"amp_disc_acc_expert_mean",
|
||||||
|
"keyframe_prior_mean",
|
||||||
|
"keyframe_front_ratio",
|
||||||
)
|
)
|
||||||
self._metric_sums: dict[str, float] = {}
|
self._metric_sums: dict[str, float] = {}
|
||||||
self._metric_counts: dict[str, int] = {}
|
self._metric_counts: dict[str, int] = {}
|
||||||
@@ -202,10 +227,33 @@ def main():
|
|||||||
task_id = "Isaac-T1-GetUp-v0"
|
task_id = "Isaac-T1-GetUp-v0"
|
||||||
env_cfg = T1EnvCfg()
|
env_cfg = T1EnvCfg()
|
||||||
|
|
||||||
|
if args_cli.disable_keyframe_prior:
|
||||||
|
env_cfg.rewards.keyframe_motion_prior.weight = 0.0
|
||||||
|
print("[INFO]: keyframe motion prior disabled")
|
||||||
|
else:
|
||||||
|
env_cfg.rewards.keyframe_motion_prior.weight = float(args_cli.keyframe_prior_weight)
|
||||||
|
print(f"[INFO]: keyframe motion prior weight={env_cfg.rewards.keyframe_motion_prior.weight:.3f}")
|
||||||
|
|
||||||
|
if args_cli.amp_from_keyframes:
|
||||||
|
auto_feature_path = os.path.join(os.path.dirname(__file__), "logs", "amp", "expert_features_from_keyframes.pt")
|
||||||
|
generated_path, feature_shape = build_amp_expert_features_from_getup_keyframes(
|
||||||
|
front_yaml_path=args_cli.amp_keyframe_front,
|
||||||
|
back_yaml_path=args_cli.amp_keyframe_back,
|
||||||
|
joint_names=T1_JOINT_NAMES,
|
||||||
|
output_path=auto_feature_path,
|
||||||
|
sample_dt=float(args_cli.amp_keyframe_dt),
|
||||||
|
repeat_count=int(args_cli.amp_keyframe_repeat),
|
||||||
|
)
|
||||||
|
args_cli.amp_expert_features = generated_path
|
||||||
|
args_cli.amp_train_discriminator = True
|
||||||
|
print(f"[INFO]: AMP expert features generated at {generated_path}, shape={feature_shape}")
|
||||||
|
|
||||||
amp_cfg = env_cfg.rewards.amp_style_prior
|
amp_cfg = env_cfg.rewards.amp_style_prior
|
||||||
amp_cfg.params["logit_scale"] = float(args_cli.amp_logit_scale)
|
amp_cfg.params["logit_scale"] = float(args_cli.amp_logit_scale)
|
||||||
if args_cli.amp_train_discriminator:
|
if args_cli.amp_train_discriminator:
|
||||||
expert_path = os.path.abspath(os.path.expanduser(args_cli.amp_expert_features)) if args_cli.amp_expert_features else ""
|
expert_path = os.path.abspath(os.path.expanduser(args_cli.amp_expert_features)) if args_cli.amp_expert_features else ""
|
||||||
|
if not expert_path:
|
||||||
|
raise ValueError("--amp_train_discriminator requires --amp_expert_features or --amp_from_keyframes.")
|
||||||
amp_cfg.weight = float(args_cli.amp_reward_weight)
|
amp_cfg.weight = float(args_cli.amp_reward_weight)
|
||||||
amp_cfg.params["amp_train_enabled"] = True
|
amp_cfg.params["amp_train_enabled"] = True
|
||||||
amp_cfg.params["amp_enabled"] = False
|
amp_cfg.params["amp_enabled"] = False
|
||||||
@@ -216,8 +264,10 @@ def main():
|
|||||||
amp_cfg.params["disc_weight_decay"] = float(args_cli.amp_disc_weight_decay)
|
amp_cfg.params["disc_weight_decay"] = float(args_cli.amp_disc_weight_decay)
|
||||||
amp_cfg.params["disc_update_interval"] = int(args_cli.amp_disc_update_interval)
|
amp_cfg.params["disc_update_interval"] = int(args_cli.amp_disc_update_interval)
|
||||||
amp_cfg.params["disc_batch_size"] = int(args_cli.amp_disc_batch_size)
|
amp_cfg.params["disc_batch_size"] = int(args_cli.amp_disc_batch_size)
|
||||||
print(f"[INFO]: AMP online discriminator enabled, expert_features={expert_path or '<missing>'}")
|
amp_cfg.params["disc_history_steps"] = int(args_cli.amp_disc_history_steps)
|
||||||
|
print(f"[INFO]: AMP online discriminator enabled, expert_features={expert_path}")
|
||||||
print(f"[INFO]: AMP reward weight={amp_cfg.weight:.3f}")
|
print(f"[INFO]: AMP reward weight={amp_cfg.weight:.3f}")
|
||||||
|
print(f"[INFO]: AMP discriminator history_steps={amp_cfg.params['disc_history_steps']}")
|
||||||
elif args_cli.amp_model:
|
elif args_cli.amp_model:
|
||||||
amp_model_path = os.path.abspath(os.path.expanduser(args_cli.amp_model))
|
amp_model_path = os.path.abspath(os.path.expanduser(args_cli.amp_model))
|
||||||
amp_cfg.weight = float(args_cli.amp_reward_weight)
|
amp_cfg.weight = float(args_cli.amp_reward_weight)
|
||||||
|
|||||||
Reference in New Issue
Block a user