forked from tangger/lerobot
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
committed by
AdilZouitine
parent
761a2dbcb3
commit
8e6d5f504c
@@ -171,7 +171,9 @@ class ACTConfig(PreTrainedConfig):
|
||||
|
||||
def validate_features(self) -> None:
|
||||
if not self.image_features and not self.env_state_feature:
|
||||
raise ValueError("You must provide at least one image or the environment state among the inputs.")
|
||||
raise ValueError(
|
||||
"You must provide at least one image or the environment state among the inputs."
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> None:
|
||||
|
||||
@@ -63,7 +63,9 @@ class ACTPolicy(PreTrainedPolicy):
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_inputs = Normalize(
|
||||
config.input_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
@@ -120,8 +122,12 @@ class ACTPolicy(PreTrainedPolicy):
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = [batch[key] for key in self.config.image_features]
|
||||
batch = dict(
|
||||
batch
|
||||
) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = [
|
||||
batch[key] for key in self.config.image_features
|
||||
]
|
||||
|
||||
# If we are doing temporal ensembling, do online updates where we keep track of the number of actions
|
||||
# we are ensembling over.
|
||||
@@ -148,8 +154,12 @@ class ACTPolicy(PreTrainedPolicy):
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = [batch[key] for key in self.config.image_features]
|
||||
batch = dict(
|
||||
batch
|
||||
) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = [
|
||||
batch[key] for key in self.config.image_features
|
||||
]
|
||||
|
||||
batch = self.normalize_targets(batch)
|
||||
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
|
||||
@@ -406,14 +416,18 @@ class ACT(nn.Module):
|
||||
n_1d_tokens += 1
|
||||
self.encoder_1d_feature_pos_embed = nn.Embedding(n_1d_tokens, config.dim_model)
|
||||
if self.config.image_features:
|
||||
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2)
|
||||
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(
|
||||
config.dim_model // 2
|
||||
)
|
||||
|
||||
# Transformer decoder.
|
||||
# Learnable positional embedding for the transformer's decoder (in the style of DETR object queries).
|
||||
self.decoder_pos_embed = nn.Embedding(config.chunk_size, config.dim_model)
|
||||
|
||||
# Final action regression head on the output of the transformer's decoder.
|
||||
self.action_head = nn.Linear(config.dim_model, self.config.action_feature.shape[0])
|
||||
self.action_head = nn.Linear(
|
||||
config.dim_model, self.config.action_feature.shape[0]
|
||||
)
|
||||
|
||||
self._reset_parameters()
|
||||
|
||||
@@ -461,14 +475,20 @@ class ACT(nn.Module):
|
||||
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
|
||||
) # (B, 1, D)
|
||||
if self.config.robot_state_feature:
|
||||
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"])
|
||||
robot_state_embed = self.vae_encoder_robot_state_input_proj(
|
||||
batch["observation.state"]
|
||||
)
|
||||
robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D)
|
||||
action_embed = self.vae_encoder_action_input_proj(
|
||||
batch["action"]
|
||||
) # (B, S, D)
|
||||
|
||||
if self.config.robot_state_feature:
|
||||
vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D)
|
||||
vae_encoder_input = [
|
||||
cls_embed,
|
||||
robot_state_embed,
|
||||
action_embed,
|
||||
] # (B, S+2, D)
|
||||
else:
|
||||
vae_encoder_input = [cls_embed, action_embed]
|
||||
vae_encoder_input = torch.cat(vae_encoder_input, axis=1)
|
||||
@@ -517,7 +537,9 @@ class ACT(nn.Module):
|
||||
)
|
||||
# Robot state token.
|
||||
if self.config.robot_state_feature:
|
||||
encoder_in_tokens.append(self.encoder_robot_state_input_proj(batch["observation.state"]))
|
||||
encoder_in_tokens.append(
|
||||
self.encoder_robot_state_input_proj(batch["observation.state"])
|
||||
)
|
||||
# Environment state token.
|
||||
if self.config.env_state_feature:
|
||||
encoder_in_tokens.append(
|
||||
@@ -534,7 +556,9 @@ class ACT(nn.Module):
|
||||
# For a list of images, the H and W may vary but H*W is constant.
|
||||
for img in batch["observation.images"]:
|
||||
cam_features = self.backbone(img)["feature_map"]
|
||||
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
|
||||
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(
|
||||
dtype=cam_features.dtype
|
||||
)
|
||||
cam_features = self.encoder_img_feat_input_proj(cam_features)
|
||||
|
||||
# Rearrange features to (sequence, batch, dim).
|
||||
|
||||
@@ -205,11 +205,16 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
|
||||
def validate_features(self) -> None:
|
||||
if len(self.image_features) == 0 and self.env_state_feature is None:
|
||||
raise ValueError("You must provide at least one image or the environment state among the inputs.")
|
||||
raise ValueError(
|
||||
"You must provide at least one image or the environment state among the inputs."
|
||||
)
|
||||
|
||||
if self.crop_shape is not None:
|
||||
for key, image_ft in self.image_features.items():
|
||||
if self.crop_shape[0] > image_ft.shape[1] or self.crop_shape[1] > image_ft.shape[2]:
|
||||
if (
|
||||
self.crop_shape[0] > image_ft.shape[1]
|
||||
or self.crop_shape[1] > image_ft.shape[2]
|
||||
):
|
||||
raise ValueError(
|
||||
f"`crop_shape` should fit within the images shapes. Got {self.crop_shape} "
|
||||
f"for `crop_shape` and {image_ft.shape} for "
|
||||
|
||||
@@ -70,7 +70,9 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_inputs = Normalize(
|
||||
config.input_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
@@ -97,7 +99,9 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
if self.config.image_features:
|
||||
self._queues["observation.images"] = deque(maxlen=self.config.n_obs_steps)
|
||||
if self.config.env_state_feature:
|
||||
self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps)
|
||||
self._queues["observation.environment_state"] = deque(
|
||||
maxlen=self.config.n_obs_steps
|
||||
)
|
||||
|
||||
@torch.no_grad
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
@@ -123,7 +127,9 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
"""
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch = dict(
|
||||
batch
|
||||
) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack(
|
||||
[batch[key] for key in self.config.image_features], dim=-4
|
||||
)
|
||||
@@ -151,7 +157,9 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch = dict(
|
||||
batch
|
||||
) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack(
|
||||
[batch[key] for key in self.config.image_features], dim=-4
|
||||
)
|
||||
@@ -515,11 +523,15 @@ class DiffusionRgbEncoder(nn.Module):
|
||||
|
||||
# Note: we have a check in the config class to make sure all images have the same shape.
|
||||
images_shape = next(iter(config.image_features.values())).shape
|
||||
dummy_shape_h_w = config.crop_shape if config.crop_shape is not None else images_shape[1:]
|
||||
dummy_shape_h_w = (
|
||||
config.crop_shape if config.crop_shape is not None else images_shape[1:]
|
||||
)
|
||||
dummy_shape = (1, images_shape[0], *dummy_shape_h_w)
|
||||
feature_map_shape = get_output_shape(self.backbone, dummy_shape)[1:]
|
||||
|
||||
self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints)
|
||||
self.pool = SpatialSoftmax(
|
||||
feature_map_shape, num_kp=config.spatial_softmax_num_keypoints
|
||||
)
|
||||
self.feature_dim = config.spatial_softmax_num_keypoints * 2
|
||||
self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim)
|
||||
self.relu = nn.ReLU()
|
||||
@@ -719,7 +731,9 @@ class DiffusionConditionalUnet1d(nn.Module):
|
||||
)
|
||||
|
||||
self.final_conv = nn.Sequential(
|
||||
DiffusionConv1dBlock(config.down_dims[0], config.down_dims[0], kernel_size=config.kernel_size),
|
||||
DiffusionConv1dBlock(
|
||||
config.down_dims[0], config.down_dims[0], kernel_size=config.kernel_size
|
||||
),
|
||||
nn.Conv1d(config.down_dims[0], config.action_feature.shape[0], 1),
|
||||
)
|
||||
|
||||
|
||||
@@ -104,7 +104,9 @@ def make_policy(
|
||||
PreTrainedPolicy: _description_
|
||||
"""
|
||||
if bool(ds_meta) == bool(env_cfg):
|
||||
raise ValueError("Either one of a dataset metadata or a sim env must be provided.")
|
||||
raise ValueError(
|
||||
"Either one of a dataset metadata or a sim env must be provided."
|
||||
)
|
||||
|
||||
# NOTE: Currently, if you try to run vqbet with mps backend, you'll get this error.
|
||||
# TODO(aliberts, rcadene): Implement a check_backend_compatibility in policies?
|
||||
@@ -134,8 +136,12 @@ def make_policy(
|
||||
)
|
||||
features = env_to_policy_features(env_cfg)
|
||||
|
||||
cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
||||
cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
|
||||
cfg.output_features = {
|
||||
key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION
|
||||
}
|
||||
cfg.input_features = {
|
||||
key: ft for key, ft in features.items() if key not in cfg.output_features
|
||||
}
|
||||
kwargs["config"] = cfg
|
||||
|
||||
if cfg.pretrained_path:
|
||||
|
||||
@@ -82,25 +82,43 @@ def create_stats_buffers(
|
||||
if stats:
|
||||
if isinstance(stats[key]["mean"], np.ndarray):
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32)
|
||||
buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=torch.float32)
|
||||
buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(
|
||||
dtype=torch.float32
|
||||
)
|
||||
buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(
|
||||
dtype=torch.float32
|
||||
)
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32)
|
||||
buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32)
|
||||
buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(
|
||||
dtype=torch.float32
|
||||
)
|
||||
buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(
|
||||
dtype=torch.float32
|
||||
)
|
||||
elif isinstance(stats[key]["mean"], torch.Tensor):
|
||||
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
|
||||
# tensors anywhere (for example, when we use the same stats for normalization and
|
||||
# unnormalization). See the logic here
|
||||
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
buffer["mean"].data = stats[key]["mean"].clone().to(dtype=torch.float32)
|
||||
buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32)
|
||||
buffer["mean"].data = (
|
||||
stats[key]["mean"].clone().to(dtype=torch.float32)
|
||||
)
|
||||
buffer["std"].data = (
|
||||
stats[key]["std"].clone().to(dtype=torch.float32)
|
||||
)
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
buffer["min"].data = stats[key]["min"].clone().to(dtype=torch.float32)
|
||||
buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32)
|
||||
buffer["min"].data = (
|
||||
stats[key]["min"].clone().to(dtype=torch.float32)
|
||||
)
|
||||
buffer["max"].data = (
|
||||
stats[key]["max"].clone().to(dtype=torch.float32)
|
||||
)
|
||||
else:
|
||||
type_ = type(stats[key]["mean"])
|
||||
raise ValueError(f"np.ndarray or torch.Tensor expected, but type is '{type_}' instead.")
|
||||
raise ValueError(
|
||||
f"np.ndarray or torch.Tensor expected, but type is '{type_}' instead."
|
||||
)
|
||||
|
||||
stats_buffers[key] = buffer
|
||||
return stats_buffers
|
||||
|
||||
@@ -44,7 +44,9 @@ def main():
|
||||
else:
|
||||
dataset_repo_id = "lerobot/aloha_sim_transfer_cube_human"
|
||||
|
||||
ckpt_torch_dir = Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}_pytorch"
|
||||
ckpt_torch_dir = (
|
||||
Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}_pytorch"
|
||||
)
|
||||
ckpt_jax_dir = Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}"
|
||||
save_dir = Path(f"../openpi/data/{model_name}/save")
|
||||
|
||||
@@ -70,7 +72,9 @@ def main():
|
||||
# Create LeRobot batch from Jax
|
||||
batch = {}
|
||||
for cam_key, uint_chw_array in example["images"].items():
|
||||
batch[f"observation.images.{cam_key}"] = torch.from_numpy(uint_chw_array) / 255.0
|
||||
batch[f"observation.images.{cam_key}"] = (
|
||||
torch.from_numpy(uint_chw_array) / 255.0
|
||||
)
|
||||
batch["observation.state"] = torch.from_numpy(example["state"])
|
||||
batch["action"] = torch.from_numpy(outputs["actions"])
|
||||
batch["task"] = example["prompt"]
|
||||
|
||||
@@ -54,7 +54,9 @@ def get_paligemma_config(precision: str):
|
||||
"projector_hidden_act": "gelu_fast",
|
||||
"vision_use_head": False,
|
||||
}
|
||||
final_config = PaliGemmaConfig(text_config=text_config, vision_config=vision_config, **config)
|
||||
final_config = PaliGemmaConfig(
|
||||
text_config=text_config, vision_config=vision_config, **config
|
||||
)
|
||||
return final_config
|
||||
|
||||
|
||||
|
||||
@@ -61,7 +61,11 @@ from lerobot.common.policies.pi0.conversion_scripts.conversion_utils import (
|
||||
)
|
||||
from lerobot.common.policies.pi0.modeling_pi0 import PI0Policy
|
||||
|
||||
PRECISIONS = {"bfloat16": torch.bfloat16, "float32": torch.float32, "float16": torch.float16}
|
||||
PRECISIONS = {
|
||||
"bfloat16": torch.bfloat16,
|
||||
"float32": torch.float32,
|
||||
"float16": torch.float16,
|
||||
}
|
||||
|
||||
|
||||
def slice_paligemma_state_dict(state_dict, config):
|
||||
@@ -318,7 +322,9 @@ def update_keys_with_prefix(d: dict, prefix: str) -> dict:
|
||||
return {f"{prefix}{key}": value for key, value in d.items()}
|
||||
|
||||
|
||||
def convert_pi0_checkpoint(checkpoint_dir: str, precision: str, tokenizer_id: str, output_path: str):
|
||||
def convert_pi0_checkpoint(
|
||||
checkpoint_dir: str, precision: str, tokenizer_id: str, output_path: str
|
||||
):
|
||||
# Break down orbax ckpts - they are in OCDBT
|
||||
initial_params = slice_initial_orbax_checkpoint(checkpoint_dir=checkpoint_dir)
|
||||
# process projection params
|
||||
@@ -378,7 +384,9 @@ def convert_pi0_checkpoint(checkpoint_dir: str, precision: str, tokenizer_id: st
|
||||
# gemma_config=gemma_config, paligemma_config=paligemma_config)
|
||||
pi0_model = PI0Policy(pi0_config)
|
||||
|
||||
paligemma_params = update_keys_with_prefix(paligemma_params, "model.paligemma_with_expert.")
|
||||
paligemma_params = update_keys_with_prefix(
|
||||
paligemma_params, "model.paligemma_with_expert."
|
||||
)
|
||||
gemma_params = update_keys_with_prefix(gemma_params, "model.paligemma_with_expert.")
|
||||
projection_params = update_keys_with_prefix(projection_params, "model.")
|
||||
|
||||
|
||||
@@ -48,18 +48,32 @@ def flex_attention_forward(
|
||||
|
||||
key_states = key_states[:, :, :, None, :]
|
||||
key_states = key_states.expand(
|
||||
batch_size, key_states.shape[1], num_key_value_heads, num_key_value_groups, head_dim
|
||||
batch_size,
|
||||
key_states.shape[1],
|
||||
num_key_value_heads,
|
||||
num_key_value_groups,
|
||||
head_dim,
|
||||
)
|
||||
key_states = key_states.reshape(
|
||||
batch_size, key_states.shape[1], num_key_value_heads * num_key_value_groups, head_dim
|
||||
batch_size,
|
||||
key_states.shape[1],
|
||||
num_key_value_heads * num_key_value_groups,
|
||||
head_dim,
|
||||
)
|
||||
|
||||
value_states = value_states[:, :, :, None, :]
|
||||
value_states = value_states.expand(
|
||||
batch_size, value_states.shape[1], num_key_value_heads, num_key_value_groups, head_dim
|
||||
batch_size,
|
||||
value_states.shape[1],
|
||||
num_key_value_heads,
|
||||
num_key_value_groups,
|
||||
head_dim,
|
||||
)
|
||||
value_states = value_states.reshape(
|
||||
batch_size, value_states.shape[1], num_key_value_heads * num_key_value_groups, head_dim
|
||||
batch_size,
|
||||
value_states.shape[1],
|
||||
num_key_value_heads * num_key_value_groups,
|
||||
head_dim,
|
||||
)
|
||||
|
||||
query_states = query_states.transpose(1, 2)
|
||||
|
||||
@@ -69,7 +69,11 @@ from lerobot.common.utils.utils import get_safe_dtype
|
||||
|
||||
|
||||
def create_sinusoidal_pos_embedding(
|
||||
time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
|
||||
time: torch.tensor,
|
||||
dimension: int,
|
||||
min_period: float,
|
||||
max_period: float,
|
||||
device="cpu",
|
||||
) -> Tensor:
|
||||
"""Computes sine-cosine positional embedding vectors for scalar positions."""
|
||||
if dimension % 2 != 0:
|
||||
@@ -189,7 +193,9 @@ def aloha_gripper_to_angular(value):
|
||||
|
||||
# This is the inverse of the angular to linear transformation inside the Interbotix code.
|
||||
def linear_to_radian(linear_position, arm_length, horn_radius):
|
||||
value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position)
|
||||
value = (horn_radius**2 + linear_position**2 - arm_length**2) / (
|
||||
2 * horn_radius * linear_position
|
||||
)
|
||||
return safe_arcsin(value)
|
||||
|
||||
# The constants are taken from the Interbotix code.
|
||||
@@ -240,7 +246,9 @@ class PI0Policy(PreTrainedPolicy):
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_inputs = Normalize(
|
||||
config.input_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
@@ -248,7 +256,9 @@ class PI0Policy(PreTrainedPolicy):
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
self.language_tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
|
||||
self.language_tokenizer = AutoTokenizer.from_pretrained(
|
||||
"google/paligemma-3b-pt-224"
|
||||
)
|
||||
self.model = PI0FlowMatching(config)
|
||||
|
||||
self.reset()
|
||||
@@ -261,7 +271,9 @@ class PI0Policy(PreTrainedPolicy):
|
||||
return self.parameters()
|
||||
|
||||
@torch.no_grad
|
||||
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||
def select_action(
|
||||
self, batch: dict[str, Tensor], noise: Tensor | None = None
|
||||
) -> Tensor:
|
||||
"""Select a single action given environment observations.
|
||||
|
||||
This method wraps `select_actions` in order to return one action at a time for execution in the
|
||||
@@ -300,7 +312,9 @@ class PI0Policy(PreTrainedPolicy):
|
||||
self._action_queue.extend(actions.transpose(0, 1))
|
||||
return self._action_queue.popleft()
|
||||
|
||||
def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> tuple[Tensor, dict[str, Tensor]]:
|
||||
def forward(
|
||||
self, batch: dict[str, Tensor], noise=None, time=None
|
||||
) -> tuple[Tensor, dict[str, Tensor]]:
|
||||
"""Do a full training forward pass to compute the loss"""
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
|
||||
@@ -316,7 +330,9 @@ class PI0Policy(PreTrainedPolicy):
|
||||
actions_is_pad = batch.get("action_is_pad")
|
||||
|
||||
loss_dict = {}
|
||||
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time)
|
||||
losses = self.model.forward(
|
||||
images, img_masks, lang_tokens, lang_masks, state, actions, noise, time
|
||||
)
|
||||
loss_dict["losses_after_forward"] = losses.clone()
|
||||
|
||||
if actions_is_pad is not None:
|
||||
@@ -343,7 +359,9 @@ class PI0Policy(PreTrainedPolicy):
|
||||
img_masks = []
|
||||
|
||||
present_img_keys = [key for key in self.config.image_features if key in batch]
|
||||
missing_img_keys = [key for key in self.config.image_features if key not in batch]
|
||||
missing_img_keys = [
|
||||
key for key in self.config.image_features if key not in batch
|
||||
]
|
||||
|
||||
if len(present_img_keys) == 0:
|
||||
raise ValueError(
|
||||
@@ -355,7 +373,9 @@ class PI0Policy(PreTrainedPolicy):
|
||||
img = batch[key]
|
||||
|
||||
if self.config.resize_imgs_with_padding is not None:
|
||||
img = resize_with_pad(img, *self.config.resize_imgs_with_padding, pad_value=0)
|
||||
img = resize_with_pad(
|
||||
img, *self.config.resize_imgs_with_padding, pad_value=0
|
||||
)
|
||||
|
||||
# Normalize from range [0,1] to [-1,1] as expacted by siglip
|
||||
img = img * 2.0 - 1.0
|
||||
@@ -394,7 +414,9 @@ class PI0Policy(PreTrainedPolicy):
|
||||
return_tensors="pt",
|
||||
)
|
||||
lang_tokens = tokenized_prompt["input_ids"].to(device=device)
|
||||
lang_masks = tokenized_prompt["attention_mask"].to(device=device, dtype=torch.bool)
|
||||
lang_masks = tokenized_prompt["attention_mask"].to(
|
||||
device=device, dtype=torch.bool
|
||||
)
|
||||
|
||||
return lang_tokens, lang_masks
|
||||
|
||||
@@ -413,7 +435,9 @@ class PI0Policy(PreTrainedPolicy):
|
||||
actions[:, :, motor_idx] *= -1
|
||||
# Reverse the gripper transformation that is being applied by the Aloha runtime.
|
||||
for motor_idx in [6, 13]:
|
||||
actions[:, :, motor_idx] = aloha_gripper_from_angular(actions[:, :, motor_idx])
|
||||
actions[:, :, motor_idx] = aloha_gripper_from_angular(
|
||||
actions[:, :, motor_idx]
|
||||
)
|
||||
return actions
|
||||
|
||||
def _pi_aloha_encode_actions_inv(self, actions):
|
||||
@@ -422,7 +446,9 @@ class PI0Policy(PreTrainedPolicy):
|
||||
actions[:, :, motor_idx] *= -1
|
||||
# Reverse the gripper transformation that is being applied by the Aloha runtime.
|
||||
for motor_idx in [6, 13]:
|
||||
actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(actions[:, :, motor_idx])
|
||||
actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(
|
||||
actions[:, :, motor_idx]
|
||||
)
|
||||
return actions
|
||||
|
||||
def prepare_state(self, batch):
|
||||
@@ -472,15 +498,25 @@ class PI0FlowMatching(nn.Module):
|
||||
train_expert_only=self.config.train_expert_only,
|
||||
attention_implementation=self.config.attention_implementation,
|
||||
)
|
||||
self.paligemma_with_expert = PaliGemmaWithExpertModel(paligemma_with_export_config)
|
||||
self.paligemma_with_expert = PaliGemmaWithExpertModel(
|
||||
paligemma_with_export_config
|
||||
)
|
||||
|
||||
# Projections are float32
|
||||
self.state_proj = nn.Linear(self.config.max_state_dim, self.config.proj_width)
|
||||
self.action_in_proj = nn.Linear(self.config.max_action_dim, self.config.proj_width)
|
||||
self.action_out_proj = nn.Linear(self.config.proj_width, self.config.max_action_dim)
|
||||
self.action_in_proj = nn.Linear(
|
||||
self.config.max_action_dim, self.config.proj_width
|
||||
)
|
||||
self.action_out_proj = nn.Linear(
|
||||
self.config.proj_width, self.config.max_action_dim
|
||||
)
|
||||
|
||||
self.action_time_mlp_in = nn.Linear(self.config.proj_width * 2, self.config.proj_width)
|
||||
self.action_time_mlp_out = nn.Linear(self.config.proj_width, self.config.proj_width)
|
||||
self.action_time_mlp_in = nn.Linear(
|
||||
self.config.proj_width * 2, self.config.proj_width
|
||||
)
|
||||
self.action_time_mlp_out = nn.Linear(
|
||||
self.config.proj_width, self.config.proj_width
|
||||
)
|
||||
|
||||
self.set_requires_grad()
|
||||
|
||||
@@ -524,7 +560,9 @@ class PI0FlowMatching(nn.Module):
|
||||
|
||||
# Normalize image embeddings
|
||||
img_emb_dim = img_emb.shape[-1]
|
||||
img_emb = img_emb * torch.tensor(img_emb_dim**0.5, dtype=img_emb.dtype, device=img_emb.device)
|
||||
img_emb = img_emb * torch.tensor(
|
||||
img_emb_dim**0.5, dtype=img_emb.dtype, device=img_emb.device
|
||||
)
|
||||
|
||||
bsize, num_img_embs = img_emb.shape[:2]
|
||||
img_mask = img_mask[:, None].expand(bsize, num_img_embs)
|
||||
@@ -577,7 +615,11 @@ class PI0FlowMatching(nn.Module):
|
||||
|
||||
# Embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1]
|
||||
time_emb = create_sinusoidal_pos_embedding(
|
||||
timestep, self.config.proj_width, min_period=4e-3, max_period=4.0, device=device
|
||||
timestep,
|
||||
self.config.proj_width,
|
||||
min_period=4e-3,
|
||||
max_period=4.0,
|
||||
device=device,
|
||||
)
|
||||
time_emb = time_emb.type(dtype=dtype)
|
||||
|
||||
@@ -595,7 +637,9 @@ class PI0FlowMatching(nn.Module):
|
||||
embs.append(action_time_emb)
|
||||
|
||||
bsize, action_time_dim = action_time_emb.shape[:2]
|
||||
action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=device)
|
||||
action_time_mask = torch.ones(
|
||||
bsize, action_time_dim, dtype=torch.bool, device=device
|
||||
)
|
||||
pad_masks.append(action_time_mask)
|
||||
|
||||
# Set attention masks so that image, language and state inputs do not attend to action tokens
|
||||
@@ -609,7 +653,15 @@ class PI0FlowMatching(nn.Module):
|
||||
return embs, pad_masks, att_masks
|
||||
|
||||
def forward(
|
||||
self, images, img_masks, lang_tokens, lang_masks, state, actions, noise=None, time=None
|
||||
self,
|
||||
images,
|
||||
img_masks,
|
||||
lang_tokens,
|
||||
lang_masks,
|
||||
state,
|
||||
actions,
|
||||
noise=None,
|
||||
time=None,
|
||||
) -> Tensor:
|
||||
"""Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)"""
|
||||
if noise is None:
|
||||
@@ -625,7 +677,9 @@ class PI0FlowMatching(nn.Module):
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
|
||||
images, img_masks, lang_tokens, lang_masks
|
||||
)
|
||||
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(state, x_t, time)
|
||||
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(
|
||||
state, x_t, time
|
||||
)
|
||||
|
||||
pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
|
||||
att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
|
||||
@@ -649,13 +703,19 @@ class PI0FlowMatching(nn.Module):
|
||||
losses = F.mse_loss(u_t, v_t, reduction="none")
|
||||
return losses
|
||||
|
||||
def sample_actions(self, images, img_masks, lang_tokens, lang_masks, state, noise=None) -> Tensor:
|
||||
def sample_actions(
|
||||
self, images, img_masks, lang_tokens, lang_masks, state, noise=None
|
||||
) -> Tensor:
|
||||
"""Do a full inference forward and compute the action (batch_size x num_steps x num_motors)"""
|
||||
bsize = state.shape[0]
|
||||
device = state.device
|
||||
|
||||
if noise is None:
|
||||
actions_shape = (bsize, self.config.n_action_steps, self.config.max_action_dim)
|
||||
actions_shape = (
|
||||
bsize,
|
||||
self.config.n_action_steps,
|
||||
self.config.max_action_dim,
|
||||
)
|
||||
noise = self.sample_noise(actions_shape, device)
|
||||
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
|
||||
@@ -703,12 +763,16 @@ class PI0FlowMatching(nn.Module):
|
||||
timestep,
|
||||
):
|
||||
"""Apply one denoising step of the noise `x_t` at a given timestep."""
|
||||
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(state, x_t, timestep)
|
||||
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(
|
||||
state, x_t, timestep
|
||||
)
|
||||
|
||||
suffix_len = suffix_pad_masks.shape[1]
|
||||
batch_size = prefix_pad_masks.shape[0]
|
||||
prefix_len = prefix_pad_masks.shape[1]
|
||||
prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len)
|
||||
prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(
|
||||
batch_size, suffix_len, prefix_len
|
||||
)
|
||||
|
||||
suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks)
|
||||
|
||||
|
||||
@@ -39,9 +39,13 @@ def apply_rope(x, positions, max_wavelength=10_000):
|
||||
dtype = x.dtype
|
||||
x = x.to(torch.float32)
|
||||
|
||||
freq_exponents = (2.0 / x.shape[-1]) * torch.arange(d_half, dtype=torch.float32, device=device)
|
||||
freq_exponents = (2.0 / x.shape[-1]) * torch.arange(
|
||||
d_half, dtype=torch.float32, device=device
|
||||
)
|
||||
timescale = max_wavelength**freq_exponents
|
||||
radians = positions[..., None].to(torch.float32) / timescale[None, None, :].to(torch.float32)
|
||||
radians = positions[..., None].to(torch.float32) / timescale[None, None, :].to(
|
||||
torch.float32
|
||||
)
|
||||
|
||||
radians = radians[..., None, :]
|
||||
|
||||
@@ -174,7 +178,9 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
|
||||
def __init__(self, config: PaliGemmaWithExpertConfig):
|
||||
super().__init__(config=config)
|
||||
self.config = config
|
||||
self.paligemma = PaliGemmaForConditionalGeneration(config=config.paligemma_config)
|
||||
self.paligemma = PaliGemmaForConditionalGeneration(
|
||||
config=config.paligemma_config
|
||||
)
|
||||
self.gemma_expert = GemmaForCausalLM(config=config.gemma_expert_config)
|
||||
# Remove unused embed_tokens
|
||||
self.gemma_expert.model.embed_tokens = None
|
||||
@@ -291,14 +297,22 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
|
||||
# so we create an empty cache, with just one cuda malloc, and if (in autoregressive case) we reach
|
||||
# the max len, then we (for instance) double the cache size. This implementation already exists
|
||||
# in `transformers`. (molbap)
|
||||
key_states = torch.cat([past_key_values[layer_idx]["key_states"], key_states], dim=1)
|
||||
key_states = torch.cat(
|
||||
[past_key_values[layer_idx]["key_states"], key_states], dim=1
|
||||
)
|
||||
value_states = torch.cat(
|
||||
[past_key_values[layer_idx]["value_states"], value_states], dim=1
|
||||
[past_key_values[layer_idx]["value_states"], value_states],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
attention_interface = self.get_attention_interface()
|
||||
att_output = attention_interface(
|
||||
attention_mask, batch_size, head_dim, query_states, key_states, value_states
|
||||
attention_mask,
|
||||
batch_size,
|
||||
head_dim,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
)
|
||||
att_output = att_output.to(dtype=torch.bfloat16)
|
||||
|
||||
@@ -358,15 +372,29 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
|
||||
return attention_interface
|
||||
|
||||
def flash_attention_forward(
|
||||
self, attention_mask, batch_size, head_dim, query_states, key_states, value_states
|
||||
self,
|
||||
attention_mask,
|
||||
batch_size,
|
||||
head_dim,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
):
|
||||
raise NotImplementedError("FA2 is not implemented (yet)")
|
||||
|
||||
def eager_attention_forward(
|
||||
self, attention_mask, batch_size, head_dim, query_states, key_states, value_states
|
||||
self,
|
||||
attention_mask,
|
||||
batch_size,
|
||||
head_dim,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
):
|
||||
num_att_heads = self.config.paligemma_config.text_config.num_attention_heads
|
||||
num_key_value_heads = self.config.paligemma_config.text_config.num_key_value_heads
|
||||
num_key_value_heads = (
|
||||
self.config.paligemma_config.text_config.num_key_value_heads
|
||||
)
|
||||
num_key_value_groups = num_att_heads // num_key_value_heads
|
||||
|
||||
# query_states: batch_size, sequence_length, num_att_head, head_dim
|
||||
@@ -375,17 +403,31 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
|
||||
sequence_length = key_states.shape[1]
|
||||
|
||||
key_states = key_states[:, :, :, None, :].expand(
|
||||
batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim
|
||||
batch_size,
|
||||
sequence_length,
|
||||
num_key_value_heads,
|
||||
num_key_value_groups,
|
||||
head_dim,
|
||||
)
|
||||
key_states = key_states.reshape(
|
||||
batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim
|
||||
batch_size,
|
||||
sequence_length,
|
||||
num_key_value_heads * num_key_value_groups,
|
||||
head_dim,
|
||||
)
|
||||
|
||||
value_states = value_states[:, :, :, None, :].expand(
|
||||
batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim
|
||||
batch_size,
|
||||
sequence_length,
|
||||
num_key_value_heads,
|
||||
num_key_value_groups,
|
||||
head_dim,
|
||||
)
|
||||
value_states = value_states.reshape(
|
||||
batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim
|
||||
batch_size,
|
||||
sequence_length,
|
||||
num_key_value_heads * num_key_value_groups,
|
||||
head_dim,
|
||||
)
|
||||
|
||||
# Attention here is upcasted to float32 to match the original eager implementation.
|
||||
@@ -400,7 +442,9 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
|
||||
att_weights *= head_dim**-0.5
|
||||
big_neg = -2.3819763e38 # See gemma/modules.py
|
||||
|
||||
masked_att_weights = torch.where(attention_mask[:, None, :, :], att_weights, big_neg)
|
||||
masked_att_weights = torch.where(
|
||||
attention_mask[:, None, :, :], att_weights, big_neg
|
||||
)
|
||||
|
||||
probs = nn.functional.softmax(masked_att_weights, dim=-1)
|
||||
probs = probs.to(dtype=value_states.dtype)
|
||||
@@ -412,6 +456,8 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
|
||||
|
||||
att_output = att_output.permute(0, 2, 1, 3)
|
||||
# we use -1 because sequence length can change
|
||||
att_output = att_output.reshape(batch_size, -1, num_key_value_heads * num_key_value_groups * head_dim)
|
||||
att_output = att_output.reshape(
|
||||
batch_size, -1, num_key_value_heads * num_key_value_groups * head_dim
|
||||
)
|
||||
|
||||
return att_output
|
||||
|
||||
@@ -71,7 +71,9 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
||||
def _save_pretrained(self, save_directory: Path) -> None:
|
||||
self.config._save_pretrained(save_directory)
|
||||
model_to_save = self.module if hasattr(self, "module") else self
|
||||
save_model_as_safetensor(model_to_save, str(save_directory / SAFETENSORS_SINGLE_FILE))
|
||||
save_model_as_safetensor(
|
||||
model_to_save, str(save_directory / SAFETENSORS_SINGLE_FILE)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
@@ -110,7 +112,9 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
||||
if os.path.isdir(model_id):
|
||||
print("Loading weights from local directory")
|
||||
model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE)
|
||||
policy = cls._load_as_safetensor(instance, model_file, config.device, strict)
|
||||
policy = cls._load_as_safetensor(
|
||||
instance, model_file, config.device, strict
|
||||
)
|
||||
else:
|
||||
try:
|
||||
model_file = hf_hub_download(
|
||||
@@ -124,7 +128,9 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
||||
token=token,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
policy = cls._load_as_safetensor(instance, model_file, config.device, strict)
|
||||
policy = cls._load_as_safetensor(
|
||||
instance, model_file, config.device, strict
|
||||
)
|
||||
except HfHubHTTPError as e:
|
||||
raise FileNotFoundError(
|
||||
f"{SAFETENSORS_SINGLE_FILE} not found on the HuggingFace Hub in {model_id}"
|
||||
@@ -135,8 +141,12 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
||||
return policy
|
||||
|
||||
@classmethod
|
||||
def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
|
||||
if packaging.version.parse(safetensors.__version__) < packaging.version.parse("0.4.3"):
|
||||
def _load_as_safetensor(
|
||||
cls, model: T, model_file: str, map_location: str, strict: bool
|
||||
) -> T:
|
||||
if packaging.version.parse(safetensors.__version__) < packaging.version.parse(
|
||||
"0.4.3"
|
||||
):
|
||||
load_model_as_safetensor(model, model_file, strict=strict)
|
||||
if map_location != "cpu":
|
||||
logging.warning(
|
||||
@@ -147,7 +157,9 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
||||
)
|
||||
model.to(map_location)
|
||||
else:
|
||||
safetensors.torch.load_model(model, model_file, strict=strict, device=map_location)
|
||||
safetensors.torch.load_model(
|
||||
model, model_file, strict=strict, device=map_location
|
||||
)
|
||||
return model
|
||||
|
||||
# def generate_model_card(self, *args, **kwargs) -> ModelCard:
|
||||
|
||||
@@ -639,9 +639,9 @@ class Policy(nn.Module):
|
||||
# Compute standard deviations
|
||||
if self.fixed_std is None:
|
||||
log_std = self.std_layer(outputs)
|
||||
assert not torch.isnan(
|
||||
log_std
|
||||
).any(), "[ERROR] log_std became NaN after std_layer!"
|
||||
assert not torch.isnan(log_std).any(), (
|
||||
"[ERROR] log_std became NaN after std_layer!"
|
||||
)
|
||||
|
||||
if self.use_tanh_squash:
|
||||
log_std = torch.tanh(log_std)
|
||||
|
||||
@@ -187,7 +187,9 @@ class TDMPCConfig(PreTrainedConfig):
|
||||
"If `n_action_steps > 1`, `use_mpc` must be set to `True`."
|
||||
)
|
||||
if self.n_action_steps > self.horizon:
|
||||
raise ValueError("`n_action_steps` must be less than or equal to `horizon`.")
|
||||
raise ValueError(
|
||||
"`n_action_steps` must be less than or equal to `horizon`."
|
||||
)
|
||||
|
||||
def get_optimizer_preset(self) -> AdamConfig:
|
||||
return AdamConfig(lr=self.optimizer_lr)
|
||||
@@ -207,7 +209,9 @@ class TDMPCConfig(PreTrainedConfig):
|
||||
if image_ft.shape[-2] != image_ft.shape[-1]:
|
||||
# TODO(alexander-soare): This limitation is solely because of code in the random shift
|
||||
# augmentation. It should be able to be removed.
|
||||
raise ValueError(f"Only square images are handled now. Got image shape {image_ft.shape}.")
|
||||
raise ValueError(
|
||||
f"Only square images are handled now. Got image shape {image_ft.shape}."
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list:
|
||||
|
||||
@@ -39,7 +39,11 @@ from lerobot.common.constants import OBS_ENV, OBS_ROBOT
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.common.policies.utils import get_device_from_parameters, get_output_shape, populate_queues
|
||||
from lerobot.common.policies.utils import (
|
||||
get_device_from_parameters,
|
||||
get_output_shape,
|
||||
populate_queues,
|
||||
)
|
||||
|
||||
|
||||
class TDMPCPolicy(PreTrainedPolicy):
|
||||
@@ -63,7 +67,11 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
config_class = TDMPCConfig
|
||||
name = "tdmpc"
|
||||
|
||||
def __init__(self, config: TDMPCConfig, dataset_stats: dict[str, dict[str, Tensor]] | None = None):
|
||||
def __init__(
|
||||
self,
|
||||
config: TDMPCConfig,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
config: Policy configuration class instance or None, in which case the default instantiation of
|
||||
@@ -75,7 +83,9 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_inputs = Normalize(
|
||||
config.input_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
@@ -117,7 +127,9 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
"""Select a single action given environment observations."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch = dict(
|
||||
batch
|
||||
) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.image"] = batch[next(iter(self.config.image_features))]
|
||||
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
@@ -201,7 +213,10 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
# algorithm.
|
||||
# The initial mean and standard deviation for the cross-entropy method (CEM).
|
||||
mean = torch.zeros(
|
||||
self.config.horizon, batch_size, self.config.action_feature.shape[0], device=device
|
||||
self.config.horizon,
|
||||
batch_size,
|
||||
self.config.action_feature.shape[0],
|
||||
device=device,
|
||||
)
|
||||
# Maybe warm start CEM with the mean from the previous step.
|
||||
if self._prev_mean is not None:
|
||||
@@ -339,7 +354,9 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch = dict(
|
||||
batch
|
||||
) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.image"] = batch[next(iter(self.config.image_features))]
|
||||
batch = self.normalize_targets(batch)
|
||||
|
||||
@@ -371,7 +388,9 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
current_observation[k] = observations[k][0]
|
||||
next_observations[k] = observations[k][1:]
|
||||
horizon, batch_size = next_observations[
|
||||
"observation.image" if self.config.image_features else "observation.environment_state"
|
||||
"observation.image"
|
||||
if self.config.image_features
|
||||
else "observation.environment_state"
|
||||
].shape[:2]
|
||||
|
||||
# Run latent rollout using the latent dynamics model and policy model.
|
||||
@@ -569,7 +588,9 @@ class TDMPCTOLD(nn.Module):
|
||||
self.config = config
|
||||
self._encoder = TDMPCObservationEncoder(config)
|
||||
self._dynamics = nn.Sequential(
|
||||
nn.Linear(config.latent_dim + config.action_feature.shape[0], config.mlp_dim),
|
||||
nn.Linear(
|
||||
config.latent_dim + config.action_feature.shape[0], config.mlp_dim
|
||||
),
|
||||
nn.LayerNorm(config.mlp_dim),
|
||||
nn.Mish(),
|
||||
nn.Linear(config.mlp_dim, config.mlp_dim),
|
||||
@@ -580,7 +601,9 @@ class TDMPCTOLD(nn.Module):
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
self._reward = nn.Sequential(
|
||||
nn.Linear(config.latent_dim + config.action_feature.shape[0], config.mlp_dim),
|
||||
nn.Linear(
|
||||
config.latent_dim + config.action_feature.shape[0], config.mlp_dim
|
||||
),
|
||||
nn.LayerNorm(config.mlp_dim),
|
||||
nn.Mish(),
|
||||
nn.Linear(config.mlp_dim, config.mlp_dim),
|
||||
@@ -600,7 +623,10 @@ class TDMPCTOLD(nn.Module):
|
||||
self._Qs = nn.ModuleList(
|
||||
[
|
||||
nn.Sequential(
|
||||
nn.Linear(config.latent_dim + config.action_feature.shape[0], config.mlp_dim),
|
||||
nn.Linear(
|
||||
config.latent_dim + config.action_feature.shape[0],
|
||||
config.mlp_dim,
|
||||
),
|
||||
nn.LayerNorm(config.mlp_dim),
|
||||
nn.Tanh(),
|
||||
nn.Linear(config.mlp_dim, config.mlp_dim),
|
||||
@@ -786,7 +812,9 @@ class TDMPCObservationEncoder(nn.Module):
|
||||
|
||||
if config.robot_state_feature:
|
||||
self.state_enc_layers = nn.Sequential(
|
||||
nn.Linear(config.robot_state_feature.shape[0], config.state_encoder_hidden_dim),
|
||||
nn.Linear(
|
||||
config.robot_state_feature.shape[0], config.state_encoder_hidden_dim
|
||||
),
|
||||
nn.ELU(),
|
||||
nn.Linear(config.state_encoder_hidden_dim, config.latent_dim),
|
||||
nn.LayerNorm(config.latent_dim),
|
||||
@@ -795,7 +823,9 @@ class TDMPCObservationEncoder(nn.Module):
|
||||
|
||||
if config.env_state_feature:
|
||||
self.env_state_enc_layers = nn.Sequential(
|
||||
nn.Linear(config.env_state_feature.shape[0], config.state_encoder_hidden_dim),
|
||||
nn.Linear(
|
||||
config.env_state_feature.shape[0], config.state_encoder_hidden_dim
|
||||
),
|
||||
nn.ELU(),
|
||||
nn.Linear(config.state_encoder_hidden_dim, config.latent_dim),
|
||||
nn.LayerNorm(config.latent_dim),
|
||||
@@ -813,7 +843,8 @@ class TDMPCObservationEncoder(nn.Module):
|
||||
if self.config.image_features:
|
||||
feat.append(
|
||||
flatten_forward_unflatten(
|
||||
self.image_enc_layers, obs_dict[next(iter(self.config.image_features))]
|
||||
self.image_enc_layers,
|
||||
obs_dict[next(iter(self.config.image_features))],
|
||||
)
|
||||
)
|
||||
if self.config.env_state_feature:
|
||||
|
||||
@@ -172,7 +172,10 @@ class VQBeTConfig(PreTrainedConfig):
|
||||
|
||||
if self.crop_shape is not None:
|
||||
for key, image_ft in self.image_features.items():
|
||||
if self.crop_shape[0] > image_ft.shape[1] or self.crop_shape[1] > image_ft.shape[2]:
|
||||
if (
|
||||
self.crop_shape[0] > image_ft.shape[1]
|
||||
or self.crop_shape[1] > image_ft.shape[2]
|
||||
):
|
||||
raise ValueError(
|
||||
f"`crop_shape` should fit within the images shapes. Got {self.crop_shape} "
|
||||
f"for `crop_shape` and {image_ft.shape} for "
|
||||
@@ -193,7 +196,12 @@ class VQBeTConfig(PreTrainedConfig):
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list:
|
||||
return list(range(1 - self.n_obs_steps, self.n_action_pred_token + self.action_chunk_size - 1))
|
||||
return list(
|
||||
range(
|
||||
1 - self.n_obs_steps,
|
||||
self.n_action_pred_token + self.action_chunk_size - 1,
|
||||
)
|
||||
)
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
|
||||
@@ -29,7 +29,11 @@ from torch import Tensor, nn
|
||||
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.policies.utils import get_device_from_parameters, get_output_shape, populate_queues
|
||||
from lerobot.common.policies.utils import (
|
||||
get_device_from_parameters,
|
||||
get_output_shape,
|
||||
populate_queues,
|
||||
)
|
||||
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
from lerobot.common.policies.vqbet.vqbet_utils import GPT, ResidualVQ
|
||||
|
||||
@@ -60,7 +64,9 @@ class VQBeTPolicy(PreTrainedPolicy):
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_inputs = Normalize(
|
||||
config.input_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
@@ -91,11 +97,17 @@ class VQBeTPolicy(PreTrainedPolicy):
|
||||
if self.config.sequentially_select:
|
||||
decay_params = (
|
||||
decay_params
|
||||
+ list(self.vqbet.action_head.map_to_cbet_preds_primary_bin.parameters())
|
||||
+ list(self.vqbet.action_head.map_to_cbet_preds_secondary_bin.parameters())
|
||||
+ list(
|
||||
self.vqbet.action_head.map_to_cbet_preds_primary_bin.parameters()
|
||||
)
|
||||
+ list(
|
||||
self.vqbet.action_head.map_to_cbet_preds_secondary_bin.parameters()
|
||||
)
|
||||
)
|
||||
else:
|
||||
decay_params = decay_params + list(self.vqbet.action_head.map_to_cbet_preds_bin.parameters())
|
||||
decay_params = decay_params + list(
|
||||
self.vqbet.action_head.map_to_cbet_preds_bin.parameters()
|
||||
)
|
||||
|
||||
return [
|
||||
{
|
||||
@@ -133,8 +145,12 @@ class VQBeTPolicy(PreTrainedPolicy):
|
||||
"""
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
||||
batch = dict(
|
||||
batch
|
||||
) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack(
|
||||
[batch[key] for key in self.config.image_features], dim=-4
|
||||
)
|
||||
# Note: It's important that this happens after stacking the images into a single key.
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
|
||||
@@ -165,8 +181,12 @@ class VQBeTPolicy(PreTrainedPolicy):
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
||||
batch = dict(
|
||||
batch
|
||||
) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack(
|
||||
[batch[key] for key in self.config.image_features], dim=-4
|
||||
)
|
||||
batch = self.normalize_targets(batch)
|
||||
# VQ-BeT discretizes action using VQ-VAE before training BeT (please refer to section 3.2 in the VQ-BeT paper https://arxiv.org/pdf/2403.03181)
|
||||
if not self.vqbet.action_head.vqvae_model.discretized.item():
|
||||
@@ -334,7 +354,8 @@ class VQBeTModel(nn.Module):
|
||||
|
||||
# To input state and observation features into GPT layers, we first project the features to fit the shape of input size of GPT.
|
||||
self.state_projector = MLP(
|
||||
config.robot_state_feature.shape[0], hidden_channels=[self.config.gpt_input_dim]
|
||||
config.robot_state_feature.shape[0],
|
||||
hidden_channels=[self.config.gpt_input_dim],
|
||||
)
|
||||
self.rgb_feature_projector = MLP(
|
||||
self.rgb_encoder.feature_dim, hidden_channels=[self.config.gpt_input_dim]
|
||||
@@ -406,9 +427,9 @@ class VQBeTModel(nn.Module):
|
||||
features = self.policy(input_tokens)
|
||||
# len(self.config.input_features) is the number of different observation modes.
|
||||
# this line gets the index of action prompt tokens.
|
||||
historical_act_pred_index = np.arange(0, n_obs_steps) * (len(self.config.input_features) + 1) + len(
|
||||
self.config.input_features
|
||||
)
|
||||
historical_act_pred_index = np.arange(0, n_obs_steps) * (
|
||||
len(self.config.input_features) + 1
|
||||
) + len(self.config.input_features)
|
||||
|
||||
# only extract the output tokens at the position of action query:
|
||||
# Behavior Transformer (BeT), and VQ-BeT are both sequence-to-sequence prediction models,
|
||||
@@ -771,11 +792,15 @@ class VQBeTRgbEncoder(nn.Module):
|
||||
# height and width from `config.image_features`.
|
||||
|
||||
images_shape = next(iter(config.image_features.values())).shape
|
||||
dummy_shape_h_w = config.crop_shape if config.crop_shape is not None else images_shape[1:]
|
||||
dummy_shape_h_w = (
|
||||
config.crop_shape if config.crop_shape is not None else images_shape[1:]
|
||||
)
|
||||
dummy_shape = (1, images_shape[0], *dummy_shape_h_w)
|
||||
feature_map_shape = get_output_shape(self.backbone, dummy_shape)[1:]
|
||||
|
||||
self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints)
|
||||
self.pool = SpatialSoftmax(
|
||||
feature_map_shape, num_kp=config.spatial_softmax_num_keypoints
|
||||
)
|
||||
self.feature_dim = config.spatial_softmax_num_keypoints * 2
|
||||
self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim)
|
||||
self.relu = nn.ReLU()
|
||||
@@ -871,7 +896,8 @@ class VqVae(nn.Module):
|
||||
)
|
||||
|
||||
self.encoder = MLP(
|
||||
in_channels=self.config.action_feature.shape[0] * self.config.action_chunk_size,
|
||||
in_channels=self.config.action_feature.shape[0]
|
||||
* self.config.action_chunk_size,
|
||||
hidden_channels=[
|
||||
config.vqvae_enc_hidden_dim,
|
||||
config.vqvae_enc_hidden_dim,
|
||||
@@ -899,9 +925,13 @@ class VqVae(nn.Module):
|
||||
# given latent vector, this function outputs the decoded action.
|
||||
output = self.decoder(latent)
|
||||
if self.config.action_chunk_size == 1:
|
||||
return einops.rearrange(output, "N (T A) -> N T A", A=self.config.action_feature.shape[0])
|
||||
return einops.rearrange(
|
||||
output, "N (T A) -> N T A", A=self.config.action_feature.shape[0]
|
||||
)
|
||||
else:
|
||||
return einops.rearrange(output, "N (T A) -> N T A", A=self.config.action_feature.shape[0])
|
||||
return einops.rearrange(
|
||||
output, "N (T A) -> N T A", A=self.config.action_feature.shape[0]
|
||||
)
|
||||
|
||||
def get_code(self, state):
|
||||
# in phase 2 of VQ-BeT training, we need a `ground truth labels of action data` to calculate the Focal loss for code prediction head. (please refer to section 3.3 in the paper https://arxiv.org/pdf/2403.03181)
|
||||
|
||||
@@ -290,10 +290,10 @@ class GPT(nn.Module):
|
||||
param_dict = dict(self.named_parameters())
|
||||
inter_params = decay & no_decay
|
||||
union_params = decay | no_decay
|
||||
assert (
|
||||
len(inter_params) == 0
|
||||
), "parameters {} made it into both decay/no_decay sets!".format(
|
||||
str(inter_params)
|
||||
assert len(inter_params) == 0, (
|
||||
"parameters {} made it into both decay/no_decay sets!".format(
|
||||
str(inter_params)
|
||||
)
|
||||
)
|
||||
assert len(param_dict.keys() - union_params) == 0, (
|
||||
"parameters {} were not separated into either decay/no_decay set!".format(
|
||||
@@ -664,14 +664,14 @@ class VectorQuantize(nn.Module):
|
||||
self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only
|
||||
self.orthogonal_reg_max_codes = orthogonal_reg_max_codes
|
||||
|
||||
assert not (
|
||||
ema_update and learnable_codebook
|
||||
), "learnable codebook not compatible with EMA update"
|
||||
assert not (ema_update and learnable_codebook), (
|
||||
"learnable codebook not compatible with EMA update"
|
||||
)
|
||||
|
||||
assert 0 <= sync_update_v <= 1.0
|
||||
assert not (
|
||||
sync_update_v > 0.0 and not learnable_codebook
|
||||
), "learnable codebook must be turned on"
|
||||
assert not (sync_update_v > 0.0 and not learnable_codebook), (
|
||||
"learnable codebook must be turned on"
|
||||
)
|
||||
|
||||
self.sync_update_v = sync_update_v
|
||||
|
||||
|
||||
Reference in New Issue
Block a user