Add AMP get-up pipeline with sequence discriminator and git-sourced expert data
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from pathlib import Path
|
||||
import yaml
|
||||
import isaaclab.envs.mdp as mdp
|
||||
from isaaclab.envs import ManagerBasedRLEnvCfg, ManagerBasedRLEnv
|
||||
from isaaclab.managers import ObservationGroupCfg as ObsGroup
|
||||
@@ -11,8 +11,13 @@ from isaaclab.managers import EventTermCfg as EventTerm
|
||||
from isaaclab.envs.mdp import JointPositionActionCfg
|
||||
from isaaclab.managers import SceneEntityCfg
|
||||
from isaaclab.utils import configclass
|
||||
from rl_game.get_up.amp.amp_rewards import amp_style_prior_reward
|
||||
from rl_game.get_up.env.t1_env import T1SceneCfg
|
||||
|
||||
_PROJECT_ROOT = Path(__file__).resolve().parents[3]
|
||||
_DEFAULT_FRONT_KEYFRAME_YAML = str(_PROJECT_ROOT / "behaviors" / "custom" / "keyframe" / "get_up" / "get_up_front.yaml")
|
||||
_DEFAULT_BACK_KEYFRAME_YAML = str(_PROJECT_ROOT / "behaviors" / "custom" / "keyframe" / "get_up" / "get_up_back.yaml")
|
||||
|
||||
def _contact_force_z(env: ManagerBasedRLEnv, sensor_cfg: SceneEntityCfg) -> torch.Tensor:
|
||||
"""Sum positive vertical contact force on selected bodies."""
|
||||
sensor = env.scene.sensors.get(sensor_cfg.name)
|
||||
@@ -30,272 +35,216 @@ def _safe_tensor(x: torch.Tensor, nan: float = 0.0, pos: float = 1e3, neg: float
|
||||
return torch.nan_to_num(x, nan=nan, posinf=pos, neginf=neg)
|
||||
|
||||
|
||||
class AMPDiscriminator(nn.Module):
|
||||
"""Lightweight discriminator used by online AMP updates."""
|
||||
|
||||
def __init__(self, input_dim: int, hidden_dims: tuple[int, ...]):
|
||||
super().__init__()
|
||||
layers: list[nn.Module] = []
|
||||
in_dim = input_dim
|
||||
for h_dim in hidden_dims:
|
||||
layers.append(nn.Linear(in_dim, h_dim))
|
||||
layers.append(nn.LayerNorm(h_dim))
|
||||
layers.append(nn.SiLU())
|
||||
in_dim = h_dim
|
||||
layers.append(nn.Linear(in_dim, 1))
|
||||
self.net = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.net(x)
|
||||
def _resolve_path(path_like: str) -> Path:
|
||||
path = Path(path_like).expanduser()
|
||||
if path.is_absolute():
|
||||
return path
|
||||
return (_PROJECT_ROOT / path).resolve()
|
||||
|
||||
|
||||
def _extract_tensor_from_amp_payload(payload) -> torch.Tensor | None:
|
||||
if isinstance(payload, torch.Tensor):
|
||||
return payload
|
||||
if isinstance(payload, dict):
|
||||
for key in ("expert_features", "features", "obs"):
|
||||
value = payload.get(key, None)
|
||||
if isinstance(value, torch.Tensor):
|
||||
return value
|
||||
return None
|
||||
def _interpolate_keyframes(
|
||||
keyframes: list[dict],
|
||||
sample_dt: float,
|
||||
) -> tuple[torch.Tensor, list[str]]:
|
||||
"""Interpolate sparse keyframes into dense [T, M] motor angle table in radians."""
|
||||
if len(keyframes) == 0:
|
||||
raise ValueError("No keyframes in motion yaml")
|
||||
|
||||
motor_names: set[str] = set()
|
||||
for frame in keyframes:
|
||||
motors = frame.get("motor_positions", {})
|
||||
if isinstance(motors, dict):
|
||||
motor_names.update(motors.keys())
|
||||
names = sorted(motor_names)
|
||||
if len(names) == 0:
|
||||
raise ValueError("No motor_positions found in keyframes")
|
||||
|
||||
frame_count = len(keyframes)
|
||||
times = torch.zeros(frame_count, dtype=torch.float32)
|
||||
values = torch.full((frame_count, len(names)), float("nan"), dtype=torch.float32)
|
||||
name_to_idx = {name: idx for idx, name in enumerate(names)}
|
||||
|
||||
elapsed = 0.0
|
||||
for i, frame in enumerate(keyframes):
|
||||
elapsed += max(float(frame.get("delta", 0.1)), 1e-4)
|
||||
times[i] = elapsed
|
||||
motors = frame.get("motor_positions", {})
|
||||
if not isinstance(motors, dict):
|
||||
continue
|
||||
for motor_name, motor_deg in motors.items():
|
||||
if motor_name not in name_to_idx:
|
||||
continue
|
||||
values[i, name_to_idx[motor_name]] = float(motor_deg) * (torch.pi / 180.0)
|
||||
|
||||
values[0] = torch.nan_to_num(values[0], nan=0.0)
|
||||
for i in range(1, frame_count):
|
||||
values[i] = torch.where(torch.isnan(values[i]), values[i - 1], values[i])
|
||||
|
||||
sample_dt = max(float(sample_dt), 1e-3)
|
||||
sample_times = torch.arange(0.0, float(times[-1]) + 1e-6, sample_dt, dtype=torch.float32)
|
||||
sample_times[0] = max(sample_times[0], 1e-6)
|
||||
sample_times = torch.clamp(sample_times, min=times[0], max=times[-1])
|
||||
|
||||
upper = torch.searchsorted(times, sample_times, right=True)
|
||||
upper = torch.clamp(upper, min=1, max=frame_count - 1)
|
||||
lower = upper - 1
|
||||
|
||||
t0 = times.index_select(0, lower)
|
||||
t1 = times.index_select(0, upper)
|
||||
v0 = values.index_select(0, lower)
|
||||
v1 = values.index_select(0, upper)
|
||||
alpha = ((sample_times - t0) / (t1 - t0 + 1e-6)).unsqueeze(-1)
|
||||
interp = v0 + alpha * (v1 - v0)
|
||||
return interp, names
|
||||
|
||||
|
||||
def _load_amp_expert_features(
|
||||
expert_features_path: str,
|
||||
device: str,
|
||||
feature_dim: int,
|
||||
fallback_samples: int,
|
||||
) -> torch.Tensor | None:
|
||||
"""Load expert AMP features. Returns None when file is unavailable."""
|
||||
if not expert_features_path:
|
||||
return None
|
||||
p = Path(expert_features_path).expanduser()
|
||||
if not p.is_file():
|
||||
return None
|
||||
try:
|
||||
payload = torch.load(str(p), map_location="cpu")
|
||||
except Exception:
|
||||
return None
|
||||
expert = _extract_tensor_from_amp_payload(payload)
|
||||
if expert is None:
|
||||
return None
|
||||
expert = expert.float()
|
||||
if expert.ndim == 1:
|
||||
expert = expert.unsqueeze(0)
|
||||
if expert.ndim != 2:
|
||||
return None
|
||||
if expert.shape[1] != feature_dim:
|
||||
return None
|
||||
if expert.shape[0] < 2:
|
||||
return None
|
||||
if expert.shape[0] < fallback_samples:
|
||||
reps = int((fallback_samples + expert.shape[0] - 1) // expert.shape[0])
|
||||
expert = expert.repeat(reps, 1)
|
||||
return expert.to(device=device)
|
||||
def _motion_table_from_yaml(
|
||||
yaml_path: str,
|
||||
joint_names: list[str],
|
||||
sample_dt: float,
|
||||
) -> tuple[torch.Tensor, float]:
|
||||
"""
|
||||
Build [T, J] target joint motion from get-up keyframes.
|
||||
Unknown joints default to 0.0 (neutral).
|
||||
"""
|
||||
path = _resolve_path(yaml_path)
|
||||
if not path.is_file():
|
||||
raise FileNotFoundError(f"Motion yaml not found: {path}")
|
||||
with path.open("r", encoding="utf-8") as f:
|
||||
payload = yaml.safe_load(f) or {}
|
||||
keyframes = payload.get("keyframes", [])
|
||||
if not isinstance(keyframes, list):
|
||||
raise ValueError(f"Invalid keyframes in yaml: {path}")
|
||||
|
||||
motor_table, motor_names = _interpolate_keyframes(keyframes, sample_dt=sample_dt)
|
||||
motor_to_idx = {name: idx for idx, name in enumerate(motor_names)}
|
||||
joint_table = torch.zeros((motor_table.shape[0], len(joint_names)), dtype=torch.float32)
|
||||
|
||||
sign_flip_bases = {"Shoulder_Roll", "Elbow_Yaw", "Hip_Roll", "Hip_Yaw", "Ankle_Roll"}
|
||||
for j, joint_name in enumerate(joint_names):
|
||||
alias = "Head_yaw" if joint_name == "AAHead_yaw" else joint_name
|
||||
base = alias.replace("Left_", "").replace("Right_", "")
|
||||
src = None
|
||||
if alias in motor_to_idx:
|
||||
src = alias
|
||||
elif base in motor_to_idx:
|
||||
src = base
|
||||
if src is None:
|
||||
continue
|
||||
sign = -1.0 if alias.startswith("Right_") and base in sign_flip_bases else 1.0
|
||||
joint_table[:, j] = sign * motor_table[:, motor_to_idx[src]]
|
||||
|
||||
duration_s = max(float(motor_table.shape[0] - 1) * sample_dt, sample_dt)
|
||||
return joint_table, duration_s
|
||||
|
||||
|
||||
def _get_amp_state(
|
||||
def _get_keyframe_motion_cache(
|
||||
env: ManagerBasedRLEnv,
|
||||
amp_enabled: bool,
|
||||
amp_model_path: str,
|
||||
amp_train_enabled: bool,
|
||||
amp_expert_features_path: str,
|
||||
feature_dim: int,
|
||||
disc_hidden_dim: int,
|
||||
disc_hidden_layers: int,
|
||||
disc_lr: float,
|
||||
disc_weight_decay: float,
|
||||
disc_min_expert_samples: int,
|
||||
front_motion_path: str,
|
||||
back_motion_path: str,
|
||||
sample_dt: float,
|
||||
):
|
||||
"""Get cached AMP state (frozen jit or trainable discriminator)."""
|
||||
cache_key = "amp_state_cache"
|
||||
hidden_layers = max(int(disc_hidden_layers), 1)
|
||||
hidden_dim = max(int(disc_hidden_dim), 16)
|
||||
state_sig = (
|
||||
bool(amp_enabled),
|
||||
str(amp_model_path),
|
||||
bool(amp_train_enabled),
|
||||
str(amp_expert_features_path),
|
||||
int(feature_dim),
|
||||
hidden_dim,
|
||||
hidden_layers,
|
||||
float(disc_lr),
|
||||
float(disc_weight_decay),
|
||||
)
|
||||
"""Cache interpolated front/back get-up motion priors on env device."""
|
||||
cache_key = "getup_keyframe_motion_cache"
|
||||
sig = (str(front_motion_path), str(back_motion_path), float(sample_dt))
|
||||
cached = env.extras.get(cache_key, None)
|
||||
if isinstance(cached, dict) and cached.get("sig") == state_sig:
|
||||
if isinstance(cached, dict) and cached.get("sig") == sig:
|
||||
return cached
|
||||
|
||||
state = {
|
||||
"sig": state_sig,
|
||||
"mode": "disabled",
|
||||
"model": None,
|
||||
"optimizer": None,
|
||||
"expert_features": None,
|
||||
"step": 0,
|
||||
"last_loss": 0.0,
|
||||
"last_acc_policy": 0.0,
|
||||
"last_acc_expert": 0.0,
|
||||
front_table, front_duration = _motion_table_from_yaml(front_motion_path, T1_JOINT_NAMES, sample_dt)
|
||||
back_table, back_duration = _motion_table_from_yaml(back_motion_path, T1_JOINT_NAMES, sample_dt)
|
||||
|
||||
cache = {
|
||||
"sig": sig,
|
||||
"sample_dt": float(sample_dt),
|
||||
"front_motion": front_table.to(device=env.device),
|
||||
"back_motion": back_table.to(device=env.device),
|
||||
"front_duration": float(front_duration),
|
||||
"back_duration": float(back_duration),
|
||||
}
|
||||
|
||||
if amp_train_enabled:
|
||||
expert_features = _load_amp_expert_features(
|
||||
amp_expert_features_path,
|
||||
device=env.device,
|
||||
feature_dim=feature_dim,
|
||||
fallback_samples=max(disc_min_expert_samples, 512),
|
||||
)
|
||||
if expert_features is not None:
|
||||
model = AMPDiscriminator(input_dim=feature_dim, hidden_dims=tuple([hidden_dim] * hidden_layers)).to(env.device)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=float(disc_lr), weight_decay=float(disc_weight_decay))
|
||||
state["mode"] = "trainable"
|
||||
state["model"] = model
|
||||
state["optimizer"] = optimizer
|
||||
state["expert_features"] = expert_features
|
||||
elif amp_enabled and amp_model_path:
|
||||
model_path = Path(amp_model_path).expanduser()
|
||||
if model_path.is_file():
|
||||
try:
|
||||
model = torch.jit.load(str(model_path), map_location=env.device)
|
||||
model.eval()
|
||||
state["mode"] = "jit"
|
||||
state["model"] = model
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
env.extras[cache_key] = state
|
||||
return state
|
||||
env.extras[cache_key] = cache
|
||||
return cache
|
||||
|
||||
|
||||
def _build_amp_features(env: ManagerBasedRLEnv, feature_clip: float = 8.0) -> torch.Tensor:
|
||||
"""Build AMP-style discriminator features from robot kinematics."""
|
||||
robot_data = env.scene["robot"].data
|
||||
joint_pos_rel = robot_data.joint_pos - robot_data.default_joint_pos
|
||||
joint_vel = robot_data.joint_vel
|
||||
root_lin_vel = robot_data.root_lin_vel_w
|
||||
root_ang_vel = robot_data.root_ang_vel_w
|
||||
projected_gravity = robot_data.projected_gravity_b
|
||||
amp_features = torch.cat([joint_pos_rel, joint_vel, root_lin_vel, root_ang_vel, projected_gravity], dim=-1)
|
||||
amp_features = _safe_tensor(amp_features, nan=0.0, pos=feature_clip, neg=-feature_clip)
|
||||
return torch.clamp(amp_features, min=-feature_clip, max=feature_clip)
|
||||
|
||||
|
||||
def amp_style_prior_reward(
|
||||
def keyframe_motion_prior_reward(
|
||||
env: ManagerBasedRLEnv,
|
||||
amp_enabled: bool = False,
|
||||
amp_model_path: str = "",
|
||||
amp_train_enabled: bool = False,
|
||||
amp_expert_features_path: str = "",
|
||||
disc_hidden_dim: int = 256,
|
||||
disc_hidden_layers: int = 2,
|
||||
disc_lr: float = 3e-4,
|
||||
disc_weight_decay: float = 1e-6,
|
||||
disc_update_interval: int = 4,
|
||||
disc_batch_size: int = 1024,
|
||||
disc_min_expert_samples: int = 2048,
|
||||
feature_clip: float = 8.0,
|
||||
logit_scale: float = 1.0,
|
||||
amp_reward_gain: float = 1.0,
|
||||
front_motion_path: str = _DEFAULT_FRONT_KEYFRAME_YAML,
|
||||
back_motion_path: str = _DEFAULT_BACK_KEYFRAME_YAML,
|
||||
sample_dt: float = 0.04,
|
||||
pose_sigma: float = 0.42,
|
||||
vel_sigma: float = 1.6,
|
||||
joint_subset: str = "all",
|
||||
internal_reward_scale: float = 1.0,
|
||||
) -> torch.Tensor:
|
||||
"""AMP style prior reward with optional online discriminator training."""
|
||||
zeros = torch.zeros(env.num_envs, device=env.device)
|
||||
amp_score = zeros
|
||||
model_loaded = 0.0
|
||||
amp_train_active = 0.0
|
||||
disc_loss = 0.0
|
||||
disc_acc_policy = 0.0
|
||||
disc_acc_expert = 0.0
|
||||
|
||||
amp_features = _build_amp_features(env, feature_clip=feature_clip)
|
||||
amp_state = _get_amp_state(
|
||||
"""
|
||||
DeepMimic-style dense reward from keyframe get-up motions.
|
||||
- mode=1 uses front sequence
|
||||
- mode=0 uses back sequence
|
||||
"""
|
||||
motion_cache = _get_keyframe_motion_cache(
|
||||
env=env,
|
||||
amp_enabled=amp_enabled,
|
||||
amp_model_path=amp_model_path,
|
||||
amp_train_enabled=amp_train_enabled,
|
||||
amp_expert_features_path=amp_expert_features_path,
|
||||
feature_dim=int(amp_features.shape[-1]),
|
||||
disc_hidden_dim=disc_hidden_dim,
|
||||
disc_hidden_layers=disc_hidden_layers,
|
||||
disc_lr=disc_lr,
|
||||
disc_weight_decay=disc_weight_decay,
|
||||
disc_min_expert_samples=disc_min_expert_samples,
|
||||
front_motion_path=front_motion_path,
|
||||
back_motion_path=back_motion_path,
|
||||
sample_dt=sample_dt,
|
||||
)
|
||||
discriminator = amp_state.get("model", None)
|
||||
if discriminator is not None:
|
||||
model_loaded = 1.0
|
||||
|
||||
if amp_state.get("mode") == "trainable" and discriminator is not None:
|
||||
amp_train_active = 1.0
|
||||
optimizer = amp_state.get("optimizer", None)
|
||||
expert_features = amp_state.get("expert_features", None)
|
||||
amp_state["step"] = int(amp_state.get("step", 0)) + 1
|
||||
update_interval = max(int(disc_update_interval), 1)
|
||||
batch_size = max(int(disc_batch_size), 32)
|
||||
# env.extras["getup_mode"]: 1=front, 0=back
|
||||
getup_mode = env.extras.get("getup_mode", None)
|
||||
if not isinstance(getup_mode, torch.Tensor) or getup_mode.shape[0] != env.num_envs:
|
||||
getup_mode = torch.zeros(env.num_envs, device=env.device, dtype=torch.long)
|
||||
env.extras["getup_mode"] = getup_mode
|
||||
getup_mode = getup_mode.to(dtype=torch.long)
|
||||
use_front = getup_mode == 1
|
||||
|
||||
if optimizer is not None and isinstance(expert_features, torch.Tensor) and amp_state["step"] % update_interval == 0:
|
||||
policy_features = amp_features.detach()
|
||||
policy_count = policy_features.shape[0]
|
||||
if policy_count > batch_size:
|
||||
policy_ids = torch.randint(0, policy_count, (batch_size,), device=env.device)
|
||||
policy_batch = policy_features.index_select(0, policy_ids)
|
||||
else:
|
||||
policy_batch = policy_features
|
||||
step_dt = env.step_dt
|
||||
phase_time = torch.clamp(env.episode_length_buf * step_dt, min=0.0)
|
||||
|
||||
expert_count = expert_features.shape[0]
|
||||
expert_ids = torch.randint(0, expert_count, (policy_batch.shape[0],), device=env.device)
|
||||
expert_batch = expert_features.index_select(0, expert_ids)
|
||||
front_motion = motion_cache["front_motion"]
|
||||
back_motion = motion_cache["back_motion"]
|
||||
front_idx = torch.clamp((phase_time / sample_dt).to(torch.long), min=0, max=front_motion.shape[0] - 1)
|
||||
back_idx = torch.clamp((phase_time / sample_dt).to(torch.long), min=0, max=back_motion.shape[0] - 1)
|
||||
|
||||
discriminator.train()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
logits_expert = discriminator(expert_batch).squeeze(-1)
|
||||
logits_policy = discriminator(policy_batch).squeeze(-1)
|
||||
loss_expert = nn.functional.binary_cross_entropy_with_logits(logits_expert, torch.ones_like(logits_expert))
|
||||
loss_policy = nn.functional.binary_cross_entropy_with_logits(logits_policy, torch.zeros_like(logits_policy))
|
||||
loss = 0.5 * (loss_expert + loss_policy)
|
||||
loss.backward()
|
||||
nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0)
|
||||
optimizer.step()
|
||||
discriminator.eval()
|
||||
target_front = front_motion.index_select(0, front_idx)
|
||||
target_back = back_motion.index_select(0, back_idx)
|
||||
target_pos = torch.where(use_front.unsqueeze(-1), target_front, target_back)
|
||||
|
||||
with torch.no_grad():
|
||||
disc_loss = float(loss.detach().item())
|
||||
disc_acc_expert = float((torch.sigmoid(logits_expert) > 0.5).float().mean().item())
|
||||
disc_acc_policy = float((torch.sigmoid(logits_policy) < 0.5).float().mean().item())
|
||||
amp_state["last_loss"] = disc_loss
|
||||
amp_state["last_acc_expert"] = disc_acc_expert
|
||||
amp_state["last_acc_policy"] = disc_acc_policy
|
||||
else:
|
||||
disc_loss = float(amp_state.get("last_loss", 0.0))
|
||||
disc_acc_expert = float(amp_state.get("last_acc_expert", 0.0))
|
||||
disc_acc_policy = float(amp_state.get("last_acc_policy", 0.0))
|
||||
robot_data = env.scene["robot"].data
|
||||
current_pos = _safe_tensor(robot_data.joint_pos)
|
||||
current_vel = _safe_tensor(robot_data.joint_vel)
|
||||
|
||||
if discriminator is not None:
|
||||
discriminator.eval()
|
||||
with torch.no_grad():
|
||||
logits = discriminator(amp_features)
|
||||
if isinstance(logits, (tuple, list)):
|
||||
logits = logits[0]
|
||||
if logits.ndim > 1:
|
||||
logits = logits.squeeze(-1)
|
||||
logits = _safe_tensor(logits, nan=0.0, pos=20.0, neg=-20.0)
|
||||
amp_score = torch.sigmoid(logit_scale * logits)
|
||||
amp_score = _safe_tensor(amp_score, nan=0.0, pos=1.0, neg=0.0)
|
||||
if joint_subset == "legs":
|
||||
legs_idx, _ = env.scene["robot"].find_joints(".*(Hip|Knee|Ankle).*")
|
||||
if len(legs_idx) > 0:
|
||||
ids = torch.tensor(legs_idx, device=env.device, dtype=torch.long)
|
||||
current_pos = current_pos.index_select(1, ids)
|
||||
current_vel = current_vel.index_select(1, ids)
|
||||
target_pos = target_pos.index_select(1, ids)
|
||||
elif joint_subset == "core":
|
||||
core_idx, _ = env.scene["robot"].find_joints(".*(Waist|Hip|Knee|Ankle).*")
|
||||
if len(core_idx) > 0:
|
||||
ids = torch.tensor(core_idx, device=env.device, dtype=torch.long)
|
||||
current_pos = current_pos.index_select(1, ids)
|
||||
current_vel = current_vel.index_select(1, ids)
|
||||
target_pos = target_pos.index_select(1, ids)
|
||||
|
||||
amp_reward = _safe_tensor(amp_reward_gain * amp_score, nan=0.0, pos=10.0, neg=0.0)
|
||||
target_vel = torch.zeros_like(target_pos)
|
||||
pos_mse = torch.mean(torch.square(current_pos - target_pos), dim=-1)
|
||||
vel_mse = torch.mean(torch.square(current_vel - target_vel), dim=-1)
|
||||
|
||||
pose_sigma = max(float(pose_sigma), 1e-3)
|
||||
vel_sigma = max(float(vel_sigma), 1e-3)
|
||||
pose_reward = torch.exp(-pos_mse / pose_sigma)
|
||||
vel_reward = torch.exp(-vel_mse / vel_sigma)
|
||||
prior_reward = 0.75 * pose_reward + 0.25 * vel_reward
|
||||
prior_reward = _safe_tensor(prior_reward, nan=0.0, pos=1.0, neg=0.0)
|
||||
|
||||
log_dict = env.extras.get("log", {})
|
||||
if isinstance(log_dict, dict):
|
||||
log_dict["amp_score_mean"] = torch.mean(amp_score).detach().item()
|
||||
log_dict["amp_reward_mean"] = torch.mean(amp_reward).detach().item()
|
||||
log_dict["amp_model_loaded_mean"] = model_loaded
|
||||
log_dict["amp_train_active_mean"] = amp_train_active
|
||||
log_dict["amp_disc_loss_mean"] = disc_loss
|
||||
log_dict["amp_disc_acc_policy_mean"] = disc_acc_policy
|
||||
log_dict["amp_disc_acc_expert_mean"] = disc_acc_expert
|
||||
log_dict["keyframe_prior_mean"] = torch.mean(prior_reward).detach().item()
|
||||
log_dict["keyframe_front_ratio"] = torch.mean(use_front.float()).detach().item()
|
||||
env.extras["log"] = log_dict
|
||||
|
||||
return internal_reward_scale * amp_reward
|
||||
return internal_reward_scale * prior_reward
|
||||
|
||||
|
||||
def root_height_obs(env: ManagerBasedRLEnv) -> torch.Tensor:
|
||||
@@ -365,6 +314,10 @@ def reset_root_state_bimodal_lie_pose(
|
||||
-torch.ones(num_resets, device=env.device),
|
||||
)
|
||||
euler_angles[:, 1] = pitch_mag * pitch_sign
|
||||
# Cache get-up mode for motion priors: pitch<0 => front, pitch>=0 => back.
|
||||
if "getup_mode" not in env.extras or not isinstance(env.extras.get("getup_mode"), torch.Tensor):
|
||||
env.extras["getup_mode"] = torch.zeros(env.num_envs, device=env.device, dtype=torch.long)
|
||||
env.extras["getup_mode"][env_ids] = (pitch_sign < 0.0).to(torch.long)
|
||||
|
||||
yaw_min, yaw_max = yaw_abs_range
|
||||
yaw_mag = yaw_min + torch.rand(num_resets, device=env.device) * (yaw_max - yaw_min)
|
||||
@@ -778,6 +731,19 @@ class T1GetUpRewardCfg:
|
||||
"timer_name": "reward_stable_timer",
|
||||
},
|
||||
)
|
||||
keyframe_motion_prior = RewTerm(
|
||||
func=keyframe_motion_prior_reward,
|
||||
weight=0.0,
|
||||
params={
|
||||
"front_motion_path": _DEFAULT_FRONT_KEYFRAME_YAML,
|
||||
"back_motion_path": _DEFAULT_BACK_KEYFRAME_YAML,
|
||||
"sample_dt": 0.04,
|
||||
"pose_sigma": 0.42,
|
||||
"vel_sigma": 1.6,
|
||||
"joint_subset": "all",
|
||||
"internal_reward_scale": 1.0,
|
||||
},
|
||||
)
|
||||
# AMP reward is disabled by default until a discriminator model path is provided.
|
||||
amp_style_prior = RewTerm(
|
||||
func=amp_style_prior_reward,
|
||||
@@ -794,6 +760,7 @@ class T1GetUpRewardCfg:
|
||||
"disc_update_interval": 4,
|
||||
"disc_batch_size": 1024,
|
||||
"disc_min_expert_samples": 2048,
|
||||
"disc_history_steps": 4,
|
||||
"feature_clip": 8.0,
|
||||
"logit_scale": 1.0,
|
||||
"amp_reward_gain": 1.0,
|
||||
|
||||
Reference in New Issue
Block a user