replaced OBS_ROBOT with OBS_STATE constant (#1211)

This commit is contained in:
Dhruva
2025-06-06 03:25:51 -04:00
committed by GitHub
parent 09343acce7
commit a28f02ecb3
2 changed files with 12 additions and 12 deletions

View File

@@ -56,7 +56,7 @@ from transformers import AutoProcessor, AutoTokenizer, PaliGemmaForConditionalGe
from transformers.cache_utils import HybridCache, StaticCache
from transformers.models.auto import CONFIG_MAPPING
from lerobot.common.constants import ACTION, OBS_ROBOT
from lerobot.common.constants import ACTION, OBS_STATE
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
from lerobot.common.policies.pretrained import PreTrainedPolicy
@@ -203,7 +203,7 @@ class PI0FASTPolicy(PreTrainedPolicy):
self.eval()
if self.config.adapt_to_pi_aloha:
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
batch = self.normalize_inputs(batch)
@@ -231,7 +231,7 @@ class PI0FASTPolicy(PreTrainedPolicy):
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
if self.config.adapt_to_pi_aloha:
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
batch = self.normalize_inputs(batch)
batch = self.normalize_targets(batch)
@@ -677,12 +677,12 @@ class PI0FAST(nn.Module):
return new_tokens, new_ar_masks, new_padding_mask, new_loss_mask, new_targets, new_token_type_ids
def forward(self, batch: dict[str, Tensor]):
device = batch[OBS_ROBOT].device
device = batch[OBS_STATE].device
# TODO: keep like this or move to the policy .forward
images, img_masks = self.prepare_images(batch)
padded_outs = self.create_input_tokens(
state=batch[OBS_ROBOT],
state=batch[OBS_STATE],
lang_text=batch["task"],
actions=batch[ACTION],
)
@@ -849,7 +849,7 @@ class PI0FAST(nn.Module):
# TODO: keep like this or move to the policy .forward
images, img_masks = self.prepare_images(batch)
padded_outs = self.create_input_tokens(state=batch[OBS_ROBOT], lang_text=batch["task"], actions=None)
padded_outs = self.create_input_tokens(state=batch[OBS_STATE], lang_text=batch["task"], actions=None)
embs, pad_masks, att_masks2, targets, loss_mask, token_type_ids = self.embed_inputs(
images,
img_masks,

View File

@@ -60,7 +60,7 @@ import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
from transformers import AutoProcessor
from lerobot.common.constants import ACTION, OBS_ROBOT
from lerobot.common.constants import ACTION, OBS_STATE
from lerobot.common.policies.normalize import (
Normalize,
Unnormalize,
@@ -278,7 +278,7 @@ class SmolVLAPolicy(PreTrainedPolicy):
self.eval()
if self.config.adapt_to_pi_aloha:
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
batch = self.normalize_inputs(batch)
@@ -313,7 +313,7 @@ class SmolVLAPolicy(PreTrainedPolicy):
def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> dict[str, Tensor]:
"""Do a full training forward pass to compute the loss"""
if self.config.adapt_to_pi_aloha:
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
batch = self.normalize_inputs(batch)
batch = self.normalize_targets(batch)
@@ -385,10 +385,10 @@ class SmolVLAPolicy(PreTrainedPolicy):
def prepare_language(self, batch) -> tuple[Tensor, Tensor]:
"""Tokenize the text input"""
device = batch[OBS_ROBOT].device
device = batch[OBS_STATE].device
tasks = batch["task"]
if len(tasks) == 1:
tasks = [tasks[0] for _ in range(batch[OBS_ROBOT].shape[0])]
tasks = [tasks[0] for _ in range(batch[OBS_STATE].shape[0])]
tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks]
tokenized_prompt = self.language_tokenizer.__call__(
@@ -432,7 +432,7 @@ class SmolVLAPolicy(PreTrainedPolicy):
def prepare_state(self, batch):
"""Pad state"""
state = batch[OBS_ROBOT][:, -1, :] if batch[OBS_ROBOT].ndim > 2 else batch[OBS_ROBOT]
state = batch[OBS_STATE][:, -1, :] if batch[OBS_STATE].ndim > 2 else batch[OBS_STATE]
state = pad_vector(state, self.config.max_state_dim)
return state