chore: replace hard-coded action values with constants throughout all the source code (#2055)
* chore: replace hard-coded 'action' values with constants throughout all the source code * chore(tests): replace hard-coded action values with constants throughout all the test code
This commit is contained in:
@@ -44,6 +44,7 @@ from lerobot.robots import ( # noqa: F401
|
||||
so100_follower,
|
||||
so101_follower,
|
||||
)
|
||||
from lerobot.utils.constants import ACTION
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.utils import (
|
||||
init_logging,
|
||||
@@ -78,16 +79,16 @@ def replay(cfg: ReplayConfig):
|
||||
|
||||
robot = make_robot_from_config(cfg.robot)
|
||||
dataset = LeRobotDataset(cfg.dataset.repo_id, root=cfg.dataset.root, episodes=[cfg.dataset.episode])
|
||||
actions = dataset.hf_dataset.select_columns("action")
|
||||
actions = dataset.hf_dataset.select_columns(ACTION)
|
||||
robot.connect()
|
||||
|
||||
log_say("Replaying episode", cfg.play_sounds, blocking=True)
|
||||
for idx in range(dataset.num_frames):
|
||||
start_episode_t = time.perf_counter()
|
||||
|
||||
action_array = actions[idx]["action"]
|
||||
action_array = actions[idx][ACTION]
|
||||
action = {}
|
||||
for i, name in enumerate(dataset.features["action"]["names"]):
|
||||
for i, name in enumerate(dataset.features[ACTION]["names"]):
|
||||
key = f"{name.removeprefix('main_')}.pos"
|
||||
action[key] = action_array[i].item()
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ from lerobot.policies.factory import make_pre_post_processors
|
||||
from lerobot.processor import make_default_processors
|
||||
from lerobot.robots.lekiwi import LeKiwiClient, LeKiwiClientConfig
|
||||
from lerobot.scripts.lerobot_record import record_loop
|
||||
from lerobot.utils.constants import OBS_STR
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import init_rerun
|
||||
@@ -42,7 +42,7 @@ robot = LeKiwiClient(robot_config)
|
||||
policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
|
||||
|
||||
# Configure the dataset features
|
||||
action_features = hw_to_dataset_features(robot.action_features, "action")
|
||||
action_features = hw_to_dataset_features(robot.action_features, ACTION)
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, OBS_STR)
|
||||
dataset_features = {**action_features, **obs_features}
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ from lerobot.robots.lekiwi.lekiwi_client import LeKiwiClient
|
||||
from lerobot.scripts.lerobot_record import record_loop
|
||||
from lerobot.teleoperators.keyboard import KeyboardTeleop, KeyboardTeleopConfig
|
||||
from lerobot.teleoperators.so100_leader import SO100Leader, SO100LeaderConfig
|
||||
from lerobot.utils.constants import OBS_STR
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import init_rerun
|
||||
@@ -48,7 +48,7 @@ keyboard = KeyboardTeleop(keyboard_config)
|
||||
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
|
||||
|
||||
# Configure the dataset features
|
||||
action_features = hw_to_dataset_features(robot.action_features, "action")
|
||||
action_features = hw_to_dataset_features(robot.action_features, ACTION)
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, OBS_STR)
|
||||
dataset_features = {**action_features, **obs_features}
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ import time
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.robots.lekiwi.config_lekiwi import LeKiwiClientConfig
|
||||
from lerobot.robots.lekiwi.lekiwi_client import LeKiwiClient
|
||||
from lerobot.utils.constants import ACTION
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.utils import log_say
|
||||
|
||||
@@ -34,7 +35,7 @@ robot = LeKiwiClient(robot_config)
|
||||
dataset = LeRobotDataset("<hf_username>/<dataset_repo_id>", episodes=[EPISODE_IDX])
|
||||
# Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0
|
||||
episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == EPISODE_IDX)
|
||||
actions = episode_frames.select_columns("action")
|
||||
actions = episode_frames.select_columns(ACTION)
|
||||
|
||||
# Connect to the robot
|
||||
robot.connect()
|
||||
@@ -49,7 +50,7 @@ for idx in range(len(episode_frames)):
|
||||
|
||||
# Get recorded action from dataset
|
||||
action = {
|
||||
name: float(actions[idx]["action"][i]) for i, name in enumerate(dataset.features["action"]["names"])
|
||||
name: float(actions[idx][ACTION][i]) for i, name in enumerate(dataset.features[ACTION]["names"])
|
||||
}
|
||||
|
||||
# Send action to robot
|
||||
|
||||
@@ -28,6 +28,7 @@ from lerobot.robots.so100_follower.robot_kinematic_processor import (
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
from lerobot.utils.constants import ACTION
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.utils import log_say
|
||||
|
||||
@@ -66,7 +67,7 @@ robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotOb
|
||||
dataset = LeRobotDataset(HF_REPO_ID, episodes=[EPISODE_IDX])
|
||||
# Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0
|
||||
episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == EPISODE_IDX)
|
||||
actions = episode_frames.select_columns("action")
|
||||
actions = episode_frames.select_columns(ACTION)
|
||||
|
||||
# Connect to the robot
|
||||
robot.connect()
|
||||
@@ -81,7 +82,7 @@ for idx in range(len(episode_frames)):
|
||||
|
||||
# Get recorded action from dataset
|
||||
ee_action = {
|
||||
name: float(actions[idx]["action"][i]) for i, name in enumerate(dataset.features["action"]["names"])
|
||||
name: float(actions[idx][ACTION][i]) for i, name in enumerate(dataset.features[ACTION]["names"])
|
||||
}
|
||||
|
||||
# Get robot observation
|
||||
|
||||
@@ -29,6 +29,7 @@ from lerobot.robots.so100_follower.robot_kinematic_processor import (
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
from lerobot.utils.constants import ACTION
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.utils import log_say
|
||||
|
||||
@@ -67,7 +68,7 @@ robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotOb
|
||||
dataset = LeRobotDataset(HF_REPO_ID, episodes=[EPISODE_IDX])
|
||||
# Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0
|
||||
episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == EPISODE_IDX)
|
||||
actions = episode_frames.select_columns("action")
|
||||
actions = episode_frames.select_columns(ACTION)
|
||||
|
||||
# Connect to the robot
|
||||
robot.connect()
|
||||
@@ -82,7 +83,7 @@ for idx in range(len(episode_frames)):
|
||||
|
||||
# Get recorded action from dataset
|
||||
ee_action = {
|
||||
name: float(actions[idx]["action"][i]) for i, name in enumerate(dataset.features["action"]["names"])
|
||||
name: float(actions[idx][ACTION][i]) for i, name in enumerate(dataset.features[ACTION]["names"])
|
||||
}
|
||||
|
||||
# Get robot observation
|
||||
|
||||
@@ -27,7 +27,7 @@ from lerobot.datasets.lerobot_dataset import (
|
||||
)
|
||||
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
|
||||
from lerobot.datasets.transforms import ImageTransforms
|
||||
from lerobot.utils.constants import OBS_PREFIX
|
||||
from lerobot.utils.constants import ACTION, OBS_PREFIX
|
||||
|
||||
IMAGENET_STATS = {
|
||||
"mean": [[[0.485]], [[0.456]], [[0.406]]], # (c,1,1)
|
||||
@@ -57,7 +57,7 @@ def resolve_delta_timestamps(
|
||||
for key in ds_meta.features:
|
||||
if key == "next.reward" and cfg.reward_delta_indices is not None:
|
||||
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.reward_delta_indices]
|
||||
if key == "action" and cfg.action_delta_indices is not None:
|
||||
if key == ACTION and cfg.action_delta_indices is not None:
|
||||
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.action_delta_indices]
|
||||
if key.startswith(OBS_PREFIX) and cfg.observation_delta_indices is not None:
|
||||
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.observation_delta_indices]
|
||||
|
||||
@@ -132,7 +132,7 @@ def aggregate_pipeline_dataset_features(
|
||||
# Convert the processed features into the final dataset format.
|
||||
dataset_features = {}
|
||||
if processed_features[ACTION]:
|
||||
dataset_features.update(hw_to_dataset_features(processed_features["action"], ACTION, use_videos))
|
||||
dataset_features.update(hw_to_dataset_features(processed_features[ACTION], ACTION, use_videos))
|
||||
if processed_features[OBS_STR]:
|
||||
dataset_features.update(hw_to_dataset_features(processed_features[OBS_STR], OBS_STR, use_videos))
|
||||
|
||||
|
||||
@@ -43,7 +43,7 @@ from lerobot.datasets.backward_compatibility import (
|
||||
BackwardCompatibilityError,
|
||||
ForwardCompatibilityError,
|
||||
)
|
||||
from lerobot.utils.constants import OBS_ENV_STATE, OBS_STR
|
||||
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_STR
|
||||
from lerobot.utils.utils import is_valid_numpy_dtype_string
|
||||
|
||||
DEFAULT_CHUNK_SIZE = 1000 # Max number of files per chunk
|
||||
@@ -646,7 +646,7 @@ def hw_to_dataset_features(
|
||||
}
|
||||
cam_fts = {key: shape for key, shape in hw_features.items() if isinstance(shape, tuple)}
|
||||
|
||||
if joint_fts and prefix == "action":
|
||||
if joint_fts and prefix == ACTION:
|
||||
features[prefix] = {
|
||||
"dtype": "float32",
|
||||
"shape": (len(joint_fts),),
|
||||
@@ -733,7 +733,7 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea
|
||||
type = FeatureType.ENV
|
||||
elif key.startswith(OBS_STR):
|
||||
type = FeatureType.STATE
|
||||
elif key.startswith("action"):
|
||||
elif key.startswith(ACTION):
|
||||
type = FeatureType.ACTION
|
||||
else:
|
||||
continue
|
||||
|
||||
@@ -53,12 +53,12 @@ class AlohaEnv(EnvConfig):
|
||||
render_mode: str = "rgb_array"
|
||||
features: dict[str, PolicyFeature] = field(
|
||||
default_factory=lambda: {
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(14,)),
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(14,)),
|
||||
}
|
||||
)
|
||||
features_map: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"action": ACTION,
|
||||
ACTION: ACTION,
|
||||
"agent_pos": OBS_STATE,
|
||||
"top": f"{OBS_IMAGE}.top",
|
||||
"pixels/top": f"{OBS_IMAGES}.top",
|
||||
@@ -93,13 +93,13 @@ class PushtEnv(EnvConfig):
|
||||
visualization_height: int = 384
|
||||
features: dict[str, PolicyFeature] = field(
|
||||
default_factory=lambda: {
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)),
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(2,)),
|
||||
"agent_pos": PolicyFeature(type=FeatureType.STATE, shape=(2,)),
|
||||
}
|
||||
)
|
||||
features_map: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"action": ACTION,
|
||||
ACTION: ACTION,
|
||||
"agent_pos": OBS_STATE,
|
||||
"environment_state": OBS_ENV_STATE,
|
||||
"pixels": OBS_IMAGE,
|
||||
@@ -135,13 +135,13 @@ class XarmEnv(EnvConfig):
|
||||
visualization_height: int = 384
|
||||
features: dict[str, PolicyFeature] = field(
|
||||
default_factory=lambda: {
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(4,)),
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(4,)),
|
||||
"pixels": PolicyFeature(type=FeatureType.VISUAL, shape=(84, 84, 3)),
|
||||
}
|
||||
)
|
||||
features_map: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"action": ACTION,
|
||||
ACTION: ACTION,
|
||||
"agent_pos": OBS_STATE,
|
||||
"pixels": OBS_IMAGE,
|
||||
}
|
||||
@@ -259,12 +259,12 @@ class LiberoEnv(EnvConfig):
|
||||
camera_name_mapping: dict[str, str] | None = (None,)
|
||||
features: dict[str, PolicyFeature] = field(
|
||||
default_factory=lambda: {
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
|
||||
}
|
||||
)
|
||||
features_map: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"action": ACTION,
|
||||
ACTION: ACTION,
|
||||
"agent_pos": OBS_STATE,
|
||||
"pixels/agentview_image": f"{OBS_IMAGES}.image",
|
||||
"pixels/robot0_eye_in_hand_image": f"{OBS_IMAGES}.image2",
|
||||
|
||||
@@ -394,7 +394,7 @@ class ACT(nn.Module):
|
||||
latent dimension.
|
||||
"""
|
||||
if self.config.use_vae and self.training:
|
||||
assert "action" in batch, (
|
||||
assert ACTION in batch, (
|
||||
"actions must be provided when using the variational objective in training mode."
|
||||
)
|
||||
|
||||
@@ -404,7 +404,7 @@ class ACT(nn.Module):
|
||||
batch_size = batch[OBS_ENV_STATE].shape[0]
|
||||
|
||||
# Prepare the latent for input to the transformer encoder.
|
||||
if self.config.use_vae and "action" in batch and self.training:
|
||||
if self.config.use_vae and ACTION in batch and self.training:
|
||||
# Prepare the input to the VAE encoder: [cls, *joint_space_configuration, *action_sequence].
|
||||
cls_embed = einops.repeat(
|
||||
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
|
||||
@@ -412,7 +412,7 @@ class ACT(nn.Module):
|
||||
if self.config.robot_state_feature:
|
||||
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch[OBS_STATE])
|
||||
robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D)
|
||||
action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D)
|
||||
action_embed = self.vae_encoder_action_input_proj(batch[ACTION]) # (B, S, D)
|
||||
|
||||
if self.config.robot_state_feature:
|
||||
vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D)
|
||||
|
||||
@@ -82,7 +82,7 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
"""Clear observation and action queues. Should be called on `env.reset()`"""
|
||||
self._queues = {
|
||||
OBS_STATE: deque(maxlen=self.config.n_obs_steps),
|
||||
"action": deque(maxlen=self.config.n_action_steps),
|
||||
ACTION: deque(maxlen=self.config.n_action_steps),
|
||||
}
|
||||
if self.config.image_features:
|
||||
self._queues[OBS_IMAGES] = deque(maxlen=self.config.n_obs_steps)
|
||||
@@ -306,10 +306,10 @@ class DiffusionModel(nn.Module):
|
||||
}
|
||||
"""
|
||||
# Input validation.
|
||||
assert set(batch).issuperset({OBS_STATE, "action", "action_is_pad"})
|
||||
assert set(batch).issuperset({OBS_STATE, ACTION, "action_is_pad"})
|
||||
assert OBS_IMAGES in batch or OBS_ENV_STATE in batch
|
||||
n_obs_steps = batch[OBS_STATE].shape[1]
|
||||
horizon = batch["action"].shape[1]
|
||||
horizon = batch[ACTION].shape[1]
|
||||
assert horizon == self.config.horizon
|
||||
assert n_obs_steps == self.config.n_obs_steps
|
||||
|
||||
@@ -317,7 +317,7 @@ class DiffusionModel(nn.Module):
|
||||
global_cond = self._prepare_global_conditioning(batch) # (B, global_cond_dim)
|
||||
|
||||
# Forward diffusion.
|
||||
trajectory = batch["action"]
|
||||
trajectory = batch[ACTION]
|
||||
# Sample noise to add to the trajectory.
|
||||
eps = torch.randn(trajectory.shape, device=trajectory.device)
|
||||
# Sample a random noising timestep for each item in the batch.
|
||||
@@ -338,7 +338,7 @@ class DiffusionModel(nn.Module):
|
||||
if self.config.prediction_type == "epsilon":
|
||||
target = eps
|
||||
elif self.config.prediction_type == "sample":
|
||||
target = batch["action"]
|
||||
target = batch[ACTION]
|
||||
else:
|
||||
raise ValueError(f"Unsupported prediction type {self.config.prediction_type}")
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ import torch
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||
from lerobot.policies.factory import make_policy
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
|
||||
|
||||
def display(tensor: torch.Tensor):
|
||||
@@ -73,7 +73,7 @@ def main():
|
||||
for cam_key, uint_chw_array in example["images"].items():
|
||||
batch[f"{OBS_IMAGES}.{cam_key}"] = torch.from_numpy(uint_chw_array) / 255.0
|
||||
batch[OBS_STATE] = torch.from_numpy(example["state"])
|
||||
batch["action"] = torch.from_numpy(outputs["actions"])
|
||||
batch[ACTION] = torch.from_numpy(outputs["actions"])
|
||||
batch["task"] = example["prompt"]
|
||||
|
||||
if model_name == "pi0_aloha_towel":
|
||||
@@ -117,7 +117,7 @@ def main():
|
||||
actions.append(action)
|
||||
|
||||
actions = torch.stack(actions, dim=1)
|
||||
pi_actions = batch["action"]
|
||||
pi_actions = batch[ACTION]
|
||||
print("actions")
|
||||
display(actions)
|
||||
print()
|
||||
|
||||
@@ -225,7 +225,7 @@ class SACConfig(PreTrainedConfig):
|
||||
"You must provide either 'observation.state' or an image observation (key starting with 'observation.image') in the input features"
|
||||
)
|
||||
|
||||
if "action" not in self.output_features:
|
||||
if ACTION not in self.output_features:
|
||||
raise ValueError("You must provide 'action' in the output features")
|
||||
|
||||
@property
|
||||
|
||||
@@ -31,7 +31,7 @@ from torch.distributions import MultivariateNormal, TanhTransform, Transform, Tr
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.sac.configuration_sac import SACConfig, is_image_feature
|
||||
from lerobot.policies.utils import get_device_from_parameters
|
||||
from lerobot.utils.constants import OBS_ENV_STATE, OBS_STATE
|
||||
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_STATE
|
||||
|
||||
DISCRETE_DIMENSION_INDEX = -1 # Gripper is always the last dimension
|
||||
|
||||
@@ -51,7 +51,7 @@ class SACPolicy(
|
||||
self.config = config
|
||||
|
||||
# Determine action dimension and initialize all components
|
||||
continuous_action_dim = config.output_features["action"].shape[0]
|
||||
continuous_action_dim = config.output_features[ACTION].shape[0]
|
||||
self._init_encoders()
|
||||
self._init_critics(continuous_action_dim)
|
||||
self._init_actor(continuous_action_dim)
|
||||
@@ -158,7 +158,7 @@ class SACPolicy(
|
||||
The computed loss tensor
|
||||
"""
|
||||
# Extract common components from batch
|
||||
actions: Tensor = batch["action"]
|
||||
actions: Tensor = batch[ACTION]
|
||||
observations: dict[str, Tensor] = batch["state"]
|
||||
observation_features: Tensor = batch.get("observation_feature")
|
||||
|
||||
|
||||
@@ -92,7 +92,7 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
"""
|
||||
self._queues = {
|
||||
OBS_STATE: deque(maxlen=1),
|
||||
"action": deque(maxlen=max(self.config.n_action_steps, self.config.n_action_repeats)),
|
||||
ACTION: deque(maxlen=max(self.config.n_action_steps, self.config.n_action_repeats)),
|
||||
}
|
||||
if self.config.image_features:
|
||||
self._queues[OBS_IMAGE] = deque(maxlen=1)
|
||||
|
||||
@@ -23,7 +23,7 @@ from typing import Any
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.utils.constants import OBS_PREFIX
|
||||
from lerobot.utils.constants import ACTION, OBS_PREFIX
|
||||
|
||||
from .core import EnvTransition, PolicyAction, RobotAction, RobotObservation, TransitionKey
|
||||
|
||||
@@ -344,7 +344,7 @@ def batch_to_transition(batch: dict[str, Any]) -> EnvTransition:
|
||||
if not isinstance(batch, dict):
|
||||
raise ValueError(f"EnvTransition must be a dictionary. Got {type(batch).__name__}")
|
||||
|
||||
action = batch.get("action")
|
||||
action = batch.get(ACTION)
|
||||
if action is not None and not isinstance(action, PolicyAction):
|
||||
raise ValueError(f"Action should be a PolicyAction type got {type(action)}")
|
||||
|
||||
@@ -354,7 +354,7 @@ def batch_to_transition(batch: dict[str, Any]) -> EnvTransition:
|
||||
|
||||
return create_transition(
|
||||
observation=observation_keys if observation_keys else None,
|
||||
action=batch.get("action"),
|
||||
action=batch.get(ACTION),
|
||||
reward=batch.get("next.reward", 0.0),
|
||||
done=batch.get("next.done", False),
|
||||
truncated=batch.get("next.truncated", False),
|
||||
@@ -379,7 +379,7 @@ def transition_to_batch(transition: EnvTransition) -> dict[str, Any]:
|
||||
raise ValueError(f"Transition should be a EnvTransition type (dict) got {type(transition)}")
|
||||
|
||||
batch = {
|
||||
"action": transition.get(TransitionKey.ACTION),
|
||||
ACTION: transition.get(TransitionKey.ACTION),
|
||||
"next.reward": transition.get(TransitionKey.REWARD, 0.0),
|
||||
"next.done": transition.get(TransitionKey.DONE, False),
|
||||
"next.truncated": transition.get(TransitionKey.TRUNCATED, False),
|
||||
|
||||
@@ -59,6 +59,7 @@ from safetensors.torch import load_file as load_safetensors
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.policies.factory import get_policy_class, make_policy_config, make_pre_post_processors
|
||||
from lerobot.utils.constants import ACTION
|
||||
|
||||
|
||||
def extract_normalization_stats(state_dict: dict[str, torch.Tensor]) -> dict[str, dict[str, torch.Tensor]]:
|
||||
@@ -196,7 +197,7 @@ def detect_features_and_norm_modes(
|
||||
feature_type = FeatureType.VISUAL
|
||||
elif "state" in key:
|
||||
feature_type = FeatureType.STATE
|
||||
elif "action" in key:
|
||||
elif ACTION in key:
|
||||
feature_type = FeatureType.ACTION
|
||||
else:
|
||||
feature_type = FeatureType.STATE # Default
|
||||
@@ -215,7 +216,7 @@ def detect_features_and_norm_modes(
|
||||
feature_type = FeatureType.VISUAL
|
||||
elif "state" in key or "joint" in key or "position" in key:
|
||||
feature_type = FeatureType.STATE
|
||||
elif "action" in key:
|
||||
elif ACTION in key:
|
||||
feature_type = FeatureType.ACTION
|
||||
else:
|
||||
feature_type = FeatureType.STATE
|
||||
@@ -321,7 +322,7 @@ def convert_features_to_policy_features(features_dict: dict[str, dict]) -> dict[
|
||||
feature_type = FeatureType.VISUAL
|
||||
elif "state" in key:
|
||||
feature_type = FeatureType.STATE
|
||||
elif "action" in key:
|
||||
elif ACTION in key:
|
||||
feature_type = FeatureType.ACTION
|
||||
else:
|
||||
feature_type = FeatureType.STATE
|
||||
|
||||
@@ -26,6 +26,7 @@ from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.utils.constants import ACTION
|
||||
|
||||
from .converters import from_tensor_to_numpy, to_tensor
|
||||
from .core import EnvTransition, PolicyAction, TransitionKey
|
||||
@@ -272,7 +273,7 @@ class _NormalizationMixin:
|
||||
Returns:
|
||||
The transformed action tensor.
|
||||
"""
|
||||
processed_action = self._apply_transform(action, "action", FeatureType.ACTION, inverse=inverse)
|
||||
processed_action = self._apply_transform(action, ACTION, FeatureType.ACTION, inverse=inverse)
|
||||
return processed_action
|
||||
|
||||
def _apply_transform(
|
||||
|
||||
@@ -5,6 +5,7 @@ import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.processor import ActionProcessorStep, PolicyAction, ProcessorStepRegistry, RobotAction
|
||||
from lerobot.utils.constants import ACTION
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -23,7 +24,7 @@ class RobotActionToPolicyActionProcessorStep(ActionProcessorStep):
|
||||
return asdict(self)
|
||||
|
||||
def transform_features(self, features):
|
||||
features[PipelineFeatureType.ACTION]["action"] = PolicyFeature(
|
||||
features[PipelineFeatureType.ACTION][ACTION] = PolicyFeature(
|
||||
type=FeatureType.ACTION, shape=(len(self.motor_names),)
|
||||
)
|
||||
return features
|
||||
|
||||
@@ -24,7 +24,7 @@ import torch.nn.functional as F # noqa: N812
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.utils.constants import OBS_IMAGE
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGE
|
||||
from lerobot.utils.transition import Transition
|
||||
|
||||
|
||||
@@ -467,7 +467,7 @@ class ReplayBuffer:
|
||||
if list_transition:
|
||||
first_transition = list_transition[0]
|
||||
first_state = {k: v.to(device) for k, v in first_transition["state"].items()}
|
||||
first_action = first_transition["action"].to(device)
|
||||
first_action = first_transition[ACTION].to(device)
|
||||
|
||||
# Get complementary info if available
|
||||
first_complementary_info = None
|
||||
@@ -492,7 +492,7 @@ class ReplayBuffer:
|
||||
elif isinstance(v, torch.Tensor):
|
||||
data[k] = v.to(storage_device)
|
||||
|
||||
action = data["action"]
|
||||
action = data[ACTION]
|
||||
|
||||
replay_buffer.add(
|
||||
state=data["state"],
|
||||
@@ -530,8 +530,8 @@ class ReplayBuffer:
|
||||
|
||||
# Add "action"
|
||||
sample_action = self.actions[0]
|
||||
act_info = guess_feature_info(t=sample_action, name="action")
|
||||
features["action"] = act_info
|
||||
act_info = guess_feature_info(t=sample_action, name=ACTION)
|
||||
features[ACTION] = act_info
|
||||
|
||||
# Add "reward" and "done"
|
||||
features["next.reward"] = {"dtype": "float32", "shape": (1,)}
|
||||
@@ -577,7 +577,7 @@ class ReplayBuffer:
|
||||
frame_dict[key] = self.states[key][actual_idx].cpu()
|
||||
|
||||
# Fill action, reward, done
|
||||
frame_dict["action"] = self.actions[actual_idx].cpu()
|
||||
frame_dict[ACTION] = self.actions[actual_idx].cpu()
|
||||
frame_dict["next.reward"] = torch.tensor([self.rewards[actual_idx]], dtype=torch.float32).cpu()
|
||||
frame_dict["next.done"] = torch.tensor([self.dones[actual_idx]], dtype=torch.bool).cpu()
|
||||
frame_dict["task"] = task_name
|
||||
@@ -668,7 +668,7 @@ class ReplayBuffer:
|
||||
current_state[key] = val.unsqueeze(0) # Add batch dimension
|
||||
|
||||
# ----- 2) Action -----
|
||||
action = current_sample["action"].unsqueeze(0) # Add batch dimension
|
||||
action = current_sample[ACTION].unsqueeze(0) # Add batch dimension
|
||||
|
||||
# ----- 3) Reward and done -----
|
||||
reward = float(current_sample["next.reward"].item()) # ensure float
|
||||
@@ -788,8 +788,8 @@ def concatenate_batch_transitions(
|
||||
}
|
||||
|
||||
# Concatenate basic fields
|
||||
left_batch_transitions["action"] = torch.cat(
|
||||
[left_batch_transitions["action"], right_batch_transition["action"]], dim=0
|
||||
left_batch_transitions[ACTION] = torch.cat(
|
||||
[left_batch_transitions[ACTION], right_batch_transition[ACTION]], dim=0
|
||||
)
|
||||
left_batch_transitions["reward"] = torch.cat(
|
||||
[left_batch_transitions["reward"], right_batch_transition["reward"]], dim=0
|
||||
|
||||
@@ -73,7 +73,7 @@ from lerobot.teleoperators import (
|
||||
)
|
||||
from lerobot.teleoperators.teleoperator import Teleoperator
|
||||
from lerobot.teleoperators.utils import TeleopEvents
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.utils import log_say
|
||||
|
||||
@@ -601,7 +601,7 @@ def control_loop(
|
||||
if cfg.mode == "record":
|
||||
action_features = teleop_device.action_features
|
||||
features = {
|
||||
"action": action_features,
|
||||
ACTION: action_features,
|
||||
"next.reward": {"dtype": "float32", "shape": (1,), "names": None},
|
||||
"next.done": {"dtype": "bool", "shape": (1,), "names": None},
|
||||
}
|
||||
@@ -672,7 +672,7 @@ def control_loop(
|
||||
)
|
||||
frame = {
|
||||
**observations,
|
||||
"action": action_to_record.cpu(),
|
||||
ACTION: action_to_record.cpu(),
|
||||
"next.reward": np.array([transition[TransitionKey.REWARD]], dtype=np.float32),
|
||||
"next.done": np.array([terminated or truncated], dtype=bool),
|
||||
}
|
||||
@@ -733,7 +733,7 @@ def replay_trajectory(
|
||||
download_videos=False,
|
||||
)
|
||||
episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == cfg.dataset.replay_episode)
|
||||
actions = episode_frames.select_columns("action")
|
||||
actions = episode_frames.select_columns(ACTION)
|
||||
|
||||
_, info = env.reset()
|
||||
|
||||
@@ -741,7 +741,7 @@ def replay_trajectory(
|
||||
start_time = time.perf_counter()
|
||||
transition = create_transition(
|
||||
observation=env.get_raw_joint_positions() if hasattr(env, "get_raw_joint_positions") else {},
|
||||
action=action_data["action"],
|
||||
action=action_data[ACTION],
|
||||
)
|
||||
transition = action_processor(transition)
|
||||
env.step(transition[TransitionKey.ACTION])
|
||||
|
||||
@@ -80,6 +80,7 @@ from lerobot.transport.utils import (
|
||||
state_to_bytes,
|
||||
)
|
||||
from lerobot.utils.constants import (
|
||||
ACTION,
|
||||
CHECKPOINTS_DIR,
|
||||
LAST_CHECKPOINT_LINK,
|
||||
PRETRAINED_MODEL_DIR,
|
||||
@@ -402,7 +403,7 @@ def add_actor_information_and_train(
|
||||
left_batch_transitions=batch, right_batch_transition=batch_offline
|
||||
)
|
||||
|
||||
actions = batch["action"]
|
||||
actions = batch[ACTION]
|
||||
rewards = batch["reward"]
|
||||
observations = batch["state"]
|
||||
next_observations = batch["next_state"]
|
||||
@@ -415,7 +416,7 @@ def add_actor_information_and_train(
|
||||
|
||||
# Create a batch dictionary with all required elements for the forward method
|
||||
forward_batch = {
|
||||
"action": actions,
|
||||
ACTION: actions,
|
||||
"reward": rewards,
|
||||
"state": observations,
|
||||
"next_state": next_observations,
|
||||
@@ -460,7 +461,7 @@ def add_actor_information_and_train(
|
||||
left_batch_transitions=batch, right_batch_transition=batch_offline
|
||||
)
|
||||
|
||||
actions = batch["action"]
|
||||
actions = batch[ACTION]
|
||||
rewards = batch["reward"]
|
||||
observations = batch["state"]
|
||||
next_observations = batch["next_state"]
|
||||
@@ -474,7 +475,7 @@ def add_actor_information_and_train(
|
||||
|
||||
# Create a batch dictionary with all required elements for the forward method
|
||||
forward_batch = {
|
||||
"action": actions,
|
||||
ACTION: actions,
|
||||
"reward": rewards,
|
||||
"state": observations,
|
||||
"next_state": next_observations,
|
||||
@@ -1155,7 +1156,7 @@ def process_transitions(
|
||||
# Skip transitions with NaN values
|
||||
if check_nan_in_transition(
|
||||
observations=transition["state"],
|
||||
actions=transition["action"],
|
||||
actions=transition[ACTION],
|
||||
next_state=transition["next_state"],
|
||||
):
|
||||
logging.warning("[LEARNER] NaN detected in transition, skipping")
|
||||
|
||||
@@ -23,7 +23,7 @@ from typing import Any
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from lerobot.utils.constants import OBS_STATE
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
|
||||
from ..robot import Robot
|
||||
@@ -330,7 +330,7 @@ class LeKiwiClient(Robot):
|
||||
actions = np.array([action.get(k, 0.0) for k in self._state_order], dtype=np.float32)
|
||||
|
||||
action_sent = {key: actions[i] for i, key in enumerate(self._state_order)}
|
||||
action_sent["action"] = actions
|
||||
action_sent[ACTION] = actions
|
||||
return action_sent
|
||||
|
||||
def disconnect(self):
|
||||
|
||||
@@ -75,7 +75,7 @@ import torch.utils.data
|
||||
import tqdm
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.utils.constants import OBS_STATE
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
|
||||
|
||||
class EpisodeSampler(torch.utils.data.Sampler):
|
||||
@@ -157,9 +157,9 @@ def visualize_dataset(
|
||||
rr.log(key, rr.Image(to_hwc_uint8_numpy(batch[key][i])))
|
||||
|
||||
# display each dimension of action space (e.g. actuators command)
|
||||
if "action" in batch:
|
||||
for dim_idx, val in enumerate(batch["action"][i]):
|
||||
rr.log(f"action/{dim_idx}", rr.Scalar(val.item()))
|
||||
if ACTION in batch:
|
||||
for dim_idx, val in enumerate(batch[ACTION][i]):
|
||||
rr.log(f"{ACTION}/{dim_idx}", rr.Scalar(val.item()))
|
||||
|
||||
# display each dimension of observed state space (e.g. agent position in joint space)
|
||||
if OBS_STATE in batch:
|
||||
|
||||
@@ -81,7 +81,7 @@ from lerobot.envs.utils import (
|
||||
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
|
||||
from lerobot.utils.constants import OBS_STR
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.io_utils import write_video
|
||||
from lerobot.utils.random_utils import set_seed
|
||||
from lerobot.utils.utils import (
|
||||
@@ -213,7 +213,7 @@ def rollout(
|
||||
|
||||
# Stack the sequence along the first dimension so that we have (batch, sequence, *) tensors.
|
||||
ret = {
|
||||
"action": torch.stack(all_actions, dim=1),
|
||||
ACTION: torch.stack(all_actions, dim=1),
|
||||
"reward": torch.stack(all_rewards, dim=1),
|
||||
"success": torch.stack(all_successes, dim=1),
|
||||
"done": torch.stack(all_dones, dim=1),
|
||||
@@ -440,14 +440,14 @@ def _compile_episode_data(
|
||||
"""
|
||||
ep_dicts = []
|
||||
total_frames = 0
|
||||
for ep_ix in range(rollout_data["action"].shape[0]):
|
||||
for ep_ix in range(rollout_data[ACTION].shape[0]):
|
||||
# + 2 to include the first done frame and the last observation frame.
|
||||
num_frames = done_indices[ep_ix].item() + 2
|
||||
total_frames += num_frames
|
||||
|
||||
# Here we do `num_frames - 1` as we don't want to include the last observation frame just yet.
|
||||
ep_dict = {
|
||||
"action": rollout_data["action"][ep_ix, : num_frames - 1],
|
||||
ACTION: rollout_data[ACTION][ep_ix, : num_frames - 1],
|
||||
"episode_index": torch.tensor([start_episode_index + ep_ix] * (num_frames - 1)),
|
||||
"frame_index": torch.arange(0, num_frames - 1, 1),
|
||||
"timestamp": torch.arange(0, num_frames - 1, 1) / fps,
|
||||
|
||||
@@ -109,7 +109,7 @@ from lerobot.teleoperators import ( # noqa: F401
|
||||
so101_leader,
|
||||
)
|
||||
from lerobot.teleoperators.keyboard.teleop_keyboard import KeyboardTeleop
|
||||
from lerobot.utils.constants import OBS_STR
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.control_utils import (
|
||||
init_keyboard_listener,
|
||||
is_headless,
|
||||
@@ -319,7 +319,7 @@ def record_loop(
|
||||
robot_type=robot.robot_type,
|
||||
)
|
||||
|
||||
action_names = dataset.features["action"]["names"]
|
||||
action_names = dataset.features[ACTION]["names"]
|
||||
act_processed_policy: RobotAction = {
|
||||
f"{name}": float(action_values[i]) for i, name in enumerate(action_names)
|
||||
}
|
||||
@@ -361,7 +361,7 @@ def record_loop(
|
||||
|
||||
# Write to dataset
|
||||
if dataset is not None:
|
||||
action_frame = build_dataset_frame(dataset.features, action_values, prefix="action")
|
||||
action_frame = build_dataset_frame(dataset.features, action_values, prefix=ACTION)
|
||||
frame = {**observation_frame, **action_frame, "task": single_task}
|
||||
dataset.add_frame(frame)
|
||||
|
||||
|
||||
@@ -60,6 +60,7 @@ from lerobot.robots import ( # noqa: F401
|
||||
so100_follower,
|
||||
so101_follower,
|
||||
)
|
||||
from lerobot.utils.constants import ACTION
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.utils import (
|
||||
init_logging,
|
||||
@@ -99,7 +100,7 @@ def replay(cfg: ReplayConfig):
|
||||
|
||||
# Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0
|
||||
episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == cfg.dataset.episode)
|
||||
actions = episode_frames.select_columns("action")
|
||||
actions = episode_frames.select_columns(ACTION)
|
||||
|
||||
robot.connect()
|
||||
|
||||
@@ -107,9 +108,9 @@ def replay(cfg: ReplayConfig):
|
||||
for idx in range(len(episode_frames)):
|
||||
start_episode_t = time.perf_counter()
|
||||
|
||||
action_array = actions[idx]["action"]
|
||||
action_array = actions[idx][ACTION]
|
||||
action = {}
|
||||
for i, name in enumerate(dataset.features["action"]["names"]):
|
||||
for i, name in enumerate(dataset.features[ACTION]["names"]):
|
||||
action[name] = action_array[i]
|
||||
|
||||
robot_obs = robot.get_observation()
|
||||
|
||||
@@ -18,6 +18,8 @@ from typing import TypedDict
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.utils.constants import ACTION
|
||||
|
||||
|
||||
class Transition(TypedDict):
|
||||
state: dict[str, torch.Tensor]
|
||||
@@ -39,7 +41,7 @@ def move_transition_to_device(transition: Transition, device: str = "cpu") -> Tr
|
||||
}
|
||||
|
||||
# Move action to device
|
||||
transition["action"] = transition["action"].to(device, non_blocking=non_blocking)
|
||||
transition[ACTION] = transition[ACTION].to(device, non_blocking=non_blocking)
|
||||
|
||||
# Move reward and done if they are tensors
|
||||
if isinstance(transition["reward"], torch.Tensor):
|
||||
|
||||
@@ -21,7 +21,7 @@ from huggingface_hub import DatasetCard
|
||||
|
||||
from lerobot.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
|
||||
from lerobot.datasets.utils import combine_feature_dicts, create_lerobot_dataset_card, hf_transform_to_torch
|
||||
from lerobot.utils.constants import OBS_IMAGES
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES
|
||||
|
||||
|
||||
def test_default_parameters():
|
||||
@@ -59,14 +59,14 @@ def test_calculate_episode_data_index():
|
||||
|
||||
def test_merge_simple_vectors():
|
||||
g1 = {
|
||||
"action": {
|
||||
ACTION: {
|
||||
"dtype": "float32",
|
||||
"shape": (2,),
|
||||
"names": ["ee.x", "ee.y"],
|
||||
}
|
||||
}
|
||||
g2 = {
|
||||
"action": {
|
||||
ACTION: {
|
||||
"dtype": "float32",
|
||||
"shape": (2,),
|
||||
"names": ["ee.y", "ee.z"],
|
||||
@@ -75,23 +75,23 @@ def test_merge_simple_vectors():
|
||||
|
||||
out = combine_feature_dicts(g1, g2)
|
||||
|
||||
assert "action" in out
|
||||
assert out["action"]["dtype"] == "float32"
|
||||
assert ACTION in out
|
||||
assert out[ACTION]["dtype"] == "float32"
|
||||
# Names merged with preserved order and de-dupuplication
|
||||
assert out["action"]["names"] == ["ee.x", "ee.y", "ee.z"]
|
||||
assert out[ACTION]["names"] == ["ee.x", "ee.y", "ee.z"]
|
||||
# Shape correctly recomputed from names length
|
||||
assert out["action"]["shape"] == (3,)
|
||||
assert out[ACTION]["shape"] == (3,)
|
||||
|
||||
|
||||
def test_merge_multiple_groups_order_and_dedup():
|
||||
g1 = {"action": {"dtype": "float32", "shape": (2,), "names": ["a", "b"]}}
|
||||
g2 = {"action": {"dtype": "float32", "shape": (2,), "names": ["b", "c"]}}
|
||||
g3 = {"action": {"dtype": "float32", "shape": (3,), "names": ["a", "c", "d"]}}
|
||||
g1 = {ACTION: {"dtype": "float32", "shape": (2,), "names": ["a", "b"]}}
|
||||
g2 = {ACTION: {"dtype": "float32", "shape": (2,), "names": ["b", "c"]}}
|
||||
g3 = {ACTION: {"dtype": "float32", "shape": (3,), "names": ["a", "c", "d"]}}
|
||||
|
||||
out = combine_feature_dicts(g1, g2, g3)
|
||||
|
||||
assert out["action"]["names"] == ["a", "b", "c", "d"]
|
||||
assert out["action"]["shape"] == (4,)
|
||||
assert out[ACTION]["names"] == ["a", "b", "c", "d"]
|
||||
assert out[ACTION]["shape"] == (4,)
|
||||
|
||||
|
||||
def test_non_vector_last_wins_for_images():
|
||||
@@ -117,8 +117,8 @@ def test_non_vector_last_wins_for_images():
|
||||
|
||||
|
||||
def test_dtype_mismatch_raises():
|
||||
g1 = {"action": {"dtype": "float32", "shape": (1,), "names": ["a"]}}
|
||||
g2 = {"action": {"dtype": "float64", "shape": (1,), "names": ["b"]}}
|
||||
g1 = {ACTION: {"dtype": "float32", "shape": (1,), "names": ["a"]}}
|
||||
g2 = {ACTION: {"dtype": "float64", "shape": (1,), "names": ["b"]}}
|
||||
|
||||
with pytest.raises(ValueError, match="dtype mismatch for 'action'"):
|
||||
_ = combine_feature_dicts(g1, g2)
|
||||
|
||||
@@ -46,7 +46,7 @@ from lerobot.datasets.utils import (
|
||||
from lerobot.envs.factory import make_env_config
|
||||
from lerobot.policies.factory import make_policy_config
|
||||
from lerobot.robots import make_robot_from_config
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE, OBS_STR
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE, OBS_STR
|
||||
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
|
||||
from tests.mocks.mock_robot import MockRobotConfig
|
||||
from tests.utils import require_x86_64_kernel
|
||||
@@ -75,7 +75,7 @@ def test_same_attributes_defined(tmp_path, lerobot_dataset_factory):
|
||||
"""
|
||||
# Instantiate both ways
|
||||
robot = make_robot_from_config(MockRobotConfig())
|
||||
action_features = hw_to_dataset_features(robot.action_features, "action", True)
|
||||
action_features = hw_to_dataset_features(robot.action_features, ACTION, True)
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, OBS_STR, True)
|
||||
dataset_features = {**action_features, **obs_features}
|
||||
root_create = tmp_path / "create"
|
||||
@@ -393,7 +393,7 @@ def test_factory(env_name, repo_id, policy_name):
|
||||
item = dataset[0]
|
||||
|
||||
keys_ndim_required = [
|
||||
("action", 1, True),
|
||||
(ACTION, 1, True),
|
||||
("episode_index", 0, True),
|
||||
("frame_index", 0, True),
|
||||
("timestamp", 0, True),
|
||||
@@ -668,7 +668,7 @@ def test_update_chunk_settings(tmp_path, empty_lerobot_dataset_factory):
|
||||
"shape": (6,),
|
||||
"names": ["shoulder_pan", "shoulder_lift", "elbow", "wrist_1", "wrist_2", "wrist_3"],
|
||||
},
|
||||
"action": {
|
||||
ACTION: {
|
||||
"dtype": "float32",
|
||||
"shape": (6,),
|
||||
"names": ["shoulder_pan", "shoulder_lift", "elbow", "wrist_1", "wrist_2", "wrist_3"],
|
||||
@@ -775,7 +775,7 @@ def test_update_chunk_settings_video_dataset(tmp_path):
|
||||
"shape": (480, 640, 3),
|
||||
"names": ["height", "width", "channels"],
|
||||
},
|
||||
"action": {"dtype": "float32", "shape": (6,), "names": ["j1", "j2", "j3", "j4", "j5", "j6"]},
|
||||
ACTION: {"dtype": "float32", "shape": (6,), "names": ["j1", "j2", "j3", "j4", "j5", "j6"]},
|
||||
}
|
||||
|
||||
# Create video dataset
|
||||
@@ -842,7 +842,7 @@ def test_multi_episode_metadata_consistency(tmp_path, empty_lerobot_dataset_fact
|
||||
"""Test episode metadata consistency across multiple episodes."""
|
||||
features = {
|
||||
"state": {"dtype": "float32", "shape": (3,), "names": ["x", "y", "z"]},
|
||||
"action": {"dtype": "float32", "shape": (2,), "names": ["v", "w"]},
|
||||
ACTION: {"dtype": "float32", "shape": (2,), "names": ["v", "w"]},
|
||||
}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features, use_videos=False)
|
||||
|
||||
@@ -852,7 +852,7 @@ def test_multi_episode_metadata_consistency(tmp_path, empty_lerobot_dataset_fact
|
||||
|
||||
for episode_idx in range(num_episodes):
|
||||
for _ in range(frames_per_episode[episode_idx]):
|
||||
dataset.add_frame({"state": torch.randn(3), "action": torch.randn(2), "task": tasks[episode_idx]})
|
||||
dataset.add_frame({"state": torch.randn(3), ACTION: torch.randn(2), "task": tasks[episode_idx]})
|
||||
dataset.save_episode()
|
||||
|
||||
# Load and validate episode metadata
|
||||
@@ -927,7 +927,7 @@ def test_statistics_metadata_validation(tmp_path, empty_lerobot_dataset_factory)
|
||||
"""Test that statistics are properly computed and stored for all features."""
|
||||
features = {
|
||||
"state": {"dtype": "float32", "shape": (2,), "names": ["pos", "vel"]},
|
||||
"action": {"dtype": "float32", "shape": (1,), "names": ["force"]},
|
||||
ACTION: {"dtype": "float32", "shape": (1,), "names": ["force"]},
|
||||
}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features, use_videos=False)
|
||||
|
||||
@@ -941,7 +941,7 @@ def test_statistics_metadata_validation(tmp_path, empty_lerobot_dataset_factory)
|
||||
for frame_idx in range(frames_per_episode[episode_idx]):
|
||||
state_data = torch.tensor([frame_idx * 0.1, frame_idx * 0.2], dtype=torch.float32)
|
||||
action_data = torch.tensor([frame_idx * 0.05], dtype=torch.float32)
|
||||
dataset.add_frame({"state": state_data, "action": action_data, "task": "stats_test"})
|
||||
dataset.add_frame({"state": state_data, ACTION: action_data, "task": "stats_test"})
|
||||
dataset.save_episode()
|
||||
|
||||
loaded_dataset = LeRobotDataset(dataset.repo_id, root=dataset.root)
|
||||
|
||||
@@ -19,6 +19,7 @@ import torch
|
||||
|
||||
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
|
||||
from lerobot.datasets.utils import safe_shard
|
||||
from lerobot.utils.constants import ACTION
|
||||
from tests.fixtures.constants import DUMMY_REPO_ID
|
||||
|
||||
|
||||
@@ -234,7 +235,7 @@ def test_frames_with_delta_consistency(tmp_path, lerobot_dataset_factory, state_
|
||||
delta_timestamps = {
|
||||
camera_key: state_deltas,
|
||||
"state": state_deltas,
|
||||
"action": action_deltas,
|
||||
ACTION: action_deltas,
|
||||
}
|
||||
|
||||
ds = lerobot_dataset_factory(
|
||||
@@ -319,7 +320,7 @@ def test_frames_with_delta_consistency_with_shards(
|
||||
delta_timestamps = {
|
||||
camera_key: state_deltas,
|
||||
"state": state_deltas,
|
||||
"action": action_deltas,
|
||||
ACTION: action_deltas,
|
||||
}
|
||||
|
||||
ds = lerobot_dataset_factory(
|
||||
|
||||
4
tests/fixtures/constants.py
vendored
4
tests/fixtures/constants.py
vendored
@@ -11,13 +11,13 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME
|
||||
from lerobot.utils.constants import ACTION, HF_LEROBOT_HOME
|
||||
|
||||
LEROBOT_TEST_DIR = HF_LEROBOT_HOME / "_testing"
|
||||
DUMMY_REPO_ID = "dummy/repo"
|
||||
DUMMY_ROBOT_TYPE = "dummy_robot"
|
||||
DUMMY_MOTOR_FEATURES = {
|
||||
"action": {
|
||||
ACTION: {
|
||||
"dtype": "float32",
|
||||
"shape": (6,),
|
||||
"names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"],
|
||||
|
||||
@@ -59,7 +59,7 @@ def dummy_dataset_metadata(lerobot_dataset_metadata_factory, info_factory, tmp_p
|
||||
},
|
||||
}
|
||||
motor_features = {
|
||||
"action": {
|
||||
ACTION: {
|
||||
"dtype": "float32",
|
||||
"shape": (6,),
|
||||
"names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"],
|
||||
@@ -287,7 +287,7 @@ def test_multikey_construction(multikey: bool):
|
||||
),
|
||||
}
|
||||
output_features = {
|
||||
"action": PolicyFeature(
|
||||
ACTION: PolicyFeature(
|
||||
type=FeatureType.ACTION,
|
||||
shape=(5,),
|
||||
),
|
||||
@@ -304,7 +304,7 @@ def test_multikey_construction(multikey: bool):
|
||||
output_features = {}
|
||||
output_features["action.first_three_motors"] = PolicyFeature(type=FeatureType.ACTION, shape=(3,))
|
||||
output_features["action.last_two_motors"] = PolicyFeature(type=FeatureType.ACTION, shape=(2,))
|
||||
output_features["action"] = PolicyFeature(
|
||||
output_features[ACTION] = PolicyFeature(
|
||||
type=FeatureType.ACTION,
|
||||
shape=(5,),
|
||||
)
|
||||
|
||||
@@ -25,7 +25,7 @@ from lerobot.policies.sac.configuration_sac import (
|
||||
PolicyConfig,
|
||||
SACConfig,
|
||||
)
|
||||
from lerobot.utils.constants import OBS_IMAGE, OBS_STATE
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE
|
||||
|
||||
|
||||
def test_sac_config_default_initialization():
|
||||
@@ -46,7 +46,7 @@ def test_sac_config_default_initialization():
|
||||
"min": [0.0, 0.0],
|
||||
"max": [1.0, 1.0],
|
||||
},
|
||||
"action": {
|
||||
ACTION: {
|
||||
"min": [0.0, 0.0, 0.0],
|
||||
"max": [1.0, 1.0, 1.0],
|
||||
},
|
||||
@@ -99,7 +99,7 @@ def test_sac_config_default_initialization():
|
||||
"min": [0.0, 0.0],
|
||||
"max": [1.0, 1.0],
|
||||
},
|
||||
"action": {
|
||||
ACTION: {
|
||||
"min": [0.0, 0.0, 0.0],
|
||||
"max": [1.0, 1.0, 1.0],
|
||||
},
|
||||
@@ -193,7 +193,7 @@ def test_sac_config_custom_initialization():
|
||||
def test_validate_features():
|
||||
config = SACConfig(
|
||||
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,))},
|
||||
output_features={"action": PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
|
||||
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
|
||||
)
|
||||
config.validate_features()
|
||||
|
||||
@@ -201,7 +201,7 @@ def test_validate_features():
|
||||
def test_validate_features_missing_observation():
|
||||
config = SACConfig(
|
||||
input_features={"wrong_key": PolicyFeature(type=FeatureType.STATE, shape=(10,))},
|
||||
output_features={"action": PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
|
||||
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
|
||||
)
|
||||
with pytest.raises(
|
||||
ValueError, match="You must provide either 'observation.state' or an image observation"
|
||||
|
||||
@@ -23,7 +23,7 @@ from torch import Tensor, nn
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.policies.sac.modeling_sac import MLP, SACPolicy
|
||||
from lerobot.utils.constants import OBS_IMAGE, OBS_STATE
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE
|
||||
from lerobot.utils.random_utils import seeded_context, set_seed
|
||||
|
||||
try:
|
||||
@@ -105,7 +105,7 @@ def create_default_train_batch(
|
||||
batch_size: int = 8, state_dim: int = 10, action_dim: int = 10
|
||||
) -> dict[str, Tensor]:
|
||||
return {
|
||||
"action": create_dummy_action(batch_size, action_dim),
|
||||
ACTION: create_dummy_action(batch_size, action_dim),
|
||||
"reward": torch.randn(batch_size),
|
||||
"state": create_dummy_state(batch_size, state_dim),
|
||||
"next_state": create_dummy_state(batch_size, state_dim),
|
||||
@@ -117,7 +117,7 @@ def create_train_batch_with_visual_input(
|
||||
batch_size: int = 8, state_dim: int = 10, action_dim: int = 10
|
||||
) -> dict[str, Tensor]:
|
||||
return {
|
||||
"action": create_dummy_action(batch_size, action_dim),
|
||||
ACTION: create_dummy_action(batch_size, action_dim),
|
||||
"reward": torch.randn(batch_size),
|
||||
"state": create_dummy_with_visual_input(batch_size, state_dim),
|
||||
"next_state": create_dummy_with_visual_input(batch_size, state_dim),
|
||||
@@ -182,13 +182,13 @@ def create_default_config(
|
||||
|
||||
config = SACConfig(
|
||||
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))},
|
||||
output_features={"action": PolicyFeature(type=FeatureType.ACTION, shape=(continuous_action_dim,))},
|
||||
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(continuous_action_dim,))},
|
||||
dataset_stats={
|
||||
OBS_STATE: {
|
||||
"min": [0.0] * state_dim,
|
||||
"max": [1.0] * state_dim,
|
||||
},
|
||||
"action": {
|
||||
ACTION: {
|
||||
"min": [0.0] * continuous_action_dim,
|
||||
"max": [1.0] * continuous_action_dim,
|
||||
},
|
||||
|
||||
@@ -2,7 +2,7 @@ import torch
|
||||
|
||||
from lerobot.processor import DataProcessorPipeline, TransitionKey
|
||||
from lerobot.processor.converters import batch_to_transition, transition_to_batch
|
||||
from lerobot.utils.constants import OBS_IMAGE, OBS_PREFIX, OBS_STATE
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_PREFIX, OBS_STATE
|
||||
|
||||
|
||||
def _dummy_batch():
|
||||
@@ -11,7 +11,7 @@ def _dummy_batch():
|
||||
f"{OBS_IMAGE}.left": torch.randn(1, 3, 128, 128),
|
||||
f"{OBS_IMAGE}.right": torch.randn(1, 3, 128, 128),
|
||||
OBS_STATE: torch.tensor([[0.1, 0.2, 0.3, 0.4]]),
|
||||
"action": torch.tensor([[0.5]]),
|
||||
ACTION: torch.tensor([[0.5]]),
|
||||
"next.reward": 1.0,
|
||||
"next.done": False,
|
||||
"next.truncated": False,
|
||||
@@ -37,7 +37,7 @@ def test_observation_grouping_roundtrip():
|
||||
assert torch.allclose(batch_out[OBS_STATE], batch_in[OBS_STATE])
|
||||
|
||||
# Check other fields
|
||||
assert torch.allclose(batch_out["action"], batch_in["action"])
|
||||
assert torch.allclose(batch_out[ACTION], batch_in[ACTION])
|
||||
assert batch_out["next.reward"] == batch_in["next.reward"]
|
||||
assert batch_out["next.done"] == batch_in["next.done"]
|
||||
assert batch_out["next.truncated"] == batch_in["next.truncated"]
|
||||
@@ -50,7 +50,7 @@ def test_batch_to_transition_observation_grouping():
|
||||
f"{OBS_IMAGE}.top": torch.randn(1, 3, 128, 128),
|
||||
f"{OBS_IMAGE}.left": torch.randn(1, 3, 128, 128),
|
||||
OBS_STATE: [1, 2, 3, 4],
|
||||
"action": torch.tensor([0.1, 0.2, 0.3, 0.4]),
|
||||
ACTION: torch.tensor([0.1, 0.2, 0.3, 0.4]),
|
||||
"next.reward": 1.5,
|
||||
"next.done": True,
|
||||
"next.truncated": False,
|
||||
@@ -114,7 +114,7 @@ def test_transition_to_batch_observation_flattening():
|
||||
assert batch[OBS_STATE] == [1, 2, 3, 4]
|
||||
|
||||
# Check other fields are mapped to next.* format
|
||||
assert batch["action"] == "action_data"
|
||||
assert batch[ACTION] == "action_data"
|
||||
assert batch["next.reward"] == 1.5
|
||||
assert batch["next.done"]
|
||||
assert not batch["next.truncated"]
|
||||
@@ -124,7 +124,7 @@ def test_transition_to_batch_observation_flattening():
|
||||
def test_no_observation_keys():
|
||||
"""Test behavior when there are no observation.* keys."""
|
||||
batch = {
|
||||
"action": torch.tensor([1.0, 2.0]),
|
||||
ACTION: torch.tensor([1.0, 2.0]),
|
||||
"next.reward": 2.0,
|
||||
"next.done": False,
|
||||
"next.truncated": True,
|
||||
@@ -145,7 +145,7 @@ def test_no_observation_keys():
|
||||
|
||||
# Round trip should work
|
||||
reconstructed_batch = transition_to_batch(transition)
|
||||
assert torch.allclose(reconstructed_batch["action"], torch.tensor([1.0, 2.0]))
|
||||
assert torch.allclose(reconstructed_batch[ACTION], torch.tensor([1.0, 2.0]))
|
||||
assert reconstructed_batch["next.reward"] == 2.0
|
||||
assert not reconstructed_batch["next.done"]
|
||||
assert reconstructed_batch["next.truncated"]
|
||||
@@ -154,7 +154,7 @@ def test_no_observation_keys():
|
||||
|
||||
def test_minimal_batch():
|
||||
"""Test with minimal batch containing only observation.* and action."""
|
||||
batch = {OBS_STATE: "minimal_state", "action": torch.tensor([0.5])}
|
||||
batch = {OBS_STATE: "minimal_state", ACTION: torch.tensor([0.5])}
|
||||
|
||||
transition = batch_to_transition(batch)
|
||||
|
||||
@@ -172,7 +172,7 @@ def test_minimal_batch():
|
||||
# Round trip
|
||||
reconstructed_batch = transition_to_batch(transition)
|
||||
assert reconstructed_batch[OBS_STATE] == "minimal_state"
|
||||
assert torch.allclose(reconstructed_batch["action"], torch.tensor([0.5]))
|
||||
assert torch.allclose(reconstructed_batch[ACTION], torch.tensor([0.5]))
|
||||
assert reconstructed_batch["next.reward"] == 0.0
|
||||
assert not reconstructed_batch["next.done"]
|
||||
assert not reconstructed_batch["next.truncated"]
|
||||
@@ -196,7 +196,7 @@ def test_empty_batch():
|
||||
|
||||
# Round trip
|
||||
reconstructed_batch = transition_to_batch(transition)
|
||||
assert reconstructed_batch["action"] is None
|
||||
assert reconstructed_batch[ACTION] is None
|
||||
assert reconstructed_batch["next.reward"] == 0.0
|
||||
assert not reconstructed_batch["next.done"]
|
||||
assert not reconstructed_batch["next.truncated"]
|
||||
@@ -209,7 +209,7 @@ def test_complex_nested_observation():
|
||||
f"{OBS_IMAGE}.top": {"image": torch.randn(1, 3, 128, 128), "timestamp": 1234567890},
|
||||
f"{OBS_IMAGE}.left": {"image": torch.randn(1, 3, 128, 128), "timestamp": 1234567891},
|
||||
OBS_STATE: torch.randn(7),
|
||||
"action": torch.randn(8),
|
||||
ACTION: torch.randn(8),
|
||||
"next.reward": 3.14,
|
||||
"next.done": False,
|
||||
"next.truncated": True,
|
||||
@@ -237,7 +237,7 @@ def test_complex_nested_observation():
|
||||
)
|
||||
|
||||
# Check action tensor
|
||||
assert torch.allclose(batch["action"], reconstructed_batch["action"])
|
||||
assert torch.allclose(batch[ACTION], reconstructed_batch[ACTION])
|
||||
|
||||
# Check other fields
|
||||
assert batch["next.reward"] == reconstructed_batch["next.reward"]
|
||||
@@ -266,7 +266,7 @@ def test_custom_converter():
|
||||
|
||||
batch = {
|
||||
OBS_STATE: torch.randn(1, 4),
|
||||
"action": torch.randn(1, 2),
|
||||
ACTION: torch.randn(1, 2),
|
||||
"next.reward": 1.0,
|
||||
"next.done": False,
|
||||
}
|
||||
@@ -276,4 +276,4 @@ def test_custom_converter():
|
||||
# Check the reward was doubled by our custom converter
|
||||
assert result["next.reward"] == 2.0
|
||||
assert torch.allclose(result[OBS_STATE], batch[OBS_STATE])
|
||||
assert torch.allclose(result["action"], batch["action"])
|
||||
assert torch.allclose(result[ACTION], batch[ACTION])
|
||||
|
||||
@@ -9,7 +9,7 @@ from lerobot.processor.converters import (
|
||||
to_tensor,
|
||||
transition_to_batch,
|
||||
)
|
||||
from lerobot.utils.constants import OBS_STATE, OBS_STR
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE, OBS_STR
|
||||
|
||||
|
||||
# Tests for the unified to_tensor function
|
||||
@@ -118,16 +118,16 @@ def test_to_tensor_dictionaries():
|
||||
|
||||
# Nested dictionary
|
||||
nested = {
|
||||
"action": {"mean": [0.1, 0.2], "std": [1.0, 2.0]},
|
||||
ACTION: {"mean": [0.1, 0.2], "std": [1.0, 2.0]},
|
||||
OBS_STR: {"mean": np.array([0.5, 0.6]), "count": 10},
|
||||
}
|
||||
result = to_tensor(nested)
|
||||
assert isinstance(result, dict)
|
||||
assert isinstance(result["action"], dict)
|
||||
assert isinstance(result[ACTION], dict)
|
||||
assert isinstance(result[OBS_STR], dict)
|
||||
assert isinstance(result["action"]["mean"], torch.Tensor)
|
||||
assert isinstance(result[ACTION]["mean"], torch.Tensor)
|
||||
assert isinstance(result[OBS_STR]["mean"], torch.Tensor)
|
||||
assert torch.allclose(result["action"]["mean"], torch.tensor([0.1, 0.2]))
|
||||
assert torch.allclose(result[ACTION]["mean"], torch.tensor([0.1, 0.2]))
|
||||
assert torch.allclose(result[OBS_STR]["mean"], torch.tensor([0.5, 0.6]))
|
||||
|
||||
|
||||
@@ -200,7 +200,7 @@ def test_batch_to_transition_with_index_fields():
|
||||
# Create batch with index and task_index fields
|
||||
batch = {
|
||||
OBS_STATE: torch.randn(1, 7),
|
||||
"action": torch.randn(1, 4),
|
||||
ACTION: torch.randn(1, 4),
|
||||
"next.reward": 1.5,
|
||||
"next.done": False,
|
||||
"task": ["pick_cube"],
|
||||
@@ -262,7 +262,7 @@ def test_batch_to_transition_without_index_fields():
|
||||
# Batch without index/task_index
|
||||
batch = {
|
||||
OBS_STATE: torch.randn(1, 7),
|
||||
"action": torch.randn(1, 4),
|
||||
ACTION: torch.randn(1, 4),
|
||||
"task": ["pick_cube"],
|
||||
}
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ import torch
|
||||
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.processor import DataProcessorPipeline, DeviceProcessorStep, TransitionKey
|
||||
from lerobot.processor.converters import create_transition, identity_transition
|
||||
from lerobot.utils.constants import OBS_IMAGE, OBS_STATE
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE
|
||||
|
||||
|
||||
def test_basic_functionality():
|
||||
@@ -273,7 +273,7 @@ def test_features():
|
||||
|
||||
features = {
|
||||
PipelineFeatureType.OBSERVATION: {OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,))},
|
||||
PipelineFeatureType.ACTION: {"action": PolicyFeature(type=FeatureType.ACTION, shape=(5,))},
|
||||
PipelineFeatureType.ACTION: {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(5,))},
|
||||
}
|
||||
|
||||
result = processor.transform_features(features)
|
||||
|
||||
@@ -25,7 +25,7 @@ from pathlib import Path
|
||||
import pytest
|
||||
|
||||
from lerobot.processor.pipeline import DataProcessorPipeline, ProcessorMigrationError
|
||||
from lerobot.utils.constants import OBS_STATE
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
|
||||
|
||||
def test_is_processor_config_valid_configs():
|
||||
@@ -113,7 +113,7 @@ def test_should_suggest_migration_with_model_config_only():
|
||||
model_config = {
|
||||
"type": "act",
|
||||
"input_features": {OBS_STATE: {"shape": [7]}},
|
||||
"output_features": {"action": {"shape": [7]}},
|
||||
"output_features": {ACTION: {"shape": [7]}},
|
||||
"hidden_dim": 256,
|
||||
"n_obs_steps": 1,
|
||||
"n_action_steps": 1,
|
||||
|
||||
@@ -29,7 +29,7 @@ from lerobot.processor import (
|
||||
hotswap_stats,
|
||||
)
|
||||
from lerobot.processor.converters import create_transition, identity_transition, to_tensor
|
||||
from lerobot.utils.constants import OBS_IMAGE, OBS_STATE, OBS_STR
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE, OBS_STR
|
||||
from lerobot.utils.utils import auto_select_torch_device
|
||||
|
||||
|
||||
@@ -50,15 +50,15 @@ def test_numpy_conversion():
|
||||
|
||||
def test_tensor_conversion():
|
||||
stats = {
|
||||
"action": {
|
||||
ACTION: {
|
||||
"mean": torch.tensor([0.0, 0.0]),
|
||||
"std": torch.tensor([1.0, 1.0]),
|
||||
}
|
||||
}
|
||||
tensor_stats = to_tensor(stats)
|
||||
|
||||
assert tensor_stats["action"]["mean"].dtype == torch.float32
|
||||
assert tensor_stats["action"]["std"].dtype == torch.float32
|
||||
assert tensor_stats[ACTION]["mean"].dtype == torch.float32
|
||||
assert tensor_stats[ACTION]["std"].dtype == torch.float32
|
||||
|
||||
|
||||
def test_scalar_conversion():
|
||||
@@ -212,12 +212,12 @@ def test_from_lerobot_dataset():
|
||||
mock_dataset = Mock()
|
||||
mock_dataset.meta.stats = {
|
||||
OBS_IMAGE: {"mean": [0.5], "std": [0.2]},
|
||||
"action": {"mean": [0.0], "std": [1.0]},
|
||||
ACTION: {"mean": [0.0], "std": [1.0]},
|
||||
}
|
||||
|
||||
features = {
|
||||
OBS_IMAGE: PolicyFeature(FeatureType.VISUAL, (3, 96, 96)),
|
||||
"action": PolicyFeature(FeatureType.ACTION, (1,)),
|
||||
ACTION: PolicyFeature(FeatureType.ACTION, (1,)),
|
||||
}
|
||||
norm_map = {
|
||||
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
|
||||
@@ -228,7 +228,7 @@ def test_from_lerobot_dataset():
|
||||
|
||||
# Both observation and action statistics should be present in tensor stats
|
||||
assert OBS_IMAGE in normalizer._tensor_stats
|
||||
assert "action" in normalizer._tensor_stats
|
||||
assert ACTION in normalizer._tensor_stats
|
||||
|
||||
|
||||
def test_state_dict_save_load(observation_normalizer):
|
||||
@@ -271,7 +271,7 @@ def action_stats_min_max():
|
||||
|
||||
def _create_action_features():
|
||||
return {
|
||||
"action": PolicyFeature(FeatureType.ACTION, (3,)),
|
||||
ACTION: PolicyFeature(FeatureType.ACTION, (3,)),
|
||||
}
|
||||
|
||||
|
||||
@@ -291,7 +291,7 @@ def test_mean_std_unnormalization(action_stats_mean_std):
|
||||
features = _create_action_features()
|
||||
norm_map = _create_action_norm_map_mean_std()
|
||||
unnormalizer = UnnormalizerProcessorStep(
|
||||
features=features, norm_map=norm_map, stats={"action": action_stats_mean_std}
|
||||
features=features, norm_map=norm_map, stats={ACTION: action_stats_mean_std}
|
||||
)
|
||||
|
||||
normalized_action = torch.tensor([1.0, -0.5, 2.0])
|
||||
@@ -309,7 +309,7 @@ def test_min_max_unnormalization(action_stats_min_max):
|
||||
features = _create_action_features()
|
||||
norm_map = _create_action_norm_map_min_max()
|
||||
unnormalizer = UnnormalizerProcessorStep(
|
||||
features=features, norm_map=norm_map, stats={"action": action_stats_min_max}
|
||||
features=features, norm_map=norm_map, stats={ACTION: action_stats_min_max}
|
||||
)
|
||||
|
||||
# Actions in [-1, 1]
|
||||
@@ -335,7 +335,7 @@ def test_tensor_action_input(action_stats_mean_std):
|
||||
features = _create_action_features()
|
||||
norm_map = _create_action_norm_map_mean_std()
|
||||
unnormalizer = UnnormalizerProcessorStep(
|
||||
features=features, norm_map=norm_map, stats={"action": action_stats_mean_std}
|
||||
features=features, norm_map=norm_map, stats={ACTION: action_stats_mean_std}
|
||||
)
|
||||
|
||||
normalized_action = torch.tensor([1.0, -0.5, 2.0], dtype=torch.float32)
|
||||
@@ -353,7 +353,7 @@ def test_none_action(action_stats_mean_std):
|
||||
features = _create_action_features()
|
||||
norm_map = _create_action_norm_map_mean_std()
|
||||
unnormalizer = UnnormalizerProcessorStep(
|
||||
features=features, norm_map=norm_map, stats={"action": action_stats_mean_std}
|
||||
features=features, norm_map=norm_map, stats={ACTION: action_stats_mean_std}
|
||||
)
|
||||
|
||||
transition = create_transition()
|
||||
@@ -365,11 +365,11 @@ def test_none_action(action_stats_mean_std):
|
||||
|
||||
def test_action_from_lerobot_dataset():
|
||||
mock_dataset = Mock()
|
||||
mock_dataset.meta.stats = {"action": {"mean": [0.0], "std": [1.0]}}
|
||||
features = {"action": PolicyFeature(FeatureType.ACTION, (1,))}
|
||||
mock_dataset.meta.stats = {ACTION: {"mean": [0.0], "std": [1.0]}}
|
||||
features = {ACTION: PolicyFeature(FeatureType.ACTION, (1,))}
|
||||
norm_map = {FeatureType.ACTION: NormalizationMode.MEAN_STD}
|
||||
unnormalizer = UnnormalizerProcessorStep.from_lerobot_dataset(mock_dataset, features, norm_map)
|
||||
assert "mean" in unnormalizer._tensor_stats["action"]
|
||||
assert "mean" in unnormalizer._tensor_stats[ACTION]
|
||||
|
||||
|
||||
# Fixtures for NormalizerProcessorStep tests
|
||||
@@ -384,7 +384,7 @@ def full_stats():
|
||||
"min": np.array([0.0, -1.0]),
|
||||
"max": np.array([1.0, 1.0]),
|
||||
},
|
||||
"action": {
|
||||
ACTION: {
|
||||
"mean": np.array([0.0, 0.0]),
|
||||
"std": np.array([1.0, 2.0]),
|
||||
},
|
||||
@@ -395,7 +395,7 @@ def _create_full_features():
|
||||
return {
|
||||
OBS_IMAGE: PolicyFeature(FeatureType.VISUAL, (3, 96, 96)),
|
||||
OBS_STATE: PolicyFeature(FeatureType.STATE, (2,)),
|
||||
"action": PolicyFeature(FeatureType.ACTION, (2,)),
|
||||
ACTION: PolicyFeature(FeatureType.ACTION, (2,)),
|
||||
}
|
||||
|
||||
|
||||
@@ -461,7 +461,7 @@ def test_processor_from_lerobot_dataset(full_stats):
|
||||
|
||||
assert processor.normalize_observation_keys == {OBS_IMAGE}
|
||||
assert OBS_IMAGE in processor._tensor_stats
|
||||
assert "action" in processor._tensor_stats
|
||||
assert ACTION in processor._tensor_stats
|
||||
|
||||
|
||||
def test_get_config(full_stats):
|
||||
@@ -482,7 +482,7 @@ def test_get_config(full_stats):
|
||||
"features": {
|
||||
OBS_IMAGE: {"type": "VISUAL", "shape": (3, 96, 96)},
|
||||
OBS_STATE: {"type": "STATE", "shape": (2,)},
|
||||
"action": {"type": "ACTION", "shape": (2,)},
|
||||
ACTION: {"type": "ACTION", "shape": (2,)},
|
||||
},
|
||||
"norm_map": {
|
||||
"VISUAL": "MEAN_STD",
|
||||
@@ -568,7 +568,7 @@ def test_missing_action_stats_no_error():
|
||||
|
||||
processor = UnnormalizerProcessorStep.from_lerobot_dataset(mock_dataset, features, norm_map)
|
||||
# The tensor stats should not contain the 'action' key
|
||||
assert "action" not in processor._tensor_stats
|
||||
assert ACTION not in processor._tensor_stats
|
||||
|
||||
|
||||
def test_serialization_roundtrip(full_stats):
|
||||
@@ -676,9 +676,9 @@ def test_identity_normalization_observations():
|
||||
|
||||
def test_identity_normalization_actions():
|
||||
"""Test that IDENTITY mode skips normalization for actions."""
|
||||
features = {"action": PolicyFeature(FeatureType.ACTION, (2,))}
|
||||
features = {ACTION: PolicyFeature(FeatureType.ACTION, (2,))}
|
||||
norm_map = {FeatureType.ACTION: NormalizationMode.IDENTITY}
|
||||
stats = {"action": {"mean": [0.0, 0.0], "std": [1.0, 2.0]}}
|
||||
stats = {ACTION: {"mean": [0.0, 0.0], "std": [1.0, 2.0]}}
|
||||
|
||||
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
|
||||
|
||||
@@ -729,9 +729,9 @@ def test_identity_unnormalization_observations():
|
||||
|
||||
def test_identity_unnormalization_actions():
|
||||
"""Test that IDENTITY mode skips unnormalization for actions."""
|
||||
features = {"action": PolicyFeature(FeatureType.ACTION, (2,))}
|
||||
features = {ACTION: PolicyFeature(FeatureType.ACTION, (2,))}
|
||||
norm_map = {FeatureType.ACTION: NormalizationMode.IDENTITY}
|
||||
stats = {"action": {"min": [-1.0, -2.0], "max": [1.0, 2.0]}}
|
||||
stats = {ACTION: {"min": [-1.0, -2.0], "max": [1.0, 2.0]}}
|
||||
|
||||
unnormalizer = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
|
||||
|
||||
@@ -748,7 +748,7 @@ def test_identity_with_missing_stats():
|
||||
"""Test that IDENTITY mode works even when stats are missing."""
|
||||
features = {
|
||||
OBS_IMAGE: PolicyFeature(FeatureType.VISUAL, (3, 96, 96)),
|
||||
"action": PolicyFeature(FeatureType.ACTION, (2,)),
|
||||
ACTION: PolicyFeature(FeatureType.ACTION, (2,)),
|
||||
}
|
||||
norm_map = {
|
||||
FeatureType.VISUAL: NormalizationMode.IDENTITY,
|
||||
@@ -784,7 +784,7 @@ def test_identity_mixed_with_other_modes():
|
||||
features = {
|
||||
OBS_IMAGE: PolicyFeature(FeatureType.VISUAL, (3,)),
|
||||
OBS_STATE: PolicyFeature(FeatureType.STATE, (2,)),
|
||||
"action": PolicyFeature(FeatureType.ACTION, (2,)),
|
||||
ACTION: PolicyFeature(FeatureType.ACTION, (2,)),
|
||||
}
|
||||
norm_map = {
|
||||
FeatureType.VISUAL: NormalizationMode.IDENTITY,
|
||||
@@ -794,7 +794,7 @@ def test_identity_mixed_with_other_modes():
|
||||
stats = {
|
||||
OBS_IMAGE: {"mean": [0.5, 0.5, 0.5], "std": [0.2, 0.2, 0.2]}, # Will be ignored
|
||||
OBS_STATE: {"mean": [0.0, 0.0], "std": [1.0, 1.0]},
|
||||
"action": {"min": [-1.0, -1.0], "max": [1.0, 1.0]},
|
||||
ACTION: {"min": [-1.0, -1.0], "max": [1.0, 1.0]},
|
||||
}
|
||||
|
||||
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
|
||||
@@ -862,7 +862,7 @@ def test_identity_roundtrip():
|
||||
"""Test that IDENTITY normalization and unnormalization are true inverses."""
|
||||
features = {
|
||||
OBS_IMAGE: PolicyFeature(FeatureType.VISUAL, (3,)),
|
||||
"action": PolicyFeature(FeatureType.ACTION, (2,)),
|
||||
ACTION: PolicyFeature(FeatureType.ACTION, (2,)),
|
||||
}
|
||||
norm_map = {
|
||||
FeatureType.VISUAL: NormalizationMode.IDENTITY,
|
||||
@@ -870,7 +870,7 @@ def test_identity_roundtrip():
|
||||
}
|
||||
stats = {
|
||||
OBS_IMAGE: {"mean": [0.5, 0.5, 0.5], "std": [0.2, 0.2, 0.2]},
|
||||
"action": {"min": [-1.0, -1.0], "max": [1.0, 1.0]},
|
||||
ACTION: {"min": [-1.0, -1.0], "max": [1.0, 1.0]},
|
||||
}
|
||||
|
||||
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
|
||||
@@ -893,7 +893,7 @@ def test_identity_config_serialization():
|
||||
"""Test that IDENTITY mode is properly saved and loaded in config."""
|
||||
features = {
|
||||
OBS_IMAGE: PolicyFeature(FeatureType.VISUAL, (3,)),
|
||||
"action": PolicyFeature(FeatureType.ACTION, (2,)),
|
||||
ACTION: PolicyFeature(FeatureType.ACTION, (2,)),
|
||||
}
|
||||
norm_map = {
|
||||
FeatureType.VISUAL: NormalizationMode.IDENTITY,
|
||||
@@ -901,7 +901,7 @@ def test_identity_config_serialization():
|
||||
}
|
||||
stats = {
|
||||
OBS_IMAGE: {"mean": [0.5], "std": [0.2]},
|
||||
"action": {"mean": [0.0, 0.0], "std": [1.0, 1.0]},
|
||||
ACTION: {"mean": [0.0, 0.0], "std": [1.0, 1.0]},
|
||||
}
|
||||
|
||||
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
|
||||
@@ -969,19 +969,19 @@ def test_hotswap_stats_basic_functionality():
|
||||
# Create initial stats
|
||||
initial_stats = {
|
||||
OBS_IMAGE: {"mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2])},
|
||||
"action": {"mean": np.array([0.0, 0.0]), "std": np.array([1.0, 1.0])},
|
||||
ACTION: {"mean": np.array([0.0, 0.0]), "std": np.array([1.0, 1.0])},
|
||||
}
|
||||
|
||||
# Create new stats for hotswapping
|
||||
new_stats = {
|
||||
OBS_IMAGE: {"mean": np.array([0.3, 0.3, 0.3]), "std": np.array([0.1, 0.1, 0.1])},
|
||||
"action": {"mean": np.array([0.1, 0.1]), "std": np.array([0.5, 0.5])},
|
||||
ACTION: {"mean": np.array([0.1, 0.1]), "std": np.array([0.5, 0.5])},
|
||||
}
|
||||
|
||||
# Create features and norm_map
|
||||
features = {
|
||||
OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)),
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)),
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(2,)),
|
||||
}
|
||||
norm_map = {
|
||||
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
|
||||
@@ -1177,17 +1177,17 @@ def test_hotswap_stats_multiple_normalizer_types():
|
||||
"""Test hotswap_stats with multiple normalizer and unnormalizer steps."""
|
||||
initial_stats = {
|
||||
OBS_IMAGE: {"mean": np.array([0.5]), "std": np.array([0.2])},
|
||||
"action": {"min": np.array([-1.0]), "max": np.array([1.0])},
|
||||
ACTION: {"min": np.array([-1.0]), "max": np.array([1.0])},
|
||||
}
|
||||
|
||||
new_stats = {
|
||||
OBS_IMAGE: {"mean": np.array([0.3]), "std": np.array([0.1])},
|
||||
"action": {"min": np.array([-2.0]), "max": np.array([2.0])},
|
||||
ACTION: {"min": np.array([-2.0]), "max": np.array([2.0])},
|
||||
}
|
||||
|
||||
features = {
|
||||
OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)),
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(1,)),
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(1,)),
|
||||
}
|
||||
norm_map = {
|
||||
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
|
||||
@@ -1232,7 +1232,7 @@ def test_hotswap_stats_with_different_data_types():
|
||||
"min": 0, # int
|
||||
"max": 1.0, # float
|
||||
},
|
||||
"action": {
|
||||
ACTION: {
|
||||
"mean": np.array([0.1, 0.2]), # numpy array
|
||||
"std": torch.tensor([0.5, 0.6]), # torch tensor
|
||||
},
|
||||
@@ -1240,7 +1240,7 @@ def test_hotswap_stats_with_different_data_types():
|
||||
|
||||
features = {
|
||||
OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)),
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)),
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(2,)),
|
||||
}
|
||||
norm_map = {
|
||||
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
|
||||
@@ -1262,8 +1262,8 @@ def test_hotswap_stats_with_different_data_types():
|
||||
assert isinstance(tensor_stats[OBS_IMAGE]["std"], torch.Tensor)
|
||||
assert isinstance(tensor_stats[OBS_IMAGE]["min"], torch.Tensor)
|
||||
assert isinstance(tensor_stats[OBS_IMAGE]["max"], torch.Tensor)
|
||||
assert isinstance(tensor_stats["action"]["mean"], torch.Tensor)
|
||||
assert isinstance(tensor_stats["action"]["std"], torch.Tensor)
|
||||
assert isinstance(tensor_stats[ACTION]["mean"], torch.Tensor)
|
||||
assert isinstance(tensor_stats[ACTION]["std"], torch.Tensor)
|
||||
|
||||
# Check values
|
||||
torch.testing.assert_close(tensor_stats[OBS_IMAGE]["mean"], torch.tensor([0.3, 0.4, 0.5]))
|
||||
@@ -1284,18 +1284,18 @@ def test_hotswap_stats_functional_test():
|
||||
# Initial stats
|
||||
initial_stats = {
|
||||
OBS_IMAGE: {"mean": np.array([0.5, 0.4]), "std": np.array([0.2, 0.3])},
|
||||
"action": {"mean": np.array([0.0, 0.0]), "std": np.array([1.0, 1.0])},
|
||||
ACTION: {"mean": np.array([0.0, 0.0]), "std": np.array([1.0, 1.0])},
|
||||
}
|
||||
|
||||
# New stats
|
||||
new_stats = {
|
||||
OBS_IMAGE: {"mean": np.array([0.3, 0.2]), "std": np.array([0.1, 0.2])},
|
||||
"action": {"mean": np.array([0.1, -0.1]), "std": np.array([0.5, 0.5])},
|
||||
ACTION: {"mean": np.array([0.1, -0.1]), "std": np.array([0.5, 0.5])},
|
||||
}
|
||||
|
||||
features = {
|
||||
OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(2, 2, 2)),
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)),
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(2,)),
|
||||
}
|
||||
norm_map = {
|
||||
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
|
||||
@@ -1324,18 +1324,18 @@ def test_hotswap_stats_functional_test():
|
||||
rtol=1e-3,
|
||||
atol=1e-3,
|
||||
)
|
||||
assert not torch.allclose(original_result["action"], new_result["action"], rtol=1e-3, atol=1e-3)
|
||||
assert not torch.allclose(original_result[ACTION], new_result[ACTION], rtol=1e-3, atol=1e-3)
|
||||
|
||||
# Verify that the new processor is actually using the new stats by checking internal state
|
||||
assert new_processor.steps[0].stats == new_stats
|
||||
assert torch.allclose(new_processor.steps[0]._tensor_stats[OBS_IMAGE]["mean"], torch.tensor([0.3, 0.2]))
|
||||
assert torch.allclose(new_processor.steps[0]._tensor_stats[OBS_IMAGE]["std"], torch.tensor([0.1, 0.2]))
|
||||
assert torch.allclose(new_processor.steps[0]._tensor_stats["action"]["mean"], torch.tensor([0.1, -0.1]))
|
||||
assert torch.allclose(new_processor.steps[0]._tensor_stats["action"]["std"], torch.tensor([0.5, 0.5]))
|
||||
assert torch.allclose(new_processor.steps[0]._tensor_stats[ACTION]["mean"], torch.tensor([0.1, -0.1]))
|
||||
assert torch.allclose(new_processor.steps[0]._tensor_stats[ACTION]["std"], torch.tensor([0.5, 0.5]))
|
||||
|
||||
# Test that normalization actually happens (output should not equal input)
|
||||
assert not torch.allclose(new_result[OBS_STR][OBS_IMAGE], observation[OBS_IMAGE])
|
||||
assert not torch.allclose(new_result["action"], action)
|
||||
assert not torch.allclose(new_result[ACTION], action)
|
||||
|
||||
|
||||
def test_zero_std_uses_eps():
|
||||
@@ -1366,10 +1366,10 @@ def test_action_normalized_despite_normalize_observation_keys():
|
||||
"""Action normalization is independent of normalize_observation_keys filter for observations."""
|
||||
features = {
|
||||
OBS_STATE: PolicyFeature(FeatureType.STATE, (1,)),
|
||||
"action": PolicyFeature(FeatureType.ACTION, (2,)),
|
||||
ACTION: PolicyFeature(FeatureType.ACTION, (2,)),
|
||||
}
|
||||
norm_map = {FeatureType.STATE: NormalizationMode.IDENTITY, FeatureType.ACTION: NormalizationMode.MEAN_STD}
|
||||
stats = {"action": {"mean": np.array([1.0, -1.0]), "std": np.array([2.0, 4.0])}}
|
||||
stats = {ACTION: {"mean": np.array([1.0, -1.0]), "std": np.array([2.0, 4.0])}}
|
||||
normalizer = NormalizerProcessorStep(
|
||||
features=features, norm_map=norm_map, stats=stats, normalize_observation_keys={OBS_STATE}
|
||||
)
|
||||
@@ -1426,9 +1426,9 @@ def test_unknown_observation_keys_ignored():
|
||||
|
||||
|
||||
def test_batched_action_normalization():
|
||||
features = {"action": PolicyFeature(FeatureType.ACTION, (2,))}
|
||||
features = {ACTION: PolicyFeature(FeatureType.ACTION, (2,))}
|
||||
norm_map = {FeatureType.ACTION: NormalizationMode.MEAN_STD}
|
||||
stats = {"action": {"mean": np.array([1.0, -1.0]), "std": np.array([2.0, 4.0])}}
|
||||
stats = {ACTION: {"mean": np.array([1.0, -1.0]), "std": np.array([2.0, 4.0])}}
|
||||
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
|
||||
|
||||
actions = torch.tensor([[1.0, -1.0], [3.0, 3.0]]) # first equals mean → zeros; second → [1, 1]
|
||||
@@ -1453,12 +1453,12 @@ def test_complementary_data_preservation():
|
||||
def test_roundtrip_normalize_unnormalize_non_identity():
|
||||
features = {
|
||||
OBS_STATE: PolicyFeature(FeatureType.STATE, (2,)),
|
||||
"action": PolicyFeature(FeatureType.ACTION, (2,)),
|
||||
ACTION: PolicyFeature(FeatureType.ACTION, (2,)),
|
||||
}
|
||||
norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD, FeatureType.ACTION: NormalizationMode.MIN_MAX}
|
||||
stats = {
|
||||
OBS_STATE: {"mean": np.array([1.0, -1.0]), "std": np.array([2.0, 4.0])},
|
||||
"action": {"min": np.array([-2.0, 0.0]), "max": np.array([2.0, 4.0])},
|
||||
ACTION: {"min": np.array([-2.0, 0.0]), "max": np.array([2.0, 4.0])},
|
||||
}
|
||||
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
|
||||
unnormalizer = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
|
||||
@@ -1530,18 +1530,18 @@ def test_stats_override_preservation_in_load_state_dict():
|
||||
# Create original stats
|
||||
original_stats = {
|
||||
OBS_IMAGE: {"mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2])},
|
||||
"action": {"mean": np.array([0.0, 0.0]), "std": np.array([1.0, 1.0])},
|
||||
ACTION: {"mean": np.array([0.0, 0.0]), "std": np.array([1.0, 1.0])},
|
||||
}
|
||||
|
||||
# Create override stats (what user wants to use)
|
||||
override_stats = {
|
||||
OBS_IMAGE: {"mean": np.array([0.3, 0.3, 0.3]), "std": np.array([0.1, 0.1, 0.1])},
|
||||
"action": {"mean": np.array([0.1, 0.1]), "std": np.array([0.5, 0.5])},
|
||||
ACTION: {"mean": np.array([0.1, 0.1]), "std": np.array([0.5, 0.5])},
|
||||
}
|
||||
|
||||
features = {
|
||||
OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)),
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)),
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(2,)),
|
||||
}
|
||||
norm_map = {
|
||||
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
|
||||
@@ -1601,12 +1601,12 @@ def test_stats_without_override_loads_normally():
|
||||
"""
|
||||
original_stats = {
|
||||
OBS_IMAGE: {"mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2])},
|
||||
"action": {"mean": np.array([0.0, 0.0]), "std": np.array([1.0, 1.0])},
|
||||
ACTION: {"mean": np.array([0.0, 0.0]), "std": np.array([1.0, 1.0])},
|
||||
}
|
||||
|
||||
features = {
|
||||
OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)),
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)),
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(2,)),
|
||||
}
|
||||
norm_map = {
|
||||
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
|
||||
@@ -1674,7 +1674,7 @@ def test_pipeline_from_pretrained_with_stats_overrides():
|
||||
# Create test data
|
||||
features = {
|
||||
OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 32, 32)),
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)),
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(2,)),
|
||||
}
|
||||
norm_map = {
|
||||
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
|
||||
@@ -1683,12 +1683,12 @@ def test_pipeline_from_pretrained_with_stats_overrides():
|
||||
|
||||
original_stats = {
|
||||
OBS_IMAGE: {"mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2])},
|
||||
"action": {"mean": np.array([0.0, 0.0]), "std": np.array([1.0, 1.0])},
|
||||
ACTION: {"mean": np.array([0.0, 0.0]), "std": np.array([1.0, 1.0])},
|
||||
}
|
||||
|
||||
override_stats = {
|
||||
OBS_IMAGE: {"mean": np.array([0.3, 0.3, 0.3]), "std": np.array([0.1, 0.1, 0.1])},
|
||||
"action": {"mean": np.array([0.1, 0.1]), "std": np.array([0.5, 0.5])},
|
||||
ACTION: {"mean": np.array([0.1, 0.1]), "std": np.array([0.5, 0.5])},
|
||||
}
|
||||
|
||||
# Create and save a pipeline with the original stats
|
||||
@@ -1751,8 +1751,8 @@ def test_pipeline_from_pretrained_with_stats_overrides():
|
||||
# The critical part was verified above: loaded_normalizer.stats == override_stats
|
||||
# This confirms that override stats are preserved during load_state_dict.
|
||||
# Let's just verify the pipeline processes data successfully.
|
||||
assert "action" in override_result
|
||||
assert isinstance(override_result["action"], torch.Tensor)
|
||||
assert ACTION in override_result
|
||||
assert isinstance(override_result[ACTION], torch.Tensor)
|
||||
|
||||
|
||||
def test_dtype_adaptation_device_processor_bfloat16_normalizer_float32():
|
||||
@@ -1812,7 +1812,7 @@ def test_stats_reconstruction_after_load_state_dict():
|
||||
features = {
|
||||
OBS_IMAGE: PolicyFeature(FeatureType.VISUAL, (3, 96, 96)),
|
||||
OBS_STATE: PolicyFeature(FeatureType.STATE, (2,)),
|
||||
"action": PolicyFeature(FeatureType.ACTION, (2,)),
|
||||
ACTION: PolicyFeature(FeatureType.ACTION, (2,)),
|
||||
}
|
||||
norm_map = {
|
||||
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
|
||||
@@ -1828,7 +1828,7 @@ def test_stats_reconstruction_after_load_state_dict():
|
||||
"min": np.array([0.0, -1.0]),
|
||||
"max": np.array([1.0, 1.0]),
|
||||
},
|
||||
"action": {
|
||||
ACTION: {
|
||||
"mean": np.array([0.0, 0.0]),
|
||||
"std": np.array([1.0, 2.0]),
|
||||
},
|
||||
@@ -1852,15 +1852,15 @@ def test_stats_reconstruction_after_load_state_dict():
|
||||
# Check that all expected keys are present
|
||||
assert OBS_IMAGE in new_normalizer.stats
|
||||
assert OBS_STATE in new_normalizer.stats
|
||||
assert "action" in new_normalizer.stats
|
||||
assert ACTION in new_normalizer.stats
|
||||
|
||||
# Check that values are correct (converted back from tensors)
|
||||
np.testing.assert_allclose(new_normalizer.stats[OBS_IMAGE]["mean"], [0.5, 0.5, 0.5])
|
||||
np.testing.assert_allclose(new_normalizer.stats[OBS_IMAGE]["std"], [0.2, 0.2, 0.2])
|
||||
np.testing.assert_allclose(new_normalizer.stats[OBS_STATE]["min"], [0.0, -1.0])
|
||||
np.testing.assert_allclose(new_normalizer.stats[OBS_STATE]["max"], [1.0, 1.0])
|
||||
np.testing.assert_allclose(new_normalizer.stats["action"]["mean"], [0.0, 0.0])
|
||||
np.testing.assert_allclose(new_normalizer.stats["action"]["std"], [1.0, 2.0])
|
||||
np.testing.assert_allclose(new_normalizer.stats[ACTION]["mean"], [0.0, 0.0])
|
||||
np.testing.assert_allclose(new_normalizer.stats[ACTION]["std"], [1.0, 2.0])
|
||||
|
||||
# Test that methods that depend on self.stats work correctly after loading
|
||||
# This would fail before the bug fix because self.stats was empty
|
||||
@@ -1876,7 +1876,7 @@ def test_stats_reconstruction_after_load_state_dict():
|
||||
new_stats = {
|
||||
OBS_IMAGE: {"mean": [0.3, 0.3, 0.3], "std": [0.1, 0.1, 0.1]},
|
||||
OBS_STATE: {"min": [-1.0, -2.0], "max": [2.0, 2.0]},
|
||||
"action": {"mean": [0.1, 0.1], "std": [0.5, 0.5]},
|
||||
ACTION: {"mean": [0.1, 0.1], "std": [0.5, 0.5]},
|
||||
}
|
||||
|
||||
pipeline = DataProcessorPipeline([new_normalizer])
|
||||
|
||||
@@ -35,7 +35,7 @@ from lerobot.processor import (
|
||||
TransitionKey,
|
||||
)
|
||||
from lerobot.processor.converters import create_transition, identity_transition
|
||||
from lerobot.utils.constants import OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||
from tests.conftest import assert_contract_is_typed
|
||||
|
||||
|
||||
@@ -257,7 +257,7 @@ def test_step_through_with_dict():
|
||||
|
||||
batch = {
|
||||
OBS_IMAGE: None,
|
||||
"action": None,
|
||||
ACTION: None,
|
||||
"next.reward": 0.0,
|
||||
"next.done": False,
|
||||
"next.truncated": False,
|
||||
@@ -1842,7 +1842,7 @@ def test_save_load_with_custom_converter_functions():
|
||||
# Verify it uses default converters by checking with standard batch format
|
||||
batch = {
|
||||
OBS_IMAGE: torch.randn(1, 3, 32, 32),
|
||||
"action": torch.randn(1, 7),
|
||||
ACTION: torch.randn(1, 7),
|
||||
"next.reward": torch.tensor([1.0]),
|
||||
"next.done": torch.tensor([False]),
|
||||
"next.truncated": torch.tensor([False]),
|
||||
@@ -2094,11 +2094,11 @@ def test_aggregate_joint_action_only():
|
||||
patterns=["action.j1.pos", "action.j2.pos"],
|
||||
)
|
||||
|
||||
# Expect only "action" with joint names
|
||||
assert "action" in out and OBS_STATE not in out
|
||||
assert out["action"]["dtype"] == "float32"
|
||||
assert set(out["action"]["names"]) == {"j1.pos", "j2.pos"}
|
||||
assert out["action"]["shape"] == (len(out["action"]["names"]),)
|
||||
# Expect only ACTION with joint names
|
||||
assert ACTION in out and OBS_STATE not in out
|
||||
assert out[ACTION]["dtype"] == "float32"
|
||||
assert set(out[ACTION]["names"]) == {"j1.pos", "j2.pos"}
|
||||
assert out[ACTION]["shape"] == (len(out[ACTION]["names"]),)
|
||||
|
||||
|
||||
def test_aggregate_ee_action_and_observation_with_videos():
|
||||
@@ -2113,9 +2113,9 @@ def test_aggregate_ee_action_and_observation_with_videos():
|
||||
)
|
||||
|
||||
# Action should pack only EE names
|
||||
assert "action" in out
|
||||
assert set(out["action"]["names"]) == {"ee.x", "ee.y"}
|
||||
assert out["action"]["dtype"] == "float32"
|
||||
assert ACTION in out
|
||||
assert set(out[ACTION]["names"]) == {"ee.x", "ee.y"}
|
||||
assert out[ACTION]["dtype"] == "float32"
|
||||
|
||||
# Observation state should pack both ee.x and j1.pos as a vector
|
||||
assert OBS_STATE in out
|
||||
@@ -2140,10 +2140,10 @@ def test_aggregate_both_action_types():
|
||||
patterns=["action.ee", "action.j1", "action.j2.pos"],
|
||||
)
|
||||
|
||||
assert "action" in out
|
||||
assert ACTION in out
|
||||
expected = {"ee.x", "ee.y", "j1.pos", "j2.pos"}
|
||||
assert set(out["action"]["names"]) == expected
|
||||
assert out["action"]["shape"] == (len(expected),)
|
||||
assert set(out[ACTION]["names"]) == expected
|
||||
assert out[ACTION]["shape"] == (len(expected),)
|
||||
|
||||
|
||||
def test_aggregate_images_when_use_videos_false():
|
||||
|
||||
@@ -28,6 +28,7 @@ from lerobot.processor import (
|
||||
RobotActionToPolicyActionProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import identity_transition
|
||||
from lerobot.utils.constants import ACTION
|
||||
from tests.conftest import assert_contract_is_typed
|
||||
|
||||
|
||||
@@ -134,8 +135,8 @@ def test_robot_to_policy_transform_features():
|
||||
|
||||
transformed = processor.transform_features(features)
|
||||
|
||||
assert "action" in transformed[PipelineFeatureType.ACTION]
|
||||
action_feature = transformed[PipelineFeatureType.ACTION]["action"]
|
||||
assert ACTION in transformed[PipelineFeatureType.ACTION]
|
||||
action_feature = transformed[PipelineFeatureType.ACTION][ACTION]
|
||||
assert action_feature.type == FeatureType.ACTION
|
||||
assert action_feature.shape == (3,)
|
||||
|
||||
@@ -251,7 +252,7 @@ def test_policy_to_robot_transform_features():
|
||||
|
||||
features = {
|
||||
PipelineFeatureType.ACTION: {
|
||||
"action": {"type": FeatureType.ACTION, "shape": (2,)},
|
||||
ACTION: {"type": FeatureType.ACTION, "shape": (2,)},
|
||||
"other_data": {"type": FeatureType.ENV, "shape": (1,)},
|
||||
}
|
||||
}
|
||||
@@ -266,7 +267,7 @@ def test_policy_to_robot_transform_features():
|
||||
assert motor_feature.type == FeatureType.ACTION
|
||||
assert motor_feature.shape == (1,)
|
||||
|
||||
assert "action" in transformed[PipelineFeatureType.ACTION]
|
||||
assert ACTION in transformed[PipelineFeatureType.ACTION]
|
||||
|
||||
assert "other_data" in transformed[PipelineFeatureType.ACTION]
|
||||
|
||||
@@ -447,8 +448,8 @@ def test_robot_to_policy_features_contract(policy_feature_factory):
|
||||
|
||||
assert_contract_is_typed(out)
|
||||
|
||||
assert "action" in out[PipelineFeatureType.ACTION]
|
||||
action_feature = out[PipelineFeatureType.ACTION]["action"]
|
||||
assert ACTION in out[PipelineFeatureType.ACTION]
|
||||
action_feature = out[PipelineFeatureType.ACTION][ACTION]
|
||||
assert action_feature.type == FeatureType.ACTION
|
||||
assert action_feature.shape == (2,)
|
||||
|
||||
@@ -458,7 +459,7 @@ def test_policy_to_robot_features_contract(policy_feature_factory):
|
||||
processor = PolicyActionToRobotActionProcessorStep(motor_names=["m1", "m2", "m3"])
|
||||
features = {
|
||||
PipelineFeatureType.ACTION: {
|
||||
"action": policy_feature_factory(FeatureType.ACTION, (3,)),
|
||||
ACTION: policy_feature_factory(FeatureType.ACTION, (3,)),
|
||||
"other": policy_feature_factory(FeatureType.ENV, (1,)),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -28,7 +28,7 @@ from lerobot.processor import (
|
||||
)
|
||||
from lerobot.processor.converters import create_transition, identity_transition
|
||||
from lerobot.processor.rename_processor import rename_stats
|
||||
from lerobot.utils.constants import OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||
from tests.conftest import assert_contract_is_typed
|
||||
|
||||
|
||||
@@ -488,7 +488,7 @@ def test_features_chained_processors(policy_feature_factory):
|
||||
def test_rename_stats_basic():
|
||||
orig = {
|
||||
OBS_STATE: {"mean": np.array([0.0]), "std": np.array([1.0])},
|
||||
"action": {"mean": np.array([0.0])},
|
||||
ACTION: {"mean": np.array([0.0])},
|
||||
}
|
||||
mapping = {OBS_STATE: "observation.robot_state"}
|
||||
renamed = rename_stats(orig, mapping)
|
||||
|
||||
@@ -11,7 +11,7 @@ import torch
|
||||
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.processor import DataProcessorPipeline, TokenizerProcessorStep, TransitionKey
|
||||
from lerobot.processor.converters import create_transition, identity_transition
|
||||
from lerobot.utils.constants import OBS_IMAGE, OBS_LANGUAGE, OBS_STATE
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_LANGUAGE, OBS_STATE
|
||||
from tests.utils import require_package
|
||||
|
||||
|
||||
@@ -504,14 +504,14 @@ def test_features_basic():
|
||||
|
||||
input_features = {
|
||||
PipelineFeatureType.OBSERVATION: {OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,))},
|
||||
PipelineFeatureType.ACTION: {"action": PolicyFeature(type=FeatureType.ACTION, shape=(5,))},
|
||||
PipelineFeatureType.ACTION: {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(5,))},
|
||||
}
|
||||
|
||||
output_features = processor.transform_features(input_features)
|
||||
|
||||
# Check that original features are preserved
|
||||
assert OBS_STATE in output_features[PipelineFeatureType.OBSERVATION]
|
||||
assert "action" in output_features[PipelineFeatureType.ACTION]
|
||||
assert ACTION in output_features[PipelineFeatureType.ACTION]
|
||||
|
||||
# Check that tokenized features are added
|
||||
assert f"{OBS_LANGUAGE}.tokens" in output_features[PipelineFeatureType.OBSERVATION]
|
||||
|
||||
@@ -21,6 +21,7 @@ from pickle import UnpicklingError
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.utils.constants import ACTION
|
||||
from lerobot.utils.transition import Transition
|
||||
from tests.utils import require_cuda, require_package
|
||||
|
||||
@@ -512,7 +513,7 @@ def test_transitions_to_bytes_single_transition():
|
||||
def assert_transitions_equal(t1: Transition, t2: Transition):
|
||||
"""Helper to assert two transitions are equal."""
|
||||
assert_observation_equal(t1["state"], t2["state"])
|
||||
assert torch.allclose(t1["action"], t2["action"])
|
||||
assert torch.allclose(t1[ACTION], t2[ACTION])
|
||||
assert torch.allclose(t1["reward"], t2["reward"])
|
||||
assert torch.equal(t1["done"], t2["done"])
|
||||
assert_observation_equal(t1["next_state"], t2["next_state"])
|
||||
|
||||
@@ -22,7 +22,7 @@ import torch
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.rl.buffer import BatchTransition, ReplayBuffer, random_crop_vectorized
|
||||
from lerobot.utils.constants import OBS_IMAGE, OBS_STATE, OBS_STR
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE, OBS_STR
|
||||
from tests.fixtures.constants import DUMMY_REPO_ID
|
||||
|
||||
|
||||
@@ -63,7 +63,7 @@ def create_random_image() -> torch.Tensor:
|
||||
def create_dummy_transition() -> dict:
|
||||
return {
|
||||
OBS_IMAGE: create_random_image(),
|
||||
"action": torch.randn(4),
|
||||
ACTION: torch.randn(4),
|
||||
"reward": torch.tensor(1.0),
|
||||
OBS_STATE: torch.randn(
|
||||
10,
|
||||
@@ -341,7 +341,7 @@ def test_sample_batch(replay_buffer):
|
||||
f"{k} should be equal to one of the dummy states."
|
||||
)
|
||||
|
||||
for got_action_item in got_batch_transition["action"]:
|
||||
for got_action_item in got_batch_transition[ACTION]:
|
||||
assert any(torch.equal(got_action_item, dummy_action) for dummy_action in dummy_actions), (
|
||||
"Actions should be equal to the dummy actions."
|
||||
)
|
||||
@@ -378,7 +378,7 @@ def test_to_lerobot_dataset(tmp_path):
|
||||
|
||||
for i in range(len(ds)):
|
||||
for feature, value in ds[i].items():
|
||||
if feature == "action":
|
||||
if feature == ACTION:
|
||||
assert torch.equal(value, buffer.actions[i])
|
||||
elif feature == "next.reward":
|
||||
assert torch.equal(value, buffer.rewards[i])
|
||||
@@ -495,7 +495,7 @@ def test_buffer_sample_alignment():
|
||||
|
||||
for i in range(50):
|
||||
state_sig = batch["state"]["state_value"][i].item()
|
||||
action_val = batch["action"][i].item()
|
||||
action_val = batch[ACTION][i].item()
|
||||
reward_val = batch["reward"][i].item()
|
||||
next_state_sig = batch["next_state"]["state_value"][i].item()
|
||||
is_done = batch["done"][i].item() > 0.5
|
||||
|
||||
Reference in New Issue
Block a user