From a28f02ecb3fe94d587a15ef7755da1f25f82765f Mon Sep 17 00:00:00 2001 From: Dhruva <51377003+utterwqlnut@users.noreply.github.com> Date: Fri, 6 Jun 2025 03:25:51 -0400 Subject: [PATCH] replaced OBS_ROBOT with OBS_STATE constant (#1211) --- lerobot/common/policies/pi0fast/modeling_pi0fast.py | 12 ++++++------ lerobot/common/policies/smolvla/modeling_smolvla.py | 12 ++++++------ 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/lerobot/common/policies/pi0fast/modeling_pi0fast.py b/lerobot/common/policies/pi0fast/modeling_pi0fast.py index f19e8c836..4996b1a08 100644 --- a/lerobot/common/policies/pi0fast/modeling_pi0fast.py +++ b/lerobot/common/policies/pi0fast/modeling_pi0fast.py @@ -56,7 +56,7 @@ from transformers import AutoProcessor, AutoTokenizer, PaliGemmaForConditionalGe from transformers.cache_utils import HybridCache, StaticCache from transformers.models.auto import CONFIG_MAPPING -from lerobot.common.constants import ACTION, OBS_ROBOT +from lerobot.common.constants import ACTION, OBS_STATE from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig from lerobot.common.policies.pretrained import PreTrainedPolicy @@ -203,7 +203,7 @@ class PI0FASTPolicy(PreTrainedPolicy): self.eval() if self.config.adapt_to_pi_aloha: - batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT]) + batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE]) batch = self.normalize_inputs(batch) @@ -231,7 +231,7 @@ class PI0FASTPolicy(PreTrainedPolicy): def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: if self.config.adapt_to_pi_aloha: - batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT]) + batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE]) batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION]) batch = self.normalize_inputs(batch) batch = self.normalize_targets(batch) @@ -677,12 +677,12 @@ class PI0FAST(nn.Module): return new_tokens, new_ar_masks, new_padding_mask, new_loss_mask, new_targets, new_token_type_ids def forward(self, batch: dict[str, Tensor]): - device = batch[OBS_ROBOT].device + device = batch[OBS_STATE].device # TODO: keep like this or move to the policy .forward images, img_masks = self.prepare_images(batch) padded_outs = self.create_input_tokens( - state=batch[OBS_ROBOT], + state=batch[OBS_STATE], lang_text=batch["task"], actions=batch[ACTION], ) @@ -849,7 +849,7 @@ class PI0FAST(nn.Module): # TODO: keep like this or move to the policy .forward images, img_masks = self.prepare_images(batch) - padded_outs = self.create_input_tokens(state=batch[OBS_ROBOT], lang_text=batch["task"], actions=None) + padded_outs = self.create_input_tokens(state=batch[OBS_STATE], lang_text=batch["task"], actions=None) embs, pad_masks, att_masks2, targets, loss_mask, token_type_ids = self.embed_inputs( images, img_masks, diff --git a/lerobot/common/policies/smolvla/modeling_smolvla.py b/lerobot/common/policies/smolvla/modeling_smolvla.py index 008ba8380..6ac2d3e7e 100644 --- a/lerobot/common/policies/smolvla/modeling_smolvla.py +++ b/lerobot/common/policies/smolvla/modeling_smolvla.py @@ -60,7 +60,7 @@ import torch.nn.functional as F # noqa: N812 from torch import Tensor, nn from transformers import AutoProcessor -from lerobot.common.constants import ACTION, OBS_ROBOT +from lerobot.common.constants import ACTION, OBS_STATE from lerobot.common.policies.normalize import ( Normalize, Unnormalize, @@ -278,7 +278,7 @@ class SmolVLAPolicy(PreTrainedPolicy): self.eval() if self.config.adapt_to_pi_aloha: - batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT]) + batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE]) batch = self.normalize_inputs(batch) @@ -313,7 +313,7 @@ class SmolVLAPolicy(PreTrainedPolicy): def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> dict[str, Tensor]: """Do a full training forward pass to compute the loss""" if self.config.adapt_to_pi_aloha: - batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT]) + batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE]) batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION]) batch = self.normalize_inputs(batch) batch = self.normalize_targets(batch) @@ -385,10 +385,10 @@ class SmolVLAPolicy(PreTrainedPolicy): def prepare_language(self, batch) -> tuple[Tensor, Tensor]: """Tokenize the text input""" - device = batch[OBS_ROBOT].device + device = batch[OBS_STATE].device tasks = batch["task"] if len(tasks) == 1: - tasks = [tasks[0] for _ in range(batch[OBS_ROBOT].shape[0])] + tasks = [tasks[0] for _ in range(batch[OBS_STATE].shape[0])] tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks] tokenized_prompt = self.language_tokenizer.__call__( @@ -432,7 +432,7 @@ class SmolVLAPolicy(PreTrainedPolicy): def prepare_state(self, batch): """Pad state""" - state = batch[OBS_ROBOT][:, -1, :] if batch[OBS_ROBOT].ndim > 2 else batch[OBS_ROBOT] + state = batch[OBS_STATE][:, -1, :] if batch[OBS_STATE].ndim > 2 else batch[OBS_STATE] state = pad_vector(state, self.config.max_state_dim) return state