chore: replace hard-coded obs values with constants throughout all the source code (#2037)
* chore: replace hard-coded OBS values with constants throughout all the source code * chore(tests): replace hard-coded OBS values with constants throughout all the test code
This commit is contained in:
@@ -27,7 +27,7 @@ from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
|
||||
# NOTE: Configs need to be loaded for the client to be able to instantiate the policy config
|
||||
from lerobot.policies import ACTConfig, DiffusionConfig, PI0Config, SmolVLAConfig, VQBeTConfig # noqa: F401
|
||||
from lerobot.robots.robot import Robot
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE, OBS_STR
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
Action = torch.Tensor
|
||||
@@ -66,7 +66,7 @@ def validate_robot_cameras_for_policy(
|
||||
|
||||
|
||||
def map_robot_keys_to_lerobot_features(robot: Robot) -> dict[str, dict]:
|
||||
return hw_to_dataset_features(robot.observation_features, "observation", use_video=False)
|
||||
return hw_to_dataset_features(robot.observation_features, OBS_STR, use_video=False)
|
||||
|
||||
|
||||
def is_image_key(k: str) -> bool:
|
||||
@@ -141,7 +141,7 @@ def make_lerobot_observation(
|
||||
lerobot_features: dict[str, dict],
|
||||
) -> LeRobotObservation:
|
||||
"""Make a lerobot observation from a raw observation."""
|
||||
return build_dataset_frame(lerobot_features, robot_obs, prefix="observation")
|
||||
return build_dataset_frame(lerobot_features, robot_obs, prefix=OBS_STR)
|
||||
|
||||
|
||||
def prepare_raw_observation(
|
||||
|
||||
@@ -27,6 +27,7 @@ from lerobot.datasets.lerobot_dataset import (
|
||||
)
|
||||
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
|
||||
from lerobot.datasets.transforms import ImageTransforms
|
||||
from lerobot.utils.constants import OBS_PREFIX
|
||||
|
||||
IMAGENET_STATS = {
|
||||
"mean": [[[0.485]], [[0.456]], [[0.406]]], # (c,1,1)
|
||||
@@ -58,7 +59,7 @@ def resolve_delta_timestamps(
|
||||
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.reward_delta_indices]
|
||||
if key == "action" and cfg.action_delta_indices is not None:
|
||||
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.action_delta_indices]
|
||||
if key.startswith("observation.") and cfg.observation_delta_indices is not None:
|
||||
if key.startswith(OBS_PREFIX) and cfg.observation_delta_indices is not None:
|
||||
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.observation_delta_indices]
|
||||
|
||||
if len(delta_timestamps) == 0:
|
||||
|
||||
@@ -19,7 +19,7 @@ from typing import Any
|
||||
from lerobot.configs.types import PipelineFeatureType
|
||||
from lerobot.datasets.utils import hw_to_dataset_features
|
||||
from lerobot.processor import DataProcessorPipeline
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE, OBS_STR
|
||||
|
||||
|
||||
def create_initial_features(
|
||||
@@ -92,8 +92,8 @@ def aggregate_pipeline_dataset_features(
|
||||
|
||||
# Intermediate storage for categorized and filtered features.
|
||||
processed_features: dict[str, dict[str, Any]] = {
|
||||
"action": {},
|
||||
"observation": {},
|
||||
ACTION: {},
|
||||
OBS_STR: {},
|
||||
}
|
||||
images_token = OBS_IMAGES.split(".")[-1]
|
||||
|
||||
@@ -125,17 +125,15 @@ def aggregate_pipeline_dataset_features(
|
||||
# 3. Add the feature to the appropriate group with a clean name.
|
||||
name = strip_prefix(key, PREFIXES_TO_STRIP)
|
||||
if is_action:
|
||||
processed_features["action"][name] = value
|
||||
processed_features[ACTION][name] = value
|
||||
else:
|
||||
processed_features["observation"][name] = value
|
||||
processed_features[OBS_STR][name] = value
|
||||
|
||||
# Convert the processed features into the final dataset format.
|
||||
dataset_features = {}
|
||||
if processed_features["action"]:
|
||||
if processed_features[ACTION]:
|
||||
dataset_features.update(hw_to_dataset_features(processed_features["action"], ACTION, use_videos))
|
||||
if processed_features["observation"]:
|
||||
dataset_features.update(
|
||||
hw_to_dataset_features(processed_features["observation"], "observation", use_videos)
|
||||
)
|
||||
if processed_features[OBS_STR]:
|
||||
dataset_features.update(hw_to_dataset_features(processed_features[OBS_STR], OBS_STR, use_videos))
|
||||
|
||||
return dataset_features
|
||||
|
||||
@@ -43,6 +43,7 @@ from lerobot.datasets.backward_compatibility import (
|
||||
BackwardCompatibilityError,
|
||||
ForwardCompatibilityError,
|
||||
)
|
||||
from lerobot.utils.constants import OBS_ENV_STATE, OBS_STR
|
||||
from lerobot.utils.utils import is_valid_numpy_dtype_string
|
||||
|
||||
DEFAULT_CHUNK_SIZE = 1000 # Max number of files per chunk
|
||||
@@ -652,7 +653,7 @@ def hw_to_dataset_features(
|
||||
"names": list(joint_fts),
|
||||
}
|
||||
|
||||
if joint_fts and prefix == "observation":
|
||||
if joint_fts and prefix == OBS_STR:
|
||||
features[f"{prefix}.state"] = {
|
||||
"dtype": "float32",
|
||||
"shape": (len(joint_fts),),
|
||||
@@ -728,9 +729,9 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea
|
||||
# Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets.
|
||||
if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
|
||||
shape = (shape[2], shape[0], shape[1])
|
||||
elif key == "observation.environment_state":
|
||||
elif key == OBS_ENV_STATE:
|
||||
type = FeatureType.ENV
|
||||
elif key.startswith("observation"):
|
||||
elif key.startswith(OBS_STR):
|
||||
type = FeatureType.STATE
|
||||
elif key.startswith("action"):
|
||||
type = FeatureType.ACTION
|
||||
|
||||
@@ -26,6 +26,7 @@ from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.envs.configs import EnvConfig
|
||||
from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.utils.utils import get_channel_first_image_shape
|
||||
|
||||
|
||||
@@ -41,9 +42,9 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
||||
return_observations = {}
|
||||
if "pixels" in observations:
|
||||
if isinstance(observations["pixels"], dict):
|
||||
imgs = {f"observation.images.{key}": img for key, img in observations["pixels"].items()}
|
||||
imgs = {f"{OBS_IMAGES}.{key}": img for key, img in observations["pixels"].items()}
|
||||
else:
|
||||
imgs = {"observation.image": observations["pixels"]}
|
||||
imgs = {OBS_IMAGE: observations["pixels"]}
|
||||
|
||||
for imgkey, img in imgs.items():
|
||||
# TODO(aliberts, rcadene): use transforms.ToTensor()?
|
||||
@@ -72,13 +73,13 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
||||
if env_state.dim() == 1:
|
||||
env_state = env_state.unsqueeze(0)
|
||||
|
||||
return_observations["observation.environment_state"] = env_state
|
||||
return_observations[OBS_ENV_STATE] = env_state
|
||||
|
||||
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing
|
||||
agent_pos = torch.from_numpy(observations["agent_pos"]).float()
|
||||
if agent_pos.dim() == 1:
|
||||
agent_pos = agent_pos.unsqueeze(0)
|
||||
return_observations["observation.state"] = agent_pos
|
||||
return_observations[OBS_STATE] = agent_pos
|
||||
|
||||
return return_observations
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ from torchvision.ops.misc import FrozenBatchNorm2d
|
||||
|
||||
from lerobot.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES
|
||||
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE
|
||||
|
||||
|
||||
class ACTPolicy(PreTrainedPolicy):
|
||||
@@ -398,10 +398,10 @@ class ACT(nn.Module):
|
||||
"actions must be provided when using the variational objective in training mode."
|
||||
)
|
||||
|
||||
if "observation.images" in batch:
|
||||
batch_size = batch["observation.images"][0].shape[0]
|
||||
if OBS_IMAGES in batch:
|
||||
batch_size = batch[OBS_IMAGES][0].shape[0]
|
||||
else:
|
||||
batch_size = batch["observation.environment_state"].shape[0]
|
||||
batch_size = batch[OBS_ENV_STATE].shape[0]
|
||||
|
||||
# Prepare the latent for input to the transformer encoder.
|
||||
if self.config.use_vae and "action" in batch and self.training:
|
||||
@@ -410,7 +410,7 @@ class ACT(nn.Module):
|
||||
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
|
||||
) # (B, 1, D)
|
||||
if self.config.robot_state_feature:
|
||||
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"])
|
||||
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch[OBS_STATE])
|
||||
robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D)
|
||||
action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D)
|
||||
|
||||
@@ -430,7 +430,7 @@ class ACT(nn.Module):
|
||||
cls_joint_is_pad = torch.full(
|
||||
(batch_size, 2 if self.config.robot_state_feature else 1),
|
||||
False,
|
||||
device=batch["observation.state"].device,
|
||||
device=batch[OBS_STATE].device,
|
||||
)
|
||||
key_padding_mask = torch.cat(
|
||||
[cls_joint_is_pad, batch["action_is_pad"]], axis=1
|
||||
@@ -454,7 +454,7 @@ class ACT(nn.Module):
|
||||
mu = log_sigma_x2 = None
|
||||
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer
|
||||
latent_sample = torch.zeros([batch_size, self.config.latent_dim], dtype=torch.float32).to(
|
||||
batch["observation.state"].device
|
||||
batch[OBS_STATE].device
|
||||
)
|
||||
|
||||
# Prepare transformer encoder inputs.
|
||||
@@ -462,18 +462,16 @@ class ACT(nn.Module):
|
||||
encoder_in_pos_embed = list(self.encoder_1d_feature_pos_embed.weight.unsqueeze(1))
|
||||
# Robot state token.
|
||||
if self.config.robot_state_feature:
|
||||
encoder_in_tokens.append(self.encoder_robot_state_input_proj(batch["observation.state"]))
|
||||
encoder_in_tokens.append(self.encoder_robot_state_input_proj(batch[OBS_STATE]))
|
||||
# Environment state token.
|
||||
if self.config.env_state_feature:
|
||||
encoder_in_tokens.append(
|
||||
self.encoder_env_state_input_proj(batch["observation.environment_state"])
|
||||
)
|
||||
encoder_in_tokens.append(self.encoder_env_state_input_proj(batch[OBS_ENV_STATE]))
|
||||
|
||||
if self.config.image_features:
|
||||
# For a list of images, the H and W may vary but H*W is constant.
|
||||
# NOTE: If modifying this section, verify on MPS devices that
|
||||
# gradients remain stable (no explosions or NaNs).
|
||||
for img in batch["observation.images"]:
|
||||
for img in batch[OBS_IMAGES]:
|
||||
cam_features = self.backbone(img)["feature_map"]
|
||||
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
|
||||
cam_features = self.encoder_img_feat_input_proj(cam_features)
|
||||
|
||||
@@ -81,13 +81,13 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
def reset(self):
|
||||
"""Clear observation and action queues. Should be called on `env.reset()`"""
|
||||
self._queues = {
|
||||
"observation.state": deque(maxlen=self.config.n_obs_steps),
|
||||
OBS_STATE: deque(maxlen=self.config.n_obs_steps),
|
||||
"action": deque(maxlen=self.config.n_action_steps),
|
||||
}
|
||||
if self.config.image_features:
|
||||
self._queues["observation.images"] = deque(maxlen=self.config.n_obs_steps)
|
||||
self._queues[OBS_IMAGES] = deque(maxlen=self.config.n_obs_steps)
|
||||
if self.config.env_state_feature:
|
||||
self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps)
|
||||
self._queues[OBS_ENV_STATE] = deque(maxlen=self.config.n_obs_steps)
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
@@ -234,7 +234,7 @@ class DiffusionModel(nn.Module):
|
||||
if self.config.image_features:
|
||||
if self.config.use_separate_rgb_encoder_per_camera:
|
||||
# Combine batch and sequence dims while rearranging to make the camera index dimension first.
|
||||
images_per_camera = einops.rearrange(batch["observation.images"], "b s n ... -> n (b s) ...")
|
||||
images_per_camera = einops.rearrange(batch[OBS_IMAGES], "b s n ... -> n (b s) ...")
|
||||
img_features_list = torch.cat(
|
||||
[
|
||||
encoder(images)
|
||||
@@ -249,7 +249,7 @@ class DiffusionModel(nn.Module):
|
||||
else:
|
||||
# Combine batch, sequence, and "which camera" dims before passing to shared encoder.
|
||||
img_features = self.rgb_encoder(
|
||||
einops.rearrange(batch["observation.images"], "b s n ... -> (b s n) ...")
|
||||
einops.rearrange(batch[OBS_IMAGES], "b s n ... -> (b s n) ...")
|
||||
)
|
||||
# Separate batch dim and sequence dim back out. The camera index dim gets absorbed into the
|
||||
# feature dim (effectively concatenating the camera features).
|
||||
@@ -275,7 +275,7 @@ class DiffusionModel(nn.Module):
|
||||
"observation.environment_state": (B, n_obs_steps, environment_dim)
|
||||
}
|
||||
"""
|
||||
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
|
||||
batch_size, n_obs_steps = batch[OBS_STATE].shape[:2]
|
||||
assert n_obs_steps == self.config.n_obs_steps
|
||||
|
||||
# Encode image features and concatenate them all together along with the state vector.
|
||||
@@ -306,9 +306,9 @@ class DiffusionModel(nn.Module):
|
||||
}
|
||||
"""
|
||||
# Input validation.
|
||||
assert set(batch).issuperset({"observation.state", "action", "action_is_pad"})
|
||||
assert "observation.images" in batch or "observation.environment_state" in batch
|
||||
n_obs_steps = batch["observation.state"].shape[1]
|
||||
assert set(batch).issuperset({OBS_STATE, "action", "action_is_pad"})
|
||||
assert OBS_IMAGES in batch or OBS_ENV_STATE in batch
|
||||
n_obs_steps = batch[OBS_STATE].shape[1]
|
||||
horizon = batch["action"].shape[1]
|
||||
assert horizon == self.config.horizon
|
||||
assert n_obs_steps == self.config.n_obs_steps
|
||||
|
||||
@@ -20,6 +20,7 @@ from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import (
|
||||
CosineDecayWithWarmupSchedulerConfig,
|
||||
)
|
||||
from lerobot.utils.constants import OBS_IMAGES
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("pi0")
|
||||
@@ -113,7 +114,7 @@ class PI0Config(PreTrainedConfig):
|
||||
# raise ValueError("You must provide at least one image or the environment state among the inputs.")
|
||||
|
||||
for i in range(self.empty_cameras):
|
||||
key = f"observation.images.empty_camera_{i}"
|
||||
key = f"{OBS_IMAGES}.empty_camera_{i}"
|
||||
empty_camera = PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, 480, 640),
|
||||
|
||||
@@ -21,6 +21,7 @@ import torch
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||
from lerobot.policies.factory import make_policy
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
|
||||
|
||||
|
||||
def display(tensor: torch.Tensor):
|
||||
@@ -60,26 +61,26 @@ def main():
|
||||
|
||||
# Override stats
|
||||
dataset_meta = LeRobotDatasetMetadata(dataset_repo_id)
|
||||
dataset_meta.stats["observation.state"]["mean"] = torch.tensor(
|
||||
dataset_meta.stats[OBS_STATE]["mean"] = torch.tensor(
|
||||
norm_stats["norm_stats"]["state"]["mean"][:num_motors], dtype=torch.float32
|
||||
)
|
||||
dataset_meta.stats["observation.state"]["std"] = torch.tensor(
|
||||
dataset_meta.stats[OBS_STATE]["std"] = torch.tensor(
|
||||
norm_stats["norm_stats"]["state"]["std"][:num_motors], dtype=torch.float32
|
||||
)
|
||||
|
||||
# Create LeRobot batch from Jax
|
||||
batch = {}
|
||||
for cam_key, uint_chw_array in example["images"].items():
|
||||
batch[f"observation.images.{cam_key}"] = torch.from_numpy(uint_chw_array) / 255.0
|
||||
batch["observation.state"] = torch.from_numpy(example["state"])
|
||||
batch[f"{OBS_IMAGES}.{cam_key}"] = torch.from_numpy(uint_chw_array) / 255.0
|
||||
batch[OBS_STATE] = torch.from_numpy(example["state"])
|
||||
batch["action"] = torch.from_numpy(outputs["actions"])
|
||||
batch["task"] = example["prompt"]
|
||||
|
||||
if model_name == "pi0_aloha_towel":
|
||||
del batch["observation.images.cam_low"]
|
||||
del batch[f"{OBS_IMAGES}.cam_low"]
|
||||
elif model_name == "pi0_aloha_sim":
|
||||
batch["observation.images.top"] = batch["observation.images.cam_high"]
|
||||
del batch["observation.images.cam_high"]
|
||||
batch[f"{OBS_IMAGES}.top"] = batch[f"{OBS_IMAGES}.cam_high"]
|
||||
del batch[f"{OBS_IMAGES}.cam_high"]
|
||||
|
||||
# Batchify
|
||||
for key in batch:
|
||||
|
||||
@@ -6,6 +6,7 @@ from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import (
|
||||
CosineDecayWithWarmupSchedulerConfig,
|
||||
)
|
||||
from lerobot.utils.constants import OBS_IMAGES
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("pi0fast")
|
||||
@@ -99,7 +100,7 @@ class PI0FASTConfig(PreTrainedConfig):
|
||||
|
||||
def validate_features(self) -> None:
|
||||
for i in range(self.empty_cameras):
|
||||
key = f"observation.images.empty_camera_{i}"
|
||||
key = f"{OBS_IMAGES}.empty_camera_{i}"
|
||||
empty_camera = PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, 480, 640),
|
||||
|
||||
@@ -31,6 +31,7 @@ from torch.distributions import MultivariateNormal, TanhTransform, Transform, Tr
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.sac.configuration_sac import SACConfig, is_image_feature
|
||||
from lerobot.policies.utils import get_device_from_parameters
|
||||
from lerobot.utils.constants import OBS_ENV_STATE, OBS_STATE
|
||||
|
||||
DISCRETE_DIMENSION_INDEX = -1 # Gripper is always the last dimension
|
||||
|
||||
@@ -513,17 +514,17 @@ class SACObservationEncoder(nn.Module):
|
||||
)
|
||||
|
||||
def _init_state_layers(self) -> None:
|
||||
self.has_env = "observation.environment_state" in self.config.input_features
|
||||
self.has_state = "observation.state" in self.config.input_features
|
||||
self.has_env = OBS_ENV_STATE in self.config.input_features
|
||||
self.has_state = OBS_STATE in self.config.input_features
|
||||
if self.has_env:
|
||||
dim = self.config.input_features["observation.environment_state"].shape[0]
|
||||
dim = self.config.input_features[OBS_ENV_STATE].shape[0]
|
||||
self.env_encoder = nn.Sequential(
|
||||
nn.Linear(dim, self.config.latent_dim),
|
||||
nn.LayerNorm(self.config.latent_dim),
|
||||
nn.Tanh(),
|
||||
)
|
||||
if self.has_state:
|
||||
dim = self.config.input_features["observation.state"].shape[0]
|
||||
dim = self.config.input_features[OBS_STATE].shape[0]
|
||||
self.state_encoder = nn.Sequential(
|
||||
nn.Linear(dim, self.config.latent_dim),
|
||||
nn.LayerNorm(self.config.latent_dim),
|
||||
@@ -549,9 +550,9 @@ class SACObservationEncoder(nn.Module):
|
||||
cache = self.get_cached_image_features(obs)
|
||||
parts.append(self._encode_images(cache, detach))
|
||||
if self.has_env:
|
||||
parts.append(self.env_encoder(obs["observation.environment_state"]))
|
||||
parts.append(self.env_encoder(obs[OBS_ENV_STATE]))
|
||||
if self.has_state:
|
||||
parts.append(self.state_encoder(obs["observation.state"]))
|
||||
parts.append(self.state_encoder(obs[OBS_STATE]))
|
||||
if parts:
|
||||
return torch.cat(parts, dim=-1)
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
from lerobot.optim.optimizers import AdamWConfig, OptimizerConfig
|
||||
from lerobot.optim.schedulers import LRSchedulerConfig
|
||||
from lerobot.utils.constants import OBS_IMAGE
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass(name="reward_classifier")
|
||||
@@ -69,7 +70,7 @@ class RewardClassifierConfig(PreTrainedConfig):
|
||||
|
||||
def validate_features(self) -> None:
|
||||
"""Validate feature configurations."""
|
||||
has_image = any(key.startswith("observation.image") for key in self.input_features)
|
||||
has_image = any(key.startswith(OBS_IMAGE) for key in self.input_features)
|
||||
if not has_image:
|
||||
raise ValueError(
|
||||
"You must provide an image observation (key starting with 'observation.image') in the input features"
|
||||
|
||||
@@ -20,6 +20,7 @@ from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import (
|
||||
CosineDecayWithWarmupSchedulerConfig,
|
||||
)
|
||||
from lerobot.utils.constants import OBS_IMAGES
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("smolvla")
|
||||
@@ -117,7 +118,7 @@ class SmolVLAConfig(PreTrainedConfig):
|
||||
|
||||
def validate_features(self) -> None:
|
||||
for i in range(self.empty_cameras):
|
||||
key = f"observation.images.empty_camera_{i}"
|
||||
key = f"{OBS_IMAGES}.empty_camera_{i}"
|
||||
empty_camera = PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, 480, 640),
|
||||
|
||||
@@ -38,7 +38,7 @@ from torch import Tensor
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.policies.utils import get_device_from_parameters, get_output_shape, populate_queues
|
||||
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGE, OBS_STATE, REWARD
|
||||
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGE, OBS_PREFIX, OBS_STATE, OBS_STR, REWARD
|
||||
|
||||
|
||||
class TDMPCPolicy(PreTrainedPolicy):
|
||||
@@ -91,13 +91,13 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
called on `env.reset()`
|
||||
"""
|
||||
self._queues = {
|
||||
"observation.state": deque(maxlen=1),
|
||||
OBS_STATE: deque(maxlen=1),
|
||||
"action": deque(maxlen=max(self.config.n_action_steps, self.config.n_action_repeats)),
|
||||
}
|
||||
if self.config.image_features:
|
||||
self._queues["observation.image"] = deque(maxlen=1)
|
||||
self._queues[OBS_IMAGE] = deque(maxlen=1)
|
||||
if self.config.env_state_feature:
|
||||
self._queues["observation.environment_state"] = deque(maxlen=1)
|
||||
self._queues[OBS_ENV_STATE] = deque(maxlen=1)
|
||||
# Previous mean obtained from the cross-entropy method (CEM) used during MPC. It is used to warm start
|
||||
# CEM for the next step.
|
||||
self._prev_mean: torch.Tensor | None = None
|
||||
@@ -325,7 +325,7 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
|
||||
action = batch[ACTION] # (t, b, action_dim)
|
||||
reward = batch[REWARD] # (t, b)
|
||||
observations = {k: v for k, v in batch.items() if k.startswith("observation.")}
|
||||
observations = {k: v for k, v in batch.items() if k.startswith(OBS_PREFIX)}
|
||||
|
||||
# Apply random image augmentations.
|
||||
if self.config.image_features and self.config.max_random_shift_ratio > 0:
|
||||
@@ -387,10 +387,10 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
temporal_loss_coeffs
|
||||
* F.mse_loss(z_preds[1:], z_targets, reduction="none").mean(dim=-1)
|
||||
# `z_preds` depends on the current observation and the actions.
|
||||
* ~batch["observation.state_is_pad"][0]
|
||||
* ~batch[f"{OBS_STR}.state_is_pad"][0]
|
||||
* ~batch["action_is_pad"]
|
||||
# `z_targets` depends on the next observation.
|
||||
* ~batch["observation.state_is_pad"][1:]
|
||||
* ~batch[f"{OBS_STR}.state_is_pad"][1:]
|
||||
)
|
||||
.sum(0)
|
||||
.mean()
|
||||
@@ -403,7 +403,7 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
* F.mse_loss(reward_preds, reward, reduction="none")
|
||||
* ~batch["next.reward_is_pad"]
|
||||
# `reward_preds` depends on the current observation and the actions.
|
||||
* ~batch["observation.state_is_pad"][0]
|
||||
* ~batch[f"{OBS_STR}.state_is_pad"][0]
|
||||
* ~batch["action_is_pad"]
|
||||
)
|
||||
.sum(0)
|
||||
@@ -419,11 +419,11 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
reduction="none",
|
||||
).sum(0) # sum over ensemble
|
||||
# `q_preds_ensemble` depends on the first observation and the actions.
|
||||
* ~batch["observation.state_is_pad"][0]
|
||||
* ~batch[f"{OBS_STR}.state_is_pad"][0]
|
||||
* ~batch["action_is_pad"]
|
||||
# q_targets depends on the reward and the next observations.
|
||||
* ~batch["next.reward_is_pad"]
|
||||
* ~batch["observation.state_is_pad"][1:]
|
||||
* ~batch[f"{OBS_STR}.state_is_pad"][1:]
|
||||
)
|
||||
.sum(0)
|
||||
.mean()
|
||||
@@ -441,7 +441,7 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
temporal_loss_coeffs
|
||||
* raw_v_value_loss
|
||||
# `v_targets` depends on the first observation and the actions, as does `v_preds`.
|
||||
* ~batch["observation.state_is_pad"][0]
|
||||
* ~batch[f"{OBS_STR}.state_is_pad"][0]
|
||||
* ~batch["action_is_pad"]
|
||||
)
|
||||
.sum(0)
|
||||
@@ -477,7 +477,7 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
* mse
|
||||
* temporal_loss_coeffs
|
||||
# `action_preds` depends on the first observation and the actions.
|
||||
* ~batch["observation.state_is_pad"][0]
|
||||
* ~batch[f"{OBS_STR}.state_is_pad"][0]
|
||||
* ~batch["action_is_pad"]
|
||||
).mean()
|
||||
|
||||
|
||||
@@ -133,7 +133,7 @@ class VQBeTPolicy(PreTrainedPolicy):
|
||||
batch.pop(ACTION)
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
# NOTE: It's important that this happens after stacking the images into a single key.
|
||||
batch["observation.images"] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
||||
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
||||
# NOTE: for offline evaluation, we have action in the batch, so we need to pop it out
|
||||
if ACTION in batch:
|
||||
batch.pop(ACTION)
|
||||
@@ -340,14 +340,12 @@ class VQBeTModel(nn.Module):
|
||||
|
||||
def forward(self, batch: dict[str, Tensor], rollout: bool) -> tuple[dict, dict]:
|
||||
# Input validation.
|
||||
assert set(batch).issuperset({"observation.state", "observation.images"})
|
||||
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
|
||||
assert set(batch).issuperset({OBS_STATE, OBS_IMAGES})
|
||||
batch_size, n_obs_steps = batch[OBS_STATE].shape[:2]
|
||||
assert n_obs_steps == self.config.n_obs_steps
|
||||
|
||||
# Extract image feature (first combine batch and sequence dims).
|
||||
img_features = self.rgb_encoder(
|
||||
einops.rearrange(batch["observation.images"], "b s n ... -> (b s n) ...")
|
||||
)
|
||||
img_features = self.rgb_encoder(einops.rearrange(batch[OBS_IMAGES], "b s n ... -> (b s n) ..."))
|
||||
# Separate batch and sequence dims.
|
||||
img_features = einops.rearrange(
|
||||
img_features, "(b s n) ... -> b s n ...", b=batch_size, s=n_obs_steps, n=self.num_images
|
||||
@@ -359,9 +357,7 @@ class VQBeTModel(nn.Module):
|
||||
img_features
|
||||
) # (batch, obs_step, number of different cameras, projection dims)
|
||||
input_tokens = [rgb_tokens[:, :, i] for i in range(rgb_tokens.size(2))]
|
||||
input_tokens.append(
|
||||
self.state_projector(batch["observation.state"])
|
||||
) # (batch, obs_step, projection dims)
|
||||
input_tokens.append(self.state_projector(batch[OBS_STATE])) # (batch, obs_step, projection dims)
|
||||
input_tokens.append(einops.repeat(self.action_token, "1 1 d -> b n d", b=batch_size, n=n_obs_steps))
|
||||
# Interleave tokens by stacking and rearranging.
|
||||
input_tokens = torch.stack(input_tokens, dim=2)
|
||||
|
||||
@@ -23,6 +23,8 @@ from typing import Any
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.utils.constants import OBS_PREFIX
|
||||
|
||||
from .core import EnvTransition, PolicyAction, RobotAction, RobotObservation, TransitionKey
|
||||
|
||||
|
||||
@@ -347,7 +349,7 @@ def batch_to_transition(batch: dict[str, Any]) -> EnvTransition:
|
||||
raise ValueError(f"Action should be a PolicyAction type got {type(action)}")
|
||||
|
||||
# Extract observation and complementary data keys.
|
||||
observation_keys = {k: v for k, v in batch.items() if k.startswith("observation.")}
|
||||
observation_keys = {k: v for k, v in batch.items() if k.startswith(OBS_PREFIX)}
|
||||
complementary_data = _extract_complementary_data(batch)
|
||||
|
||||
return create_transition(
|
||||
|
||||
@@ -21,7 +21,7 @@ import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE, OBS_STR
|
||||
|
||||
from .pipeline import ObservationProcessorStep, ProcessorStepRegistry
|
||||
|
||||
@@ -171,7 +171,7 @@ class VanillaObservationProcessorStep(ObservationProcessorStep):
|
||||
|
||||
# Prefix-based rules (e.g. pixels.cam1 -> OBS_IMAGES.cam1)
|
||||
for old_prefix, new_prefix in prefix_pairs.items():
|
||||
prefixed_old = f"observation.{old_prefix}"
|
||||
prefixed_old = f"{OBS_STR}.{old_prefix}"
|
||||
if key.startswith(prefixed_old):
|
||||
suffix = key[len(prefixed_old) :]
|
||||
new_key = f"{new_prefix}{suffix}"
|
||||
@@ -191,7 +191,7 @@ class VanillaObservationProcessorStep(ObservationProcessorStep):
|
||||
|
||||
# Exact-name rules (pixels, environment_state, agent_pos)
|
||||
for old, new in exact_pairs.items():
|
||||
if key == old or key == f"observation.{old}":
|
||||
if key == old or key == f"{OBS_STR}.{old}":
|
||||
new_key = new
|
||||
new_features[src_ft][new_key] = feat
|
||||
handled = True
|
||||
|
||||
@@ -24,6 +24,7 @@ import torch.nn.functional as F # noqa: N812
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.utils.constants import OBS_IMAGE
|
||||
from lerobot.utils.transition import Transition
|
||||
|
||||
|
||||
@@ -240,7 +241,7 @@ class ReplayBuffer:
|
||||
idx = torch.randint(low=0, high=high, size=(batch_size,), device=self.storage_device)
|
||||
|
||||
# Identify image keys that need augmentation
|
||||
image_keys = [k for k in self.states if k.startswith("observation.image")] if self.use_drq else []
|
||||
image_keys = [k for k in self.states if k.startswith(OBS_IMAGE)] if self.use_drq else []
|
||||
|
||||
# Create batched state and next_state
|
||||
batch_state = {}
|
||||
|
||||
@@ -73,6 +73,7 @@ from lerobot.teleoperators import (
|
||||
)
|
||||
from lerobot.teleoperators.teleoperator import Teleoperator
|
||||
from lerobot.teleoperators.utils import TeleopEvents
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.utils import log_say
|
||||
|
||||
@@ -180,7 +181,7 @@ class RobotEnv(gym.Env):
|
||||
|
||||
# Define observation spaces for images and other states.
|
||||
if current_observation is not None and "pixels" in current_observation:
|
||||
prefix = "observation.images"
|
||||
prefix = OBS_IMAGES
|
||||
observation_spaces = {
|
||||
f"{prefix}.{key}": gym.spaces.Box(
|
||||
low=0, high=255, shape=current_observation["pixels"][key].shape, dtype=np.uint8
|
||||
@@ -190,7 +191,7 @@ class RobotEnv(gym.Env):
|
||||
|
||||
if current_observation is not None:
|
||||
agent_pos = current_observation["agent_pos"]
|
||||
observation_spaces["observation.state"] = gym.spaces.Box(
|
||||
observation_spaces[OBS_STATE] = gym.spaces.Box(
|
||||
low=0,
|
||||
high=10,
|
||||
shape=agent_pos.shape,
|
||||
@@ -612,7 +613,7 @@ def control_loop(
|
||||
}
|
||||
|
||||
for key, value in transition[TransitionKey.OBSERVATION].items():
|
||||
if key == "observation.state":
|
||||
if key == OBS_STATE:
|
||||
features[key] = {
|
||||
"dtype": "float32",
|
||||
"shape": value.squeeze(0).shape,
|
||||
|
||||
@@ -23,6 +23,7 @@ from typing import Any
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from lerobot.utils.constants import OBS_STATE
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
|
||||
from ..robot import Robot
|
||||
@@ -203,7 +204,7 @@ class LeKiwiClient(Robot):
|
||||
|
||||
state_vec = np.array([flat_state[key] for key in self._state_order], dtype=np.float32)
|
||||
|
||||
obs_dict: dict[str, Any] = {**flat_state, "observation.state": state_vec}
|
||||
obs_dict: dict[str, Any] = {**flat_state, OBS_STATE: state_vec}
|
||||
|
||||
# Decode images
|
||||
current_frames: dict[str, np.ndarray] = {}
|
||||
|
||||
@@ -75,6 +75,7 @@ import torch.utils.data
|
||||
import tqdm
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.utils.constants import OBS_STATE
|
||||
|
||||
|
||||
class EpisodeSampler(torch.utils.data.Sampler):
|
||||
@@ -161,8 +162,8 @@ def visualize_dataset(
|
||||
rr.log(f"action/{dim_idx}", rr.Scalar(val.item()))
|
||||
|
||||
# display each dimension of observed state space (e.g. agent position in joint space)
|
||||
if "observation.state" in batch:
|
||||
for dim_idx, val in enumerate(batch["observation.state"][i]):
|
||||
if OBS_STATE in batch:
|
||||
for dim_idx, val in enumerate(batch[OBS_STATE][i]):
|
||||
rr.log(f"state/{dim_idx}", rr.Scalar(val.item()))
|
||||
|
||||
if "next.done" in batch:
|
||||
|
||||
@@ -81,6 +81,7 @@ from lerobot.envs.utils import (
|
||||
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
|
||||
from lerobot.utils.constants import OBS_STR
|
||||
from lerobot.utils.io_utils import write_video
|
||||
from lerobot.utils.random_utils import set_seed
|
||||
from lerobot.utils.utils import (
|
||||
@@ -221,7 +222,7 @@ def rollout(
|
||||
stacked_observations = {}
|
||||
for key in all_observations[0]:
|
||||
stacked_observations[key] = torch.stack([obs[key] for obs in all_observations], dim=1)
|
||||
ret["observation"] = stacked_observations
|
||||
ret[OBS_STR] = stacked_observations
|
||||
|
||||
if hasattr(policy, "use_original_modules"):
|
||||
policy.use_original_modules()
|
||||
@@ -459,8 +460,8 @@ def _compile_episode_data(
|
||||
for k in ep_dict:
|
||||
ep_dict[k] = torch.cat([ep_dict[k], ep_dict[k][-1:]])
|
||||
|
||||
for key in rollout_data["observation"]:
|
||||
ep_dict[key] = rollout_data["observation"][key][ep_ix, :num_frames]
|
||||
for key in rollout_data[OBS_STR]:
|
||||
ep_dict[key] = rollout_data[OBS_STR][key][ep_ix, :num_frames]
|
||||
|
||||
ep_dicts.append(ep_dict)
|
||||
|
||||
|
||||
@@ -109,6 +109,7 @@ from lerobot.teleoperators import ( # noqa: F401
|
||||
so101_leader,
|
||||
)
|
||||
from lerobot.teleoperators.keyboard.teleop_keyboard import KeyboardTeleop
|
||||
from lerobot.utils.constants import OBS_STR
|
||||
from lerobot.utils.control_utils import (
|
||||
init_keyboard_listener,
|
||||
is_headless,
|
||||
@@ -303,7 +304,7 @@ def record_loop(
|
||||
obs_processed = robot_observation_processor(obs)
|
||||
|
||||
if policy is not None or dataset is not None:
|
||||
observation_frame = build_dataset_frame(dataset.features, obs_processed, prefix="observation")
|
||||
observation_frame = build_dataset_frame(dataset.features, obs_processed, prefix=OBS_STR)
|
||||
|
||||
# Get action from either policy or teleop
|
||||
if policy is not None and preprocessor is not None and postprocessor is not None:
|
||||
|
||||
@@ -17,19 +17,21 @@ from pathlib import Path
|
||||
|
||||
from huggingface_hub.constants import HF_HOME
|
||||
|
||||
OBS_ENV_STATE = "observation.environment_state"
|
||||
OBS_STATE = "observation.state"
|
||||
OBS_IMAGE = "observation.image"
|
||||
OBS_IMAGES = "observation.images"
|
||||
OBS_LANGUAGE = "observation.language"
|
||||
OBS_STR = "observation"
|
||||
OBS_PREFIX = OBS_STR + "."
|
||||
OBS_ENV_STATE = OBS_STR + ".environment_state"
|
||||
OBS_STATE = OBS_STR + ".state"
|
||||
OBS_IMAGE = OBS_STR + ".image"
|
||||
OBS_IMAGES = OBS_IMAGE + "s"
|
||||
OBS_LANGUAGE = OBS_STR + ".language"
|
||||
OBS_LANGUAGE_TOKENS = OBS_LANGUAGE + ".tokens"
|
||||
OBS_LANGUAGE_ATTENTION_MASK = OBS_LANGUAGE + ".attention_mask"
|
||||
|
||||
ACTION = "action"
|
||||
REWARD = "next.reward"
|
||||
TRUNCATED = "next.truncated"
|
||||
DONE = "next.done"
|
||||
|
||||
OBS_LANGUAGE_TOKENS = OBS_LANGUAGE + ".tokens"
|
||||
OBS_LANGUAGE_ATTENTION_MASK = OBS_LANGUAGE + ".attention_mask"
|
||||
|
||||
ROBOTS = "robots"
|
||||
ROBOT_TYPE = "robot_type"
|
||||
TELEOPERATORS = "teleoperators"
|
||||
|
||||
@@ -19,6 +19,8 @@ from typing import Any
|
||||
import numpy as np
|
||||
import rerun as rr
|
||||
|
||||
from .constants import OBS_PREFIX, OBS_STR
|
||||
|
||||
|
||||
def init_rerun(session_name: str = "lerobot_control_loop") -> None:
|
||||
"""Initializes the Rerun SDK for visualizing the control loop."""
|
||||
@@ -63,7 +65,7 @@ def log_rerun_data(
|
||||
for k, v in observation.items():
|
||||
if v is None:
|
||||
continue
|
||||
key = k if str(k).startswith("observation.") else f"observation.{k}"
|
||||
key = k if str(k).startswith(OBS_PREFIX) else f"{OBS_STR}.{k}"
|
||||
|
||||
if _is_scalar(v):
|
||||
rr.log(key, rr.Scalar(float(v)))
|
||||
|
||||
Reference in New Issue
Block a user