diff --git a/lerobot/common/constants.py b/lerobot/common/constants.py index 990f2aa1..30777239 100644 --- a/lerobot/common/constants.py +++ b/lerobot/common/constants.py @@ -25,6 +25,7 @@ ACTION = "action" REWARD = "next.reward" ROBOTS = "robots" +ROBOT_TYPE = "robot_type" TELEOPERATORS = "teleoperators" # files & directories diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index e7e74bf3..12206657 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -33,6 +33,7 @@ from torch import Tensor, nn from torchvision.models._utils import IntermediateLayerGetter from torchvision.ops.misc import FrozenBatchNorm2d +from lerobot.common.constants import ACTION, OBS_IMAGES from lerobot.common.policies.act.configuration_act import ACTConfig from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.pretrained import PreTrainedPolicy @@ -114,46 +115,49 @@ class ACTPolicy(PreTrainedPolicy): environment. It works by managing the actions in a queue and only calling `select_actions` when the queue is empty. """ - self.eval() + self.eval() # keeping the policy in eval mode as it could be set to train mode while queue is consumed - batch = self.normalize_inputs(batch) - if self.config.image_features: - batch = dict(batch) # shallow copy so that adding a key doesn't modify the original - batch["observation.images"] = [batch[key] for key in self.config.image_features] - - # If we are doing temporal ensembling, do online updates where we keep track of the number of actions - # we are ensembling over. if self.config.temporal_ensemble_coeff is not None: - actions = self.model(batch)[0] # (batch_size, chunk_size, action_dim) - actions = self.unnormalize_outputs({"action": actions})["action"] + actions = self.predict_action_chunk(batch) action = self.temporal_ensembler.update(actions) return action # Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by # querying the policy. if len(self._action_queue) == 0: - actions = self.model(batch)[0][:, : self.config.n_action_steps] - - # TODO(rcadene): make _forward return output dictionary? - actions = self.unnormalize_outputs({"action": actions})["action"] + actions = self.predict_action_chunk(batch)[:, : self.config.n_action_steps] # `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue # effectively has shape (n_action_steps, batch_size, *), hence the transpose. self._action_queue.extend(actions.transpose(0, 1)) return self._action_queue.popleft() + @torch.no_grad + def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: + """Predict a chunk of actions given environment observations.""" + self.eval() + + batch = self.normalize_inputs(batch) + if self.config.image_features: + batch = dict(batch) # shallow copy so that adding a key doesn't modify the original + batch[OBS_IMAGES] = [batch[key] for key in self.config.image_features] + + actions = self.model(batch)[0] + actions = self.unnormalize_outputs({ACTION: actions})[ACTION] + return actions + def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]: """Run the batch through the model and compute the loss for training or validation.""" batch = self.normalize_inputs(batch) if self.config.image_features: batch = dict(batch) # shallow copy so that adding a key doesn't modify the original - batch["observation.images"] = [batch[key] for key in self.config.image_features] + batch[OBS_IMAGES] = [batch[key] for key in self.config.image_features] batch = self.normalize_targets(batch) actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch) l1_loss = ( - F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1) + F.l1_loss(batch[ACTION], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1) ).mean() loss_dict = {"l1_loss": l1_loss.item()} diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index 446e2cb6..038136d0 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -33,7 +33,7 @@ from diffusers.schedulers.scheduling_ddim import DDIMScheduler from diffusers.schedulers.scheduling_ddpm import DDPMScheduler from torch import Tensor, nn -from lerobot.common.constants import OBS_ENV_STATE, OBS_STATE +from lerobot.common.constants import ACTION, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.pretrained import PreTrainedPolicy @@ -99,6 +99,18 @@ class DiffusionPolicy(PreTrainedPolicy): if self.config.env_state_feature: self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps) + @torch.no_grad + def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: + """Predict a chunk of actions given environment observations.""" + # stack n latest observations from the queue + batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues} + actions = self.diffusion.generate_actions(batch) + + # TODO(rcadene): make above methods return output dictionary? + actions = self.unnormalize_outputs({ACTION: actions})[ACTION] + + return actions + @torch.no_grad def select_action(self, batch: dict[str, Tensor]) -> Tensor: """Select a single action given environment observations. @@ -124,23 +136,15 @@ class DiffusionPolicy(PreTrainedPolicy): batch = self.normalize_inputs(batch) if self.config.image_features: batch = dict(batch) # shallow copy so that adding a key doesn't modify the original - batch["observation.images"] = torch.stack( - [batch[key] for key in self.config.image_features], dim=-4 - ) + batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4) # Note: It's important that this happens after stacking the images into a single key. self._queues = populate_queues(self._queues, batch) - if len(self._queues["action"]) == 0: - # stack n latest observations from the queue - batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues} - actions = self.diffusion.generate_actions(batch) + if len(self._queues[ACTION]) == 0: + actions = self.predict_action_chunk(batch) + self._queues[ACTION].extend(actions.transpose(0, 1)) - # TODO(rcadene): make above methods return output dictionary? - actions = self.unnormalize_outputs({"action": actions})["action"] - - self._queues["action"].extend(actions.transpose(0, 1)) - - action = self._queues["action"].popleft() + action = self._queues[ACTION].popleft() return action def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, None]: @@ -148,9 +152,7 @@ class DiffusionPolicy(PreTrainedPolicy): batch = self.normalize_inputs(batch) if self.config.image_features: batch = dict(batch) # shallow copy so that adding a key doesn't modify the original - batch["observation.images"] = torch.stack( - [batch[key] for key in self.config.image_features], dim=-4 - ) + batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4) batch = self.normalize_targets(batch) loss = self.diffusion.compute_loss(batch) # no output_dict so returning None diff --git a/lerobot/common/policies/pi0/modeling_pi0.py b/lerobot/common/policies/pi0/modeling_pi0.py index 1d8a5055..97e66a27 100644 --- a/lerobot/common/policies/pi0/modeling_pi0.py +++ b/lerobot/common/policies/pi0/modeling_pi0.py @@ -260,6 +260,11 @@ class PI0Policy(PreTrainedPolicy): def get_optim_params(self) -> dict: return self.parameters() + @torch.no_grad + def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: + """Predict a chunk of actions given environment observations.""" + raise NotImplementedError("Currently not implemented for PI0") + @torch.no_grad def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: """Select a single action given environment observations. diff --git a/lerobot/common/policies/pi0fast/modeling_pi0fast.py b/lerobot/common/policies/pi0fast/modeling_pi0fast.py index 7102bdde..dbf5266b 100644 --- a/lerobot/common/policies/pi0fast/modeling_pi0fast.py +++ b/lerobot/common/policies/pi0fast/modeling_pi0fast.py @@ -192,6 +192,11 @@ class PI0FASTPolicy(PreTrainedPolicy): actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(actions[:, :, motor_idx]) return actions + @torch.no_grad + def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: + """Predict a chunk of actions given environment observations.""" + raise NotImplementedError("Currently not implemented for PI0FAST") + @torch.no_grad def select_action(self, batch: dict[str, Tensor]) -> Tensor: """Select a single action given environment observations. diff --git a/lerobot/common/policies/pretrained.py b/lerobot/common/policies/pretrained.py index 58eef9ba..bc9276d0 100644 --- a/lerobot/common/policies/pretrained.py +++ b/lerobot/common/policies/pretrained.py @@ -171,6 +171,15 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC): """ raise NotImplementedError + @abc.abstractmethod + def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: + """Returns the action chunk (for action chunking policies) for a given observation, potentially in batch mode. + + Child classes using action chunking should use this method within `select_action` to form the action chunk + cached for selection. + """ + raise NotImplementedError + @abc.abstractmethod def select_action(self, batch: dict[str, Tensor]) -> Tensor: """Return one action to run in the environment (potentially in batch mode). diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index b588115e..1ca46935 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -76,6 +76,11 @@ class SACPolicy( """Reset the policy""" pass + @torch.no_grad + def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: + """Predict a chunk of actions given environment observations.""" + raise NotImplementedError("SACPolicy does not support action chunking. It returns single actions!") + @torch.no_grad() def select_action(self, batch: dict[str, Tensor]) -> Tensor: """Select action for inference/evaluation""" diff --git a/lerobot/common/policies/sac/reward_model/modeling_classifier.py b/lerobot/common/policies/sac/reward_model/modeling_classifier.py index f537e3ae..7fec67f1 100644 --- a/lerobot/common/policies/sac/reward_model/modeling_classifier.py +++ b/lerobot/common/policies/sac/reward_model/modeling_classifier.py @@ -308,6 +308,13 @@ class Classifier(PreTrainedPolicy): """ raise NotImplementedError("Reward classifiers do not select actions") + def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: + """ + This method is required by PreTrainedPolicy but not used for reward classifiers. + The reward classifier is not an actor and does not produce action chunks. + """ + raise NotImplementedError("Reward classifiers do not predict action chunks") + def reset(self): """ This method is required by PreTrainedPolicy but not used for reward classifiers. diff --git a/lerobot/common/policies/smolvla/modeling_smolvla.py b/lerobot/common/policies/smolvla/modeling_smolvla.py index 5e0a9622..36199984 100644 --- a/lerobot/common/policies/smolvla/modeling_smolvla.py +++ b/lerobot/common/policies/smolvla/modeling_smolvla.py @@ -383,6 +383,45 @@ class SmolVLAPolicy(PreTrainedPolicy): def get_optim_params(self) -> dict: return self.parameters() + def _get_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: + for k in batch: + if k in self._queues: + batch[k] = torch.stack(list(self._queues[k]), dim=1) + + images, img_masks = self.prepare_images(batch) + state = self.prepare_state(batch) + lang_tokens, lang_masks = self.prepare_language(batch) + + actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state, noise=noise) + + # Unpad actions + original_action_dim = self.config.action_feature.shape[0] + actions = actions[:, :, :original_action_dim] + + actions = self.unnormalize_outputs({ACTION: actions})[ACTION] + + if self.config.adapt_to_pi_aloha: + actions = self._pi_aloha_encode_actions(actions) + + return actions + + def _prepare_batch(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: + if self.config.adapt_to_pi_aloha: + batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE]) + + batch = self.normalize_inputs(batch) + + return batch + + def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: + self.eval() + + batch = self._prepare_batch(batch) + self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION]) + + actions = self._get_action_chunk(batch, noise) + return actions + @torch.no_grad def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: """Select a single action given environment observations. @@ -392,38 +431,18 @@ class SmolVLAPolicy(PreTrainedPolicy): queue is empty. """ self.eval() - - if self.config.adapt_to_pi_aloha: - batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE]) - - batch = self.normalize_inputs(batch) - + batch = self._prepare_batch(batch) self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION]) + # Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by # querying the policy. if len(self._queues[ACTION]) == 0: - for k in batch: - if k in self._queues: - batch[k] = torch.stack(list(self._queues[k]), dim=1) - images, img_masks = self.prepare_images(batch) - state = self.prepare_state(batch) - lang_tokens, lang_masks = self.prepare_language(batch) + actions = self._get_action_chunk(batch, noise) - actions = self.model.sample_actions( - images, img_masks, lang_tokens, lang_masks, state, noise=noise - ) - # Unpad actions - original_action_dim = self.config.action_feature.shape[0] - actions = actions[:, :, :original_action_dim] - - actions = self.unnormalize_outputs({"action": actions})["action"] - - if self.config.adapt_to_pi_aloha: - actions = self._pi_aloha_encode_actions(actions) - - # `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue + # `self.predict_action_chunk` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue # effectively has shape (n_action_steps, batch_size, *), hence the transpose. self._queues[ACTION].extend(actions.transpose(0, 1)[: self.config.n_action_steps]) + return self._queues[ACTION].popleft() def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> dict[str, Tensor]: diff --git a/lerobot/common/policies/tdmpc/modeling_tdmpc.py b/lerobot/common/policies/tdmpc/modeling_tdmpc.py index 476e6dec..4bb564f8 100644 --- a/lerobot/common/policies/tdmpc/modeling_tdmpc.py +++ b/lerobot/common/policies/tdmpc/modeling_tdmpc.py @@ -35,7 +35,7 @@ import torch.nn as nn import torch.nn.functional as F # noqa: N812 from torch import Tensor -from lerobot.common.constants import OBS_ENV_STATE, OBS_STATE +from lerobot.common.constants import ACTION, OBS_ENV_STATE, OBS_IMAGE, OBS_STATE, REWARD from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.pretrained import PreTrainedPolicy from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig @@ -110,52 +110,58 @@ class TDMPCPolicy(PreTrainedPolicy): # CEM for the next step. self._prev_mean: torch.Tensor | None = None + @torch.no_grad + def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: + """Predict a chunk of actions given environment observations.""" + batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch if key in self._queues} + + # Remove the time dimensions as it is not handled yet. + for key in batch: + assert batch[key].shape[1] == 1 + batch[key] = batch[key][:, 0] + + # NOTE: Order of observations matters here. + encode_keys = [] + if self.config.image_features: + encode_keys.append(OBS_IMAGE) + if self.config.env_state_feature: + encode_keys.append(OBS_ENV_STATE) + encode_keys.append(OBS_STATE) + z = self.model.encode({k: batch[k] for k in encode_keys}) + if self.config.use_mpc: # noqa: SIM108 + actions = self.plan(z) # (horizon, batch, action_dim) + else: + # Plan with the policy (π) alone. This always returns one action so unsqueeze to get a + # sequence dimension like in the MPC branch. + actions = self.model.pi(z).unsqueeze(0) + + actions = torch.clamp(actions, -1, +1) + + actions = self.unnormalize_outputs({ACTION: actions})[ACTION] + return actions + @torch.no_grad() def select_action(self, batch: dict[str, Tensor]) -> Tensor: """Select a single action given environment observations.""" batch = self.normalize_inputs(batch) if self.config.image_features: batch = dict(batch) # shallow copy so that adding a key doesn't modify the original - batch["observation.image"] = batch[next(iter(self.config.image_features))] + batch[OBS_IMAGE] = batch[next(iter(self.config.image_features))] self._queues = populate_queues(self._queues, batch) # When the action queue is depleted, populate it again by querying the policy. - if len(self._queues["action"]) == 0: - batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch if key in self._queues} - - # Remove the time dimensions as it is not handled yet. - for key in batch: - assert batch[key].shape[1] == 1 - batch[key] = batch[key][:, 0] - - # NOTE: Order of observations matters here. - encode_keys = [] - if self.config.image_features: - encode_keys.append("observation.image") - if self.config.env_state_feature: - encode_keys.append("observation.environment_state") - encode_keys.append("observation.state") - z = self.model.encode({k: batch[k] for k in encode_keys}) - if self.config.use_mpc: # noqa: SIM108 - actions = self.plan(z) # (horizon, batch, action_dim) - else: - # Plan with the policy (π) alone. This always returns one action so unsqueeze to get a - # sequence dimension like in the MPC branch. - actions = self.model.pi(z).unsqueeze(0) - - actions = torch.clamp(actions, -1, +1) - - actions = self.unnormalize_outputs({"action": actions})["action"] + if len(self._queues[ACTION]) == 0: + actions = self.predict_action_chunk(batch) if self.config.n_action_repeats > 1: for _ in range(self.config.n_action_repeats): - self._queues["action"].append(actions[0]) + self._queues[ACTION].append(actions[0]) else: # Action queue is (n_action_steps, batch_size, action_dim), so we transpose the action. - self._queues["action"].extend(actions[: self.config.n_action_steps]) + self._queues[ACTION].extend(actions[: self.config.n_action_steps]) - action = self._queues["action"].popleft() + action = self._queues[ACTION].popleft() return action @torch.no_grad() @@ -312,7 +318,7 @@ class TDMPCPolicy(PreTrainedPolicy): batch = self.normalize_inputs(batch) if self.config.image_features: batch = dict(batch) # shallow copy so that adding a key doesn't modify the original - batch["observation.image"] = batch[next(iter(self.config.image_features))] + batch[OBS_IMAGE] = batch[next(iter(self.config.image_features))] batch = self.normalize_targets(batch) info = {} @@ -322,15 +328,15 @@ class TDMPCPolicy(PreTrainedPolicy): if isinstance(batch[key], torch.Tensor) and batch[key].ndim > 1: batch[key] = batch[key].transpose(1, 0) - action = batch["action"] # (t, b, action_dim) - reward = batch["next.reward"] # (t, b) + action = batch[ACTION] # (t, b, action_dim) + reward = batch[REWARD] # (t, b) observations = {k: v for k, v in batch.items() if k.startswith("observation.")} # Apply random image augmentations. if self.config.image_features and self.config.max_random_shift_ratio > 0: - observations["observation.image"] = flatten_forward_unflatten( + observations[OBS_IMAGE] = flatten_forward_unflatten( partial(random_shifts_aug, max_random_shift_ratio=self.config.max_random_shift_ratio), - observations["observation.image"], + observations[OBS_IMAGE], ) # Get the current observation for predicting trajectories, and all future observations for use in @@ -340,7 +346,7 @@ class TDMPCPolicy(PreTrainedPolicy): current_observation[k] = observations[k][0] next_observations[k] = observations[k][1:] horizon, batch_size = next_observations[ - "observation.image" if self.config.image_features else "observation.environment_state" + OBS_IMAGE if self.config.image_features else OBS_ENV_STATE ].shape[:2] # Run latent rollout using the latent dynamics model and policy model. diff --git a/lerobot/common/policies/vqbet/modeling_vqbet.py b/lerobot/common/policies/vqbet/modeling_vqbet.py index 44006a5b..a76bea2a 100644 --- a/lerobot/common/policies/vqbet/modeling_vqbet.py +++ b/lerobot/common/policies/vqbet/modeling_vqbet.py @@ -27,6 +27,7 @@ import torch.nn.functional as F # noqa: N812 import torchvision from torch import Tensor, nn +from lerobot.common.constants import ACTION, OBS_IMAGES, OBS_STATE from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.pretrained import PreTrainedPolicy from lerobot.common.policies.utils import get_device_from_parameters, get_output_shape, populate_queues @@ -118,11 +119,18 @@ class VQBeTPolicy(PreTrainedPolicy): queues are populated during rollout of the policy, they contain the n latest observations and actions """ self._queues = { - "observation.images": deque(maxlen=self.config.n_obs_steps), - "observation.state": deque(maxlen=self.config.n_obs_steps), - "action": deque(maxlen=self.config.action_chunk_size), + OBS_IMAGES: deque(maxlen=self.config.n_obs_steps), + OBS_STATE: deque(maxlen=self.config.n_obs_steps), + ACTION: deque(maxlen=self.config.action_chunk_size), } + @torch.no_grad + def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: + batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues} + actions = self.vqbet(batch, rollout=True)[:, : self.config.action_chunk_size] + actions = self.unnormalize_outputs({ACTION: actions})[ACTION] + return actions + @torch.no_grad def select_action(self, batch: dict[str, Tensor]) -> Tensor: """Select a single action given environment observations. @@ -144,23 +152,19 @@ class VQBeTPolicy(PreTrainedPolicy): stacklevel=1, ) - if len(self._queues["action"]) == 0: - batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues} - actions = self.vqbet(batch, rollout=True)[:, : self.config.action_chunk_size] - - # the dimension of returned action is (batch_size, action_chunk_size, action_dim) - actions = self.unnormalize_outputs({"action": actions})["action"] + if len(self._queues[ACTION]) == 0: + actions = self.predict_action_chunk(batch) # since the data in the action queue's dimension is (action_chunk_size, batch_size, action_dim), we transpose the action and fill the queue - self._queues["action"].extend(actions.transpose(0, 1)) + self._queues[ACTION].extend(actions.transpose(0, 1)) - action = self._queues["action"].popleft() + action = self._queues[ACTION].popleft() return action def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]: """Run the batch through the model and compute the loss for training or validation.""" batch = self.normalize_inputs(batch) batch = dict(batch) # shallow copy so that adding a key doesn't modify the original - batch["observation.images"] = torch.stack([batch[key] for key in self.config.image_features], dim=-4) + batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4) batch = self.normalize_targets(batch) # VQ-BeT discretizes action using VQ-VAE before training BeT (please refer to section 3.2 in the VQ-BeT paper https://huggingface.co/papers/2403.03181) if not self.vqbet.action_head.vqvae_model.discretized.item(): @@ -168,7 +172,7 @@ class VQBeTPolicy(PreTrainedPolicy): # n_different_codes: how many of the total possible VQ codes are being used in single batch (how many of them have at least one encoder embedding as a nearest neighbor). This can be at most `vqvae_n_embed * number of layers of RVQ (=2)`. # n_different_combinations: how many different code combinations are being used out of all possible combinations in single batch. This can be at most `vqvae_n_embed ^ number of layers of RVQ (=2)` (hint consider the RVQ as a decision tree). loss, n_different_codes, n_different_combinations, recon_l1_error = ( - self.vqbet.action_head.discretize(self.config.n_vqvae_training_steps, batch["action"]) + self.vqbet.action_head.discretize(self.config.n_vqvae_training_steps, batch[ACTION]) ) return loss, { "n_different_codes": n_different_codes, @@ -404,7 +408,7 @@ class VQBeTModel(nn.Module): ) # else, it calculate overall loss (bin prediction loss, and offset loss) else: - output = batch["action"][:, self.select_target_actions_indices] + output = batch[ACTION][:, self.select_target_actions_indices] loss = self.action_head.loss_fn(action_head_output, output, reduction="mean") return action_head_output, loss