187 lines
6.5 KiB
Python
187 lines
6.5 KiB
Python
from __future__ import annotations
|
|
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import torch
|
|
import yaml
|
|
|
|
|
|
RIGHT_JOINT_SIGN_FLIP = {
|
|
"Shoulder_Roll",
|
|
"Elbow_Yaw",
|
|
"Hip_Roll",
|
|
"Hip_Yaw",
|
|
"Ankle_Roll",
|
|
}
|
|
|
|
JOINT_NAME_ALIAS = {
|
|
"AAHead_yaw": "Head_yaw",
|
|
}
|
|
|
|
|
|
def _safe_load_yaml(path: Path) -> dict[str, Any]:
|
|
with path.open("r", encoding="utf-8") as f:
|
|
payload = yaml.safe_load(f) or {}
|
|
if not isinstance(payload, dict):
|
|
raise ValueError(f"Invalid keyframe yaml: {path}")
|
|
return payload
|
|
|
|
|
|
def _build_interpolated_motor_table(
|
|
keyframes: list[dict[str, Any]],
|
|
sample_dt: float,
|
|
) -> tuple[torch.Tensor, list[str], torch.Tensor]:
|
|
if len(keyframes) == 0:
|
|
raise ValueError("No keyframes in yaml")
|
|
|
|
durations: list[float] = []
|
|
motor_names: set[str] = set()
|
|
for frame in keyframes:
|
|
durations.append(float(frame.get("delta", 0.1)))
|
|
motors = frame.get("motor_positions", {})
|
|
if isinstance(motors, dict):
|
|
motor_names.update(motors.keys())
|
|
names = sorted(motor_names)
|
|
if len(names) == 0:
|
|
raise ValueError("No motor_positions found in keyframes")
|
|
|
|
frame_count = len(keyframes)
|
|
motor_count = len(names)
|
|
times = torch.zeros(frame_count, dtype=torch.float32)
|
|
values = torch.full((frame_count, motor_count), float("nan"), dtype=torch.float32)
|
|
|
|
elapsed = 0.0
|
|
name_to_idx = {name: idx for idx, name in enumerate(names)}
|
|
for i, frame in enumerate(keyframes):
|
|
elapsed += max(float(frame.get("delta", 0.1)), 1e-4)
|
|
times[i] = elapsed
|
|
motors = frame.get("motor_positions", {})
|
|
if not isinstance(motors, dict):
|
|
continue
|
|
for motor_name, motor_deg in motors.items():
|
|
if motor_name not in name_to_idx:
|
|
continue
|
|
values[i, name_to_idx[motor_name]] = float(motor_deg) * (torch.pi / 180.0)
|
|
|
|
first_valid = torch.nan_to_num(values[0], nan=0.0)
|
|
values[0] = first_valid
|
|
for i in range(1, frame_count):
|
|
cur = values[i]
|
|
prev = values[i - 1]
|
|
values[i] = torch.where(torch.isnan(cur), prev, cur)
|
|
|
|
sample_dt = max(float(sample_dt), 1e-3)
|
|
sample_times = torch.arange(0.0, float(times[-1]) + 1e-6, sample_dt, dtype=torch.float32)
|
|
sample_times[0] = max(sample_times[0], 1e-6)
|
|
sample_times = torch.clamp(sample_times, min=times[0], max=times[-1])
|
|
|
|
upper = torch.searchsorted(times, sample_times, right=True)
|
|
upper = torch.clamp(upper, min=1, max=frame_count - 1)
|
|
lower = upper - 1
|
|
|
|
t0 = times.index_select(0, lower)
|
|
t1 = times.index_select(0, upper)
|
|
v0 = values.index_select(0, lower)
|
|
v1 = values.index_select(0, upper)
|
|
alpha = ((sample_times - t0) / (t1 - t0 + 1e-6)).unsqueeze(-1)
|
|
interpolated = v0 + alpha * (v1 - v0)
|
|
|
|
return interpolated, names, sample_times
|
|
|
|
|
|
def _map_motors_to_joint_targets(
|
|
motor_table: torch.Tensor,
|
|
motor_names: list[str],
|
|
joint_names: list[str],
|
|
) -> torch.Tensor:
|
|
motor_to_idx = {name: idx for idx, name in enumerate(motor_names)}
|
|
joint_targets = torch.zeros((motor_table.shape[0], len(joint_names)), dtype=torch.float32)
|
|
|
|
for joint_idx, joint_name in enumerate(joint_names):
|
|
alias_name = JOINT_NAME_ALIAS.get(joint_name, joint_name)
|
|
base_name = alias_name.replace("Left_", "").replace("Right_", "")
|
|
source_name = None
|
|
if alias_name in motor_to_idx:
|
|
source_name = alias_name
|
|
elif base_name in motor_to_idx:
|
|
source_name = base_name
|
|
if source_name is None:
|
|
continue
|
|
sign = 1.0
|
|
if alias_name.startswith("Right_") and base_name in RIGHT_JOINT_SIGN_FLIP:
|
|
sign = -1.0
|
|
joint_targets[:, joint_idx] = sign * motor_table[:, motor_to_idx[source_name]]
|
|
|
|
return joint_targets
|
|
|
|
|
|
def _features_from_keyframes(
|
|
keyframe_yaml_path: str,
|
|
joint_names: list[str],
|
|
sample_dt: float,
|
|
repeat_count: int,
|
|
) -> torch.Tensor:
|
|
payload = _safe_load_yaml(Path(keyframe_yaml_path).expanduser().resolve())
|
|
keyframes = payload.get("keyframes", [])
|
|
if not isinstance(keyframes, list):
|
|
raise ValueError(f"Invalid keyframe list in {keyframe_yaml_path}")
|
|
|
|
motor_table, motor_names, _ = _build_interpolated_motor_table(keyframes, sample_dt=sample_dt)
|
|
joint_pos = _map_motors_to_joint_targets(motor_table, motor_names, joint_names)
|
|
joint_vel = torch.zeros_like(joint_pos)
|
|
if joint_pos.shape[0] > 1:
|
|
joint_vel[1:] = (joint_pos[1:] - joint_pos[:-1]) / max(sample_dt, 1e-3)
|
|
joint_vel[0] = joint_vel[1]
|
|
root_lin = torch.zeros((joint_pos.shape[0], 3), dtype=torch.float32)
|
|
root_ang = torch.zeros((joint_pos.shape[0], 3), dtype=torch.float32)
|
|
projected_gravity = torch.zeros((joint_pos.shape[0], 3), dtype=torch.float32)
|
|
projected_gravity[:, 2] = -1.0
|
|
features = torch.cat([joint_pos, joint_vel, root_lin, root_ang, projected_gravity], dim=-1)
|
|
repeat_count = max(int(repeat_count), 1)
|
|
if repeat_count > 1:
|
|
features = features.repeat(repeat_count, 1)
|
|
return features
|
|
|
|
|
|
def build_amp_expert_features_from_getup_keyframes(
|
|
*,
|
|
front_yaml_path: str,
|
|
back_yaml_path: str,
|
|
joint_names: list[str],
|
|
output_path: str,
|
|
sample_dt: float = 0.04,
|
|
repeat_count: int = 16,
|
|
) -> tuple[str, tuple[int, int]]:
|
|
"""Generate AMP expert feature tensor from front/back get-up keyframe yaml files."""
|
|
front_features = _features_from_keyframes(
|
|
keyframe_yaml_path=front_yaml_path,
|
|
joint_names=joint_names,
|
|
sample_dt=sample_dt,
|
|
repeat_count=repeat_count,
|
|
)
|
|
back_features = _features_from_keyframes(
|
|
keyframe_yaml_path=back_yaml_path,
|
|
joint_names=joint_names,
|
|
sample_dt=sample_dt,
|
|
repeat_count=repeat_count,
|
|
)
|
|
expert_features = torch.cat([front_features, back_features], dim=0).contiguous()
|
|
|
|
out_path = Path(output_path).expanduser().resolve()
|
|
out_path.parent.mkdir(parents=True, exist_ok=True)
|
|
torch.save(
|
|
{
|
|
"expert_features": expert_features,
|
|
"meta": {
|
|
"front_yaml_path": str(Path(front_yaml_path).expanduser().resolve()),
|
|
"back_yaml_path": str(Path(back_yaml_path).expanduser().resolve()),
|
|
"sample_dt": float(sample_dt),
|
|
"repeat_count": int(repeat_count),
|
|
"feature_dim": int(expert_features.shape[-1]),
|
|
},
|
|
},
|
|
str(out_path),
|
|
)
|
|
return str(out_path), (int(expert_features.shape[0]), int(expert_features.shape[1]))
|