[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot]
2025-03-24 13:16:38 +00:00
committed by AdilZouitine
parent 761a2dbcb3
commit 8e6d5f504c
97 changed files with 1596 additions and 492 deletions

View File

@@ -171,7 +171,9 @@ class ACTConfig(PreTrainedConfig):
def validate_features(self) -> None:
if not self.image_features and not self.env_state_feature:
raise ValueError("You must provide at least one image or the environment state among the inputs.")
raise ValueError(
"You must provide at least one image or the environment state among the inputs."
)
@property
def observation_delta_indices(self) -> None:

View File

@@ -63,7 +63,9 @@ class ACTPolicy(PreTrainedPolicy):
config.validate_features()
self.config = config
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_inputs = Normalize(
config.input_features, config.normalization_mapping, dataset_stats
)
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats
)
@@ -120,8 +122,12 @@ class ACTPolicy(PreTrainedPolicy):
batch = self.normalize_inputs(batch)
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = [batch[key] for key in self.config.image_features]
batch = dict(
batch
) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = [
batch[key] for key in self.config.image_features
]
# If we are doing temporal ensembling, do online updates where we keep track of the number of actions
# we are ensembling over.
@@ -148,8 +154,12 @@ class ACTPolicy(PreTrainedPolicy):
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = [batch[key] for key in self.config.image_features]
batch = dict(
batch
) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = [
batch[key] for key in self.config.image_features
]
batch = self.normalize_targets(batch)
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
@@ -406,14 +416,18 @@ class ACT(nn.Module):
n_1d_tokens += 1
self.encoder_1d_feature_pos_embed = nn.Embedding(n_1d_tokens, config.dim_model)
if self.config.image_features:
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2)
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(
config.dim_model // 2
)
# Transformer decoder.
# Learnable positional embedding for the transformer's decoder (in the style of DETR object queries).
self.decoder_pos_embed = nn.Embedding(config.chunk_size, config.dim_model)
# Final action regression head on the output of the transformer's decoder.
self.action_head = nn.Linear(config.dim_model, self.config.action_feature.shape[0])
self.action_head = nn.Linear(
config.dim_model, self.config.action_feature.shape[0]
)
self._reset_parameters()
@@ -461,14 +475,20 @@ class ACT(nn.Module):
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
) # (B, 1, D)
if self.config.robot_state_feature:
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"])
robot_state_embed = self.vae_encoder_robot_state_input_proj(
batch["observation.state"]
)
robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D)
action_embed = self.vae_encoder_action_input_proj(
batch["action"]
) # (B, S, D)
if self.config.robot_state_feature:
vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D)
vae_encoder_input = [
cls_embed,
robot_state_embed,
action_embed,
] # (B, S+2, D)
else:
vae_encoder_input = [cls_embed, action_embed]
vae_encoder_input = torch.cat(vae_encoder_input, axis=1)
@@ -517,7 +537,9 @@ class ACT(nn.Module):
)
# Robot state token.
if self.config.robot_state_feature:
encoder_in_tokens.append(self.encoder_robot_state_input_proj(batch["observation.state"]))
encoder_in_tokens.append(
self.encoder_robot_state_input_proj(batch["observation.state"])
)
# Environment state token.
if self.config.env_state_feature:
encoder_in_tokens.append(
@@ -534,7 +556,9 @@ class ACT(nn.Module):
# For a list of images, the H and W may vary but H*W is constant.
for img in batch["observation.images"]:
cam_features = self.backbone(img)["feature_map"]
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(
dtype=cam_features.dtype
)
cam_features = self.encoder_img_feat_input_proj(cam_features)
# Rearrange features to (sequence, batch, dim).

View File

@@ -205,11 +205,16 @@ class DiffusionConfig(PreTrainedConfig):
def validate_features(self) -> None:
if len(self.image_features) == 0 and self.env_state_feature is None:
raise ValueError("You must provide at least one image or the environment state among the inputs.")
raise ValueError(
"You must provide at least one image or the environment state among the inputs."
)
if self.crop_shape is not None:
for key, image_ft in self.image_features.items():
if self.crop_shape[0] > image_ft.shape[1] or self.crop_shape[1] > image_ft.shape[2]:
if (
self.crop_shape[0] > image_ft.shape[1]
or self.crop_shape[1] > image_ft.shape[2]
):
raise ValueError(
f"`crop_shape` should fit within the images shapes. Got {self.crop_shape} "
f"for `crop_shape` and {image_ft.shape} for "

View File

@@ -70,7 +70,9 @@ class DiffusionPolicy(PreTrainedPolicy):
config.validate_features()
self.config = config
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_inputs = Normalize(
config.input_features, config.normalization_mapping, dataset_stats
)
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats
)
@@ -97,7 +99,9 @@ class DiffusionPolicy(PreTrainedPolicy):
if self.config.image_features:
self._queues["observation.images"] = deque(maxlen=self.config.n_obs_steps)
if self.config.env_state_feature:
self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps)
self._queues["observation.environment_state"] = deque(
maxlen=self.config.n_obs_steps
)
@torch.no_grad
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
@@ -123,7 +127,9 @@ class DiffusionPolicy(PreTrainedPolicy):
"""
batch = self.normalize_inputs(batch)
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch = dict(
batch
) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack(
[batch[key] for key in self.config.image_features], dim=-4
)
@@ -151,7 +157,9 @@ class DiffusionPolicy(PreTrainedPolicy):
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch = dict(
batch
) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack(
[batch[key] for key in self.config.image_features], dim=-4
)
@@ -515,11 +523,15 @@ class DiffusionRgbEncoder(nn.Module):
# Note: we have a check in the config class to make sure all images have the same shape.
images_shape = next(iter(config.image_features.values())).shape
dummy_shape_h_w = config.crop_shape if config.crop_shape is not None else images_shape[1:]
dummy_shape_h_w = (
config.crop_shape if config.crop_shape is not None else images_shape[1:]
)
dummy_shape = (1, images_shape[0], *dummy_shape_h_w)
feature_map_shape = get_output_shape(self.backbone, dummy_shape)[1:]
self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints)
self.pool = SpatialSoftmax(
feature_map_shape, num_kp=config.spatial_softmax_num_keypoints
)
self.feature_dim = config.spatial_softmax_num_keypoints * 2
self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim)
self.relu = nn.ReLU()
@@ -719,7 +731,9 @@ class DiffusionConditionalUnet1d(nn.Module):
)
self.final_conv = nn.Sequential(
DiffusionConv1dBlock(config.down_dims[0], config.down_dims[0], kernel_size=config.kernel_size),
DiffusionConv1dBlock(
config.down_dims[0], config.down_dims[0], kernel_size=config.kernel_size
),
nn.Conv1d(config.down_dims[0], config.action_feature.shape[0], 1),
)

View File

@@ -104,7 +104,9 @@ def make_policy(
PreTrainedPolicy: _description_
"""
if bool(ds_meta) == bool(env_cfg):
raise ValueError("Either one of a dataset metadata or a sim env must be provided.")
raise ValueError(
"Either one of a dataset metadata or a sim env must be provided."
)
# NOTE: Currently, if you try to run vqbet with mps backend, you'll get this error.
# TODO(aliberts, rcadene): Implement a check_backend_compatibility in policies?
@@ -134,8 +136,12 @@ def make_policy(
)
features = env_to_policy_features(env_cfg)
cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
cfg.output_features = {
key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION
}
cfg.input_features = {
key: ft for key, ft in features.items() if key not in cfg.output_features
}
kwargs["config"] = cfg
if cfg.pretrained_path:

View File

@@ -82,25 +82,43 @@ def create_stats_buffers(
if stats:
if isinstance(stats[key]["mean"], np.ndarray):
if norm_mode is NormalizationMode.MEAN_STD:
buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32)
buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=torch.float32)
buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(
dtype=torch.float32
)
buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(
dtype=torch.float32
)
elif norm_mode is NormalizationMode.MIN_MAX:
buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32)
buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32)
buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(
dtype=torch.float32
)
buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(
dtype=torch.float32
)
elif isinstance(stats[key]["mean"], torch.Tensor):
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
# tensors anywhere (for example, when we use the same stats for normalization and
# unnormalization). See the logic here
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
if norm_mode is NormalizationMode.MEAN_STD:
buffer["mean"].data = stats[key]["mean"].clone().to(dtype=torch.float32)
buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32)
buffer["mean"].data = (
stats[key]["mean"].clone().to(dtype=torch.float32)
)
buffer["std"].data = (
stats[key]["std"].clone().to(dtype=torch.float32)
)
elif norm_mode is NormalizationMode.MIN_MAX:
buffer["min"].data = stats[key]["min"].clone().to(dtype=torch.float32)
buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32)
buffer["min"].data = (
stats[key]["min"].clone().to(dtype=torch.float32)
)
buffer["max"].data = (
stats[key]["max"].clone().to(dtype=torch.float32)
)
else:
type_ = type(stats[key]["mean"])
raise ValueError(f"np.ndarray or torch.Tensor expected, but type is '{type_}' instead.")
raise ValueError(
f"np.ndarray or torch.Tensor expected, but type is '{type_}' instead."
)
stats_buffers[key] = buffer
return stats_buffers

View File

@@ -44,7 +44,9 @@ def main():
else:
dataset_repo_id = "lerobot/aloha_sim_transfer_cube_human"
ckpt_torch_dir = Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}_pytorch"
ckpt_torch_dir = (
Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}_pytorch"
)
ckpt_jax_dir = Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}"
save_dir = Path(f"../openpi/data/{model_name}/save")
@@ -70,7 +72,9 @@ def main():
# Create LeRobot batch from Jax
batch = {}
for cam_key, uint_chw_array in example["images"].items():
batch[f"observation.images.{cam_key}"] = torch.from_numpy(uint_chw_array) / 255.0
batch[f"observation.images.{cam_key}"] = (
torch.from_numpy(uint_chw_array) / 255.0
)
batch["observation.state"] = torch.from_numpy(example["state"])
batch["action"] = torch.from_numpy(outputs["actions"])
batch["task"] = example["prompt"]

View File

@@ -54,7 +54,9 @@ def get_paligemma_config(precision: str):
"projector_hidden_act": "gelu_fast",
"vision_use_head": False,
}
final_config = PaliGemmaConfig(text_config=text_config, vision_config=vision_config, **config)
final_config = PaliGemmaConfig(
text_config=text_config, vision_config=vision_config, **config
)
return final_config

View File

@@ -61,7 +61,11 @@ from lerobot.common.policies.pi0.conversion_scripts.conversion_utils import (
)
from lerobot.common.policies.pi0.modeling_pi0 import PI0Policy
PRECISIONS = {"bfloat16": torch.bfloat16, "float32": torch.float32, "float16": torch.float16}
PRECISIONS = {
"bfloat16": torch.bfloat16,
"float32": torch.float32,
"float16": torch.float16,
}
def slice_paligemma_state_dict(state_dict, config):
@@ -318,7 +322,9 @@ def update_keys_with_prefix(d: dict, prefix: str) -> dict:
return {f"{prefix}{key}": value for key, value in d.items()}
def convert_pi0_checkpoint(checkpoint_dir: str, precision: str, tokenizer_id: str, output_path: str):
def convert_pi0_checkpoint(
checkpoint_dir: str, precision: str, tokenizer_id: str, output_path: str
):
# Break down orbax ckpts - they are in OCDBT
initial_params = slice_initial_orbax_checkpoint(checkpoint_dir=checkpoint_dir)
# process projection params
@@ -378,7 +384,9 @@ def convert_pi0_checkpoint(checkpoint_dir: str, precision: str, tokenizer_id: st
# gemma_config=gemma_config, paligemma_config=paligemma_config)
pi0_model = PI0Policy(pi0_config)
paligemma_params = update_keys_with_prefix(paligemma_params, "model.paligemma_with_expert.")
paligemma_params = update_keys_with_prefix(
paligemma_params, "model.paligemma_with_expert."
)
gemma_params = update_keys_with_prefix(gemma_params, "model.paligemma_with_expert.")
projection_params = update_keys_with_prefix(projection_params, "model.")

View File

@@ -48,18 +48,32 @@ def flex_attention_forward(
key_states = key_states[:, :, :, None, :]
key_states = key_states.expand(
batch_size, key_states.shape[1], num_key_value_heads, num_key_value_groups, head_dim
batch_size,
key_states.shape[1],
num_key_value_heads,
num_key_value_groups,
head_dim,
)
key_states = key_states.reshape(
batch_size, key_states.shape[1], num_key_value_heads * num_key_value_groups, head_dim
batch_size,
key_states.shape[1],
num_key_value_heads * num_key_value_groups,
head_dim,
)
value_states = value_states[:, :, :, None, :]
value_states = value_states.expand(
batch_size, value_states.shape[1], num_key_value_heads, num_key_value_groups, head_dim
batch_size,
value_states.shape[1],
num_key_value_heads,
num_key_value_groups,
head_dim,
)
value_states = value_states.reshape(
batch_size, value_states.shape[1], num_key_value_heads * num_key_value_groups, head_dim
batch_size,
value_states.shape[1],
num_key_value_heads * num_key_value_groups,
head_dim,
)
query_states = query_states.transpose(1, 2)

View File

@@ -69,7 +69,11 @@ from lerobot.common.utils.utils import get_safe_dtype
def create_sinusoidal_pos_embedding(
time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
time: torch.tensor,
dimension: int,
min_period: float,
max_period: float,
device="cpu",
) -> Tensor:
"""Computes sine-cosine positional embedding vectors for scalar positions."""
if dimension % 2 != 0:
@@ -189,7 +193,9 @@ def aloha_gripper_to_angular(value):
# This is the inverse of the angular to linear transformation inside the Interbotix code.
def linear_to_radian(linear_position, arm_length, horn_radius):
value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position)
value = (horn_radius**2 + linear_position**2 - arm_length**2) / (
2 * horn_radius * linear_position
)
return safe_arcsin(value)
# The constants are taken from the Interbotix code.
@@ -240,7 +246,9 @@ class PI0Policy(PreTrainedPolicy):
super().__init__(config)
config.validate_features()
self.config = config
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_inputs = Normalize(
config.input_features, config.normalization_mapping, dataset_stats
)
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats
)
@@ -248,7 +256,9 @@ class PI0Policy(PreTrainedPolicy):
config.output_features, config.normalization_mapping, dataset_stats
)
self.language_tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
self.language_tokenizer = AutoTokenizer.from_pretrained(
"google/paligemma-3b-pt-224"
)
self.model = PI0FlowMatching(config)
self.reset()
@@ -261,7 +271,9 @@ class PI0Policy(PreTrainedPolicy):
return self.parameters()
@torch.no_grad
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
def select_action(
self, batch: dict[str, Tensor], noise: Tensor | None = None
) -> Tensor:
"""Select a single action given environment observations.
This method wraps `select_actions` in order to return one action at a time for execution in the
@@ -300,7 +312,9 @@ class PI0Policy(PreTrainedPolicy):
self._action_queue.extend(actions.transpose(0, 1))
return self._action_queue.popleft()
def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> tuple[Tensor, dict[str, Tensor]]:
def forward(
self, batch: dict[str, Tensor], noise=None, time=None
) -> tuple[Tensor, dict[str, Tensor]]:
"""Do a full training forward pass to compute the loss"""
if self.config.adapt_to_pi_aloha:
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
@@ -316,7 +330,9 @@ class PI0Policy(PreTrainedPolicy):
actions_is_pad = batch.get("action_is_pad")
loss_dict = {}
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time)
losses = self.model.forward(
images, img_masks, lang_tokens, lang_masks, state, actions, noise, time
)
loss_dict["losses_after_forward"] = losses.clone()
if actions_is_pad is not None:
@@ -343,7 +359,9 @@ class PI0Policy(PreTrainedPolicy):
img_masks = []
present_img_keys = [key for key in self.config.image_features if key in batch]
missing_img_keys = [key for key in self.config.image_features if key not in batch]
missing_img_keys = [
key for key in self.config.image_features if key not in batch
]
if len(present_img_keys) == 0:
raise ValueError(
@@ -355,7 +373,9 @@ class PI0Policy(PreTrainedPolicy):
img = batch[key]
if self.config.resize_imgs_with_padding is not None:
img = resize_with_pad(img, *self.config.resize_imgs_with_padding, pad_value=0)
img = resize_with_pad(
img, *self.config.resize_imgs_with_padding, pad_value=0
)
# Normalize from range [0,1] to [-1,1] as expacted by siglip
img = img * 2.0 - 1.0
@@ -394,7 +414,9 @@ class PI0Policy(PreTrainedPolicy):
return_tensors="pt",
)
lang_tokens = tokenized_prompt["input_ids"].to(device=device)
lang_masks = tokenized_prompt["attention_mask"].to(device=device, dtype=torch.bool)
lang_masks = tokenized_prompt["attention_mask"].to(
device=device, dtype=torch.bool
)
return lang_tokens, lang_masks
@@ -413,7 +435,9 @@ class PI0Policy(PreTrainedPolicy):
actions[:, :, motor_idx] *= -1
# Reverse the gripper transformation that is being applied by the Aloha runtime.
for motor_idx in [6, 13]:
actions[:, :, motor_idx] = aloha_gripper_from_angular(actions[:, :, motor_idx])
actions[:, :, motor_idx] = aloha_gripper_from_angular(
actions[:, :, motor_idx]
)
return actions
def _pi_aloha_encode_actions_inv(self, actions):
@@ -422,7 +446,9 @@ class PI0Policy(PreTrainedPolicy):
actions[:, :, motor_idx] *= -1
# Reverse the gripper transformation that is being applied by the Aloha runtime.
for motor_idx in [6, 13]:
actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(actions[:, :, motor_idx])
actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(
actions[:, :, motor_idx]
)
return actions
def prepare_state(self, batch):
@@ -472,15 +498,25 @@ class PI0FlowMatching(nn.Module):
train_expert_only=self.config.train_expert_only,
attention_implementation=self.config.attention_implementation,
)
self.paligemma_with_expert = PaliGemmaWithExpertModel(paligemma_with_export_config)
self.paligemma_with_expert = PaliGemmaWithExpertModel(
paligemma_with_export_config
)
# Projections are float32
self.state_proj = nn.Linear(self.config.max_state_dim, self.config.proj_width)
self.action_in_proj = nn.Linear(self.config.max_action_dim, self.config.proj_width)
self.action_out_proj = nn.Linear(self.config.proj_width, self.config.max_action_dim)
self.action_in_proj = nn.Linear(
self.config.max_action_dim, self.config.proj_width
)
self.action_out_proj = nn.Linear(
self.config.proj_width, self.config.max_action_dim
)
self.action_time_mlp_in = nn.Linear(self.config.proj_width * 2, self.config.proj_width)
self.action_time_mlp_out = nn.Linear(self.config.proj_width, self.config.proj_width)
self.action_time_mlp_in = nn.Linear(
self.config.proj_width * 2, self.config.proj_width
)
self.action_time_mlp_out = nn.Linear(
self.config.proj_width, self.config.proj_width
)
self.set_requires_grad()
@@ -524,7 +560,9 @@ class PI0FlowMatching(nn.Module):
# Normalize image embeddings
img_emb_dim = img_emb.shape[-1]
img_emb = img_emb * torch.tensor(img_emb_dim**0.5, dtype=img_emb.dtype, device=img_emb.device)
img_emb = img_emb * torch.tensor(
img_emb_dim**0.5, dtype=img_emb.dtype, device=img_emb.device
)
bsize, num_img_embs = img_emb.shape[:2]
img_mask = img_mask[:, None].expand(bsize, num_img_embs)
@@ -577,7 +615,11 @@ class PI0FlowMatching(nn.Module):
# Embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1]
time_emb = create_sinusoidal_pos_embedding(
timestep, self.config.proj_width, min_period=4e-3, max_period=4.0, device=device
timestep,
self.config.proj_width,
min_period=4e-3,
max_period=4.0,
device=device,
)
time_emb = time_emb.type(dtype=dtype)
@@ -595,7 +637,9 @@ class PI0FlowMatching(nn.Module):
embs.append(action_time_emb)
bsize, action_time_dim = action_time_emb.shape[:2]
action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=device)
action_time_mask = torch.ones(
bsize, action_time_dim, dtype=torch.bool, device=device
)
pad_masks.append(action_time_mask)
# Set attention masks so that image, language and state inputs do not attend to action tokens
@@ -609,7 +653,15 @@ class PI0FlowMatching(nn.Module):
return embs, pad_masks, att_masks
def forward(
self, images, img_masks, lang_tokens, lang_masks, state, actions, noise=None, time=None
self,
images,
img_masks,
lang_tokens,
lang_masks,
state,
actions,
noise=None,
time=None,
) -> Tensor:
"""Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)"""
if noise is None:
@@ -625,7 +677,9 @@ class PI0FlowMatching(nn.Module):
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
images, img_masks, lang_tokens, lang_masks
)
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(state, x_t, time)
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(
state, x_t, time
)
pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
@@ -649,13 +703,19 @@ class PI0FlowMatching(nn.Module):
losses = F.mse_loss(u_t, v_t, reduction="none")
return losses
def sample_actions(self, images, img_masks, lang_tokens, lang_masks, state, noise=None) -> Tensor:
def sample_actions(
self, images, img_masks, lang_tokens, lang_masks, state, noise=None
) -> Tensor:
"""Do a full inference forward and compute the action (batch_size x num_steps x num_motors)"""
bsize = state.shape[0]
device = state.device
if noise is None:
actions_shape = (bsize, self.config.n_action_steps, self.config.max_action_dim)
actions_shape = (
bsize,
self.config.n_action_steps,
self.config.max_action_dim,
)
noise = self.sample_noise(actions_shape, device)
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
@@ -703,12 +763,16 @@ class PI0FlowMatching(nn.Module):
timestep,
):
"""Apply one denoising step of the noise `x_t` at a given timestep."""
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(state, x_t, timestep)
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(
state, x_t, timestep
)
suffix_len = suffix_pad_masks.shape[1]
batch_size = prefix_pad_masks.shape[0]
prefix_len = prefix_pad_masks.shape[1]
prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len)
prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(
batch_size, suffix_len, prefix_len
)
suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks)

View File

@@ -39,9 +39,13 @@ def apply_rope(x, positions, max_wavelength=10_000):
dtype = x.dtype
x = x.to(torch.float32)
freq_exponents = (2.0 / x.shape[-1]) * torch.arange(d_half, dtype=torch.float32, device=device)
freq_exponents = (2.0 / x.shape[-1]) * torch.arange(
d_half, dtype=torch.float32, device=device
)
timescale = max_wavelength**freq_exponents
radians = positions[..., None].to(torch.float32) / timescale[None, None, :].to(torch.float32)
radians = positions[..., None].to(torch.float32) / timescale[None, None, :].to(
torch.float32
)
radians = radians[..., None, :]
@@ -174,7 +178,9 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
def __init__(self, config: PaliGemmaWithExpertConfig):
super().__init__(config=config)
self.config = config
self.paligemma = PaliGemmaForConditionalGeneration(config=config.paligemma_config)
self.paligemma = PaliGemmaForConditionalGeneration(
config=config.paligemma_config
)
self.gemma_expert = GemmaForCausalLM(config=config.gemma_expert_config)
# Remove unused embed_tokens
self.gemma_expert.model.embed_tokens = None
@@ -291,14 +297,22 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
# so we create an empty cache, with just one cuda malloc, and if (in autoregressive case) we reach
# the max len, then we (for instance) double the cache size. This implementation already exists
# in `transformers`. (molbap)
key_states = torch.cat([past_key_values[layer_idx]["key_states"], key_states], dim=1)
key_states = torch.cat(
[past_key_values[layer_idx]["key_states"], key_states], dim=1
)
value_states = torch.cat(
[past_key_values[layer_idx]["value_states"], value_states], dim=1
[past_key_values[layer_idx]["value_states"], value_states],
dim=1,
)
attention_interface = self.get_attention_interface()
att_output = attention_interface(
attention_mask, batch_size, head_dim, query_states, key_states, value_states
attention_mask,
batch_size,
head_dim,
query_states,
key_states,
value_states,
)
att_output = att_output.to(dtype=torch.bfloat16)
@@ -358,15 +372,29 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
return attention_interface
def flash_attention_forward(
self, attention_mask, batch_size, head_dim, query_states, key_states, value_states
self,
attention_mask,
batch_size,
head_dim,
query_states,
key_states,
value_states,
):
raise NotImplementedError("FA2 is not implemented (yet)")
def eager_attention_forward(
self, attention_mask, batch_size, head_dim, query_states, key_states, value_states
self,
attention_mask,
batch_size,
head_dim,
query_states,
key_states,
value_states,
):
num_att_heads = self.config.paligemma_config.text_config.num_attention_heads
num_key_value_heads = self.config.paligemma_config.text_config.num_key_value_heads
num_key_value_heads = (
self.config.paligemma_config.text_config.num_key_value_heads
)
num_key_value_groups = num_att_heads // num_key_value_heads
# query_states: batch_size, sequence_length, num_att_head, head_dim
@@ -375,17 +403,31 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
sequence_length = key_states.shape[1]
key_states = key_states[:, :, :, None, :].expand(
batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim
batch_size,
sequence_length,
num_key_value_heads,
num_key_value_groups,
head_dim,
)
key_states = key_states.reshape(
batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim
batch_size,
sequence_length,
num_key_value_heads * num_key_value_groups,
head_dim,
)
value_states = value_states[:, :, :, None, :].expand(
batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim
batch_size,
sequence_length,
num_key_value_heads,
num_key_value_groups,
head_dim,
)
value_states = value_states.reshape(
batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim
batch_size,
sequence_length,
num_key_value_heads * num_key_value_groups,
head_dim,
)
# Attention here is upcasted to float32 to match the original eager implementation.
@@ -400,7 +442,9 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
att_weights *= head_dim**-0.5
big_neg = -2.3819763e38 # See gemma/modules.py
masked_att_weights = torch.where(attention_mask[:, None, :, :], att_weights, big_neg)
masked_att_weights = torch.where(
attention_mask[:, None, :, :], att_weights, big_neg
)
probs = nn.functional.softmax(masked_att_weights, dim=-1)
probs = probs.to(dtype=value_states.dtype)
@@ -412,6 +456,8 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
att_output = att_output.permute(0, 2, 1, 3)
# we use -1 because sequence length can change
att_output = att_output.reshape(batch_size, -1, num_key_value_heads * num_key_value_groups * head_dim)
att_output = att_output.reshape(
batch_size, -1, num_key_value_heads * num_key_value_groups * head_dim
)
return att_output

View File

@@ -71,7 +71,9 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
def _save_pretrained(self, save_directory: Path) -> None:
self.config._save_pretrained(save_directory)
model_to_save = self.module if hasattr(self, "module") else self
save_model_as_safetensor(model_to_save, str(save_directory / SAFETENSORS_SINGLE_FILE))
save_model_as_safetensor(
model_to_save, str(save_directory / SAFETENSORS_SINGLE_FILE)
)
@classmethod
def from_pretrained(
@@ -110,7 +112,9 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
if os.path.isdir(model_id):
print("Loading weights from local directory")
model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE)
policy = cls._load_as_safetensor(instance, model_file, config.device, strict)
policy = cls._load_as_safetensor(
instance, model_file, config.device, strict
)
else:
try:
model_file = hf_hub_download(
@@ -124,7 +128,9 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
token=token,
local_files_only=local_files_only,
)
policy = cls._load_as_safetensor(instance, model_file, config.device, strict)
policy = cls._load_as_safetensor(
instance, model_file, config.device, strict
)
except HfHubHTTPError as e:
raise FileNotFoundError(
f"{SAFETENSORS_SINGLE_FILE} not found on the HuggingFace Hub in {model_id}"
@@ -135,8 +141,12 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
return policy
@classmethod
def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
if packaging.version.parse(safetensors.__version__) < packaging.version.parse("0.4.3"):
def _load_as_safetensor(
cls, model: T, model_file: str, map_location: str, strict: bool
) -> T:
if packaging.version.parse(safetensors.__version__) < packaging.version.parse(
"0.4.3"
):
load_model_as_safetensor(model, model_file, strict=strict)
if map_location != "cpu":
logging.warning(
@@ -147,7 +157,9 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
)
model.to(map_location)
else:
safetensors.torch.load_model(model, model_file, strict=strict, device=map_location)
safetensors.torch.load_model(
model, model_file, strict=strict, device=map_location
)
return model
# def generate_model_card(self, *args, **kwargs) -> ModelCard:

View File

@@ -639,9 +639,9 @@ class Policy(nn.Module):
# Compute standard deviations
if self.fixed_std is None:
log_std = self.std_layer(outputs)
assert not torch.isnan(
log_std
).any(), "[ERROR] log_std became NaN after std_layer!"
assert not torch.isnan(log_std).any(), (
"[ERROR] log_std became NaN after std_layer!"
)
if self.use_tanh_squash:
log_std = torch.tanh(log_std)

View File

@@ -187,7 +187,9 @@ class TDMPCConfig(PreTrainedConfig):
"If `n_action_steps > 1`, `use_mpc` must be set to `True`."
)
if self.n_action_steps > self.horizon:
raise ValueError("`n_action_steps` must be less than or equal to `horizon`.")
raise ValueError(
"`n_action_steps` must be less than or equal to `horizon`."
)
def get_optimizer_preset(self) -> AdamConfig:
return AdamConfig(lr=self.optimizer_lr)
@@ -207,7 +209,9 @@ class TDMPCConfig(PreTrainedConfig):
if image_ft.shape[-2] != image_ft.shape[-1]:
# TODO(alexander-soare): This limitation is solely because of code in the random shift
# augmentation. It should be able to be removed.
raise ValueError(f"Only square images are handled now. Got image shape {image_ft.shape}.")
raise ValueError(
f"Only square images are handled now. Got image shape {image_ft.shape}."
)
@property
def observation_delta_indices(self) -> list:

View File

@@ -39,7 +39,11 @@ from lerobot.common.constants import OBS_ENV, OBS_ROBOT
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.common.policies.utils import get_device_from_parameters, get_output_shape, populate_queues
from lerobot.common.policies.utils import (
get_device_from_parameters,
get_output_shape,
populate_queues,
)
class TDMPCPolicy(PreTrainedPolicy):
@@ -63,7 +67,11 @@ class TDMPCPolicy(PreTrainedPolicy):
config_class = TDMPCConfig
name = "tdmpc"
def __init__(self, config: TDMPCConfig, dataset_stats: dict[str, dict[str, Tensor]] | None = None):
def __init__(
self,
config: TDMPCConfig,
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
):
"""
Args:
config: Policy configuration class instance or None, in which case the default instantiation of
@@ -75,7 +83,9 @@ class TDMPCPolicy(PreTrainedPolicy):
config.validate_features()
self.config = config
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_inputs = Normalize(
config.input_features, config.normalization_mapping, dataset_stats
)
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats
)
@@ -117,7 +127,9 @@ class TDMPCPolicy(PreTrainedPolicy):
"""Select a single action given environment observations."""
batch = self.normalize_inputs(batch)
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch = dict(
batch
) # shallow copy so that adding a key doesn't modify the original
batch["observation.image"] = batch[next(iter(self.config.image_features))]
self._queues = populate_queues(self._queues, batch)
@@ -201,7 +213,10 @@ class TDMPCPolicy(PreTrainedPolicy):
# algorithm.
# The initial mean and standard deviation for the cross-entropy method (CEM).
mean = torch.zeros(
self.config.horizon, batch_size, self.config.action_feature.shape[0], device=device
self.config.horizon,
batch_size,
self.config.action_feature.shape[0],
device=device,
)
# Maybe warm start CEM with the mean from the previous step.
if self._prev_mean is not None:
@@ -339,7 +354,9 @@ class TDMPCPolicy(PreTrainedPolicy):
batch = self.normalize_inputs(batch)
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch = dict(
batch
) # shallow copy so that adding a key doesn't modify the original
batch["observation.image"] = batch[next(iter(self.config.image_features))]
batch = self.normalize_targets(batch)
@@ -371,7 +388,9 @@ class TDMPCPolicy(PreTrainedPolicy):
current_observation[k] = observations[k][0]
next_observations[k] = observations[k][1:]
horizon, batch_size = next_observations[
"observation.image" if self.config.image_features else "observation.environment_state"
"observation.image"
if self.config.image_features
else "observation.environment_state"
].shape[:2]
# Run latent rollout using the latent dynamics model and policy model.
@@ -569,7 +588,9 @@ class TDMPCTOLD(nn.Module):
self.config = config
self._encoder = TDMPCObservationEncoder(config)
self._dynamics = nn.Sequential(
nn.Linear(config.latent_dim + config.action_feature.shape[0], config.mlp_dim),
nn.Linear(
config.latent_dim + config.action_feature.shape[0], config.mlp_dim
),
nn.LayerNorm(config.mlp_dim),
nn.Mish(),
nn.Linear(config.mlp_dim, config.mlp_dim),
@@ -580,7 +601,9 @@ class TDMPCTOLD(nn.Module):
nn.Sigmoid(),
)
self._reward = nn.Sequential(
nn.Linear(config.latent_dim + config.action_feature.shape[0], config.mlp_dim),
nn.Linear(
config.latent_dim + config.action_feature.shape[0], config.mlp_dim
),
nn.LayerNorm(config.mlp_dim),
nn.Mish(),
nn.Linear(config.mlp_dim, config.mlp_dim),
@@ -600,7 +623,10 @@ class TDMPCTOLD(nn.Module):
self._Qs = nn.ModuleList(
[
nn.Sequential(
nn.Linear(config.latent_dim + config.action_feature.shape[0], config.mlp_dim),
nn.Linear(
config.latent_dim + config.action_feature.shape[0],
config.mlp_dim,
),
nn.LayerNorm(config.mlp_dim),
nn.Tanh(),
nn.Linear(config.mlp_dim, config.mlp_dim),
@@ -786,7 +812,9 @@ class TDMPCObservationEncoder(nn.Module):
if config.robot_state_feature:
self.state_enc_layers = nn.Sequential(
nn.Linear(config.robot_state_feature.shape[0], config.state_encoder_hidden_dim),
nn.Linear(
config.robot_state_feature.shape[0], config.state_encoder_hidden_dim
),
nn.ELU(),
nn.Linear(config.state_encoder_hidden_dim, config.latent_dim),
nn.LayerNorm(config.latent_dim),
@@ -795,7 +823,9 @@ class TDMPCObservationEncoder(nn.Module):
if config.env_state_feature:
self.env_state_enc_layers = nn.Sequential(
nn.Linear(config.env_state_feature.shape[0], config.state_encoder_hidden_dim),
nn.Linear(
config.env_state_feature.shape[0], config.state_encoder_hidden_dim
),
nn.ELU(),
nn.Linear(config.state_encoder_hidden_dim, config.latent_dim),
nn.LayerNorm(config.latent_dim),
@@ -813,7 +843,8 @@ class TDMPCObservationEncoder(nn.Module):
if self.config.image_features:
feat.append(
flatten_forward_unflatten(
self.image_enc_layers, obs_dict[next(iter(self.config.image_features))]
self.image_enc_layers,
obs_dict[next(iter(self.config.image_features))],
)
)
if self.config.env_state_feature:

View File

@@ -172,7 +172,10 @@ class VQBeTConfig(PreTrainedConfig):
if self.crop_shape is not None:
for key, image_ft in self.image_features.items():
if self.crop_shape[0] > image_ft.shape[1] or self.crop_shape[1] > image_ft.shape[2]:
if (
self.crop_shape[0] > image_ft.shape[1]
or self.crop_shape[1] > image_ft.shape[2]
):
raise ValueError(
f"`crop_shape` should fit within the images shapes. Got {self.crop_shape} "
f"for `crop_shape` and {image_ft.shape} for "
@@ -193,7 +196,12 @@ class VQBeTConfig(PreTrainedConfig):
@property
def action_delta_indices(self) -> list:
return list(range(1 - self.n_obs_steps, self.n_action_pred_token + self.action_chunk_size - 1))
return list(
range(
1 - self.n_obs_steps,
self.n_action_pred_token + self.action_chunk_size - 1,
)
)
@property
def reward_delta_indices(self) -> None:

View File

@@ -29,7 +29,11 @@ from torch import Tensor, nn
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies.utils import get_device_from_parameters, get_output_shape, populate_queues
from lerobot.common.policies.utils import (
get_device_from_parameters,
get_output_shape,
populate_queues,
)
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
from lerobot.common.policies.vqbet.vqbet_utils import GPT, ResidualVQ
@@ -60,7 +64,9 @@ class VQBeTPolicy(PreTrainedPolicy):
config.validate_features()
self.config = config
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_inputs = Normalize(
config.input_features, config.normalization_mapping, dataset_stats
)
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats
)
@@ -91,11 +97,17 @@ class VQBeTPolicy(PreTrainedPolicy):
if self.config.sequentially_select:
decay_params = (
decay_params
+ list(self.vqbet.action_head.map_to_cbet_preds_primary_bin.parameters())
+ list(self.vqbet.action_head.map_to_cbet_preds_secondary_bin.parameters())
+ list(
self.vqbet.action_head.map_to_cbet_preds_primary_bin.parameters()
)
+ list(
self.vqbet.action_head.map_to_cbet_preds_secondary_bin.parameters()
)
)
else:
decay_params = decay_params + list(self.vqbet.action_head.map_to_cbet_preds_bin.parameters())
decay_params = decay_params + list(
self.vqbet.action_head.map_to_cbet_preds_bin.parameters()
)
return [
{
@@ -133,8 +145,12 @@ class VQBeTPolicy(PreTrainedPolicy):
"""
batch = self.normalize_inputs(batch)
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
batch = dict(
batch
) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack(
[batch[key] for key in self.config.image_features], dim=-4
)
# Note: It's important that this happens after stacking the images into a single key.
self._queues = populate_queues(self._queues, batch)
@@ -165,8 +181,12 @@ class VQBeTPolicy(PreTrainedPolicy):
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
batch = dict(
batch
) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack(
[batch[key] for key in self.config.image_features], dim=-4
)
batch = self.normalize_targets(batch)
# VQ-BeT discretizes action using VQ-VAE before training BeT (please refer to section 3.2 in the VQ-BeT paper https://arxiv.org/pdf/2403.03181)
if not self.vqbet.action_head.vqvae_model.discretized.item():
@@ -334,7 +354,8 @@ class VQBeTModel(nn.Module):
# To input state and observation features into GPT layers, we first project the features to fit the shape of input size of GPT.
self.state_projector = MLP(
config.robot_state_feature.shape[0], hidden_channels=[self.config.gpt_input_dim]
config.robot_state_feature.shape[0],
hidden_channels=[self.config.gpt_input_dim],
)
self.rgb_feature_projector = MLP(
self.rgb_encoder.feature_dim, hidden_channels=[self.config.gpt_input_dim]
@@ -406,9 +427,9 @@ class VQBeTModel(nn.Module):
features = self.policy(input_tokens)
# len(self.config.input_features) is the number of different observation modes.
# this line gets the index of action prompt tokens.
historical_act_pred_index = np.arange(0, n_obs_steps) * (len(self.config.input_features) + 1) + len(
self.config.input_features
)
historical_act_pred_index = np.arange(0, n_obs_steps) * (
len(self.config.input_features) + 1
) + len(self.config.input_features)
# only extract the output tokens at the position of action query:
# Behavior Transformer (BeT), and VQ-BeT are both sequence-to-sequence prediction models,
@@ -771,11 +792,15 @@ class VQBeTRgbEncoder(nn.Module):
# height and width from `config.image_features`.
images_shape = next(iter(config.image_features.values())).shape
dummy_shape_h_w = config.crop_shape if config.crop_shape is not None else images_shape[1:]
dummy_shape_h_w = (
config.crop_shape if config.crop_shape is not None else images_shape[1:]
)
dummy_shape = (1, images_shape[0], *dummy_shape_h_w)
feature_map_shape = get_output_shape(self.backbone, dummy_shape)[1:]
self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints)
self.pool = SpatialSoftmax(
feature_map_shape, num_kp=config.spatial_softmax_num_keypoints
)
self.feature_dim = config.spatial_softmax_num_keypoints * 2
self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim)
self.relu = nn.ReLU()
@@ -871,7 +896,8 @@ class VqVae(nn.Module):
)
self.encoder = MLP(
in_channels=self.config.action_feature.shape[0] * self.config.action_chunk_size,
in_channels=self.config.action_feature.shape[0]
* self.config.action_chunk_size,
hidden_channels=[
config.vqvae_enc_hidden_dim,
config.vqvae_enc_hidden_dim,
@@ -899,9 +925,13 @@ class VqVae(nn.Module):
# given latent vector, this function outputs the decoded action.
output = self.decoder(latent)
if self.config.action_chunk_size == 1:
return einops.rearrange(output, "N (T A) -> N T A", A=self.config.action_feature.shape[0])
return einops.rearrange(
output, "N (T A) -> N T A", A=self.config.action_feature.shape[0]
)
else:
return einops.rearrange(output, "N (T A) -> N T A", A=self.config.action_feature.shape[0])
return einops.rearrange(
output, "N (T A) -> N T A", A=self.config.action_feature.shape[0]
)
def get_code(self, state):
# in phase 2 of VQ-BeT training, we need a `ground truth labels of action data` to calculate the Focal loss for code prediction head. (please refer to section 3.3 in the paper https://arxiv.org/pdf/2403.03181)

View File

@@ -290,10 +290,10 @@ class GPT(nn.Module):
param_dict = dict(self.named_parameters())
inter_params = decay & no_decay
union_params = decay | no_decay
assert (
len(inter_params) == 0
), "parameters {} made it into both decay/no_decay sets!".format(
str(inter_params)
assert len(inter_params) == 0, (
"parameters {} made it into both decay/no_decay sets!".format(
str(inter_params)
)
)
assert len(param_dict.keys() - union_params) == 0, (
"parameters {} were not separated into either decay/no_decay set!".format(
@@ -664,14 +664,14 @@ class VectorQuantize(nn.Module):
self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only
self.orthogonal_reg_max_codes = orthogonal_reg_max_codes
assert not (
ema_update and learnable_codebook
), "learnable codebook not compatible with EMA update"
assert not (ema_update and learnable_codebook), (
"learnable codebook not compatible with EMA update"
)
assert 0 <= sync_update_v <= 1.0
assert not (
sync_update_v > 0.0 and not learnable_codebook
), "learnable codebook must be turned on"
assert not (sync_update_v > 0.0 and not learnable_codebook), (
"learnable codebook must be turned on"
)
self.sync_update_v = sync_update_v