replaced OBS_ROBOT with OBS_STATE constant (#1211)
This commit is contained in:
@@ -56,7 +56,7 @@ from transformers import AutoProcessor, AutoTokenizer, PaliGemmaForConditionalGe
|
|||||||
from transformers.cache_utils import HybridCache, StaticCache
|
from transformers.cache_utils import HybridCache, StaticCache
|
||||||
from transformers.models.auto import CONFIG_MAPPING
|
from transformers.models.auto import CONFIG_MAPPING
|
||||||
|
|
||||||
from lerobot.common.constants import ACTION, OBS_ROBOT
|
from lerobot.common.constants import ACTION, OBS_STATE
|
||||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||||
from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
|
from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
|
||||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||||
@@ -203,7 +203,7 @@ class PI0FASTPolicy(PreTrainedPolicy):
|
|||||||
self.eval()
|
self.eval()
|
||||||
|
|
||||||
if self.config.adapt_to_pi_aloha:
|
if self.config.adapt_to_pi_aloha:
|
||||||
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
|
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
|
||||||
|
|
||||||
batch = self.normalize_inputs(batch)
|
batch = self.normalize_inputs(batch)
|
||||||
|
|
||||||
@@ -231,7 +231,7 @@ class PI0FASTPolicy(PreTrainedPolicy):
|
|||||||
|
|
||||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||||
if self.config.adapt_to_pi_aloha:
|
if self.config.adapt_to_pi_aloha:
|
||||||
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
|
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
|
||||||
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
|
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
|
||||||
batch = self.normalize_inputs(batch)
|
batch = self.normalize_inputs(batch)
|
||||||
batch = self.normalize_targets(batch)
|
batch = self.normalize_targets(batch)
|
||||||
@@ -677,12 +677,12 @@ class PI0FAST(nn.Module):
|
|||||||
return new_tokens, new_ar_masks, new_padding_mask, new_loss_mask, new_targets, new_token_type_ids
|
return new_tokens, new_ar_masks, new_padding_mask, new_loss_mask, new_targets, new_token_type_ids
|
||||||
|
|
||||||
def forward(self, batch: dict[str, Tensor]):
|
def forward(self, batch: dict[str, Tensor]):
|
||||||
device = batch[OBS_ROBOT].device
|
device = batch[OBS_STATE].device
|
||||||
# TODO: keep like this or move to the policy .forward
|
# TODO: keep like this or move to the policy .forward
|
||||||
images, img_masks = self.prepare_images(batch)
|
images, img_masks = self.prepare_images(batch)
|
||||||
|
|
||||||
padded_outs = self.create_input_tokens(
|
padded_outs = self.create_input_tokens(
|
||||||
state=batch[OBS_ROBOT],
|
state=batch[OBS_STATE],
|
||||||
lang_text=batch["task"],
|
lang_text=batch["task"],
|
||||||
actions=batch[ACTION],
|
actions=batch[ACTION],
|
||||||
)
|
)
|
||||||
@@ -849,7 +849,7 @@ class PI0FAST(nn.Module):
|
|||||||
# TODO: keep like this or move to the policy .forward
|
# TODO: keep like this or move to the policy .forward
|
||||||
images, img_masks = self.prepare_images(batch)
|
images, img_masks = self.prepare_images(batch)
|
||||||
|
|
||||||
padded_outs = self.create_input_tokens(state=batch[OBS_ROBOT], lang_text=batch["task"], actions=None)
|
padded_outs = self.create_input_tokens(state=batch[OBS_STATE], lang_text=batch["task"], actions=None)
|
||||||
embs, pad_masks, att_masks2, targets, loss_mask, token_type_ids = self.embed_inputs(
|
embs, pad_masks, att_masks2, targets, loss_mask, token_type_ids = self.embed_inputs(
|
||||||
images,
|
images,
|
||||||
img_masks,
|
img_masks,
|
||||||
|
|||||||
@@ -60,7 +60,7 @@ import torch.nn.functional as F # noqa: N812
|
|||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from transformers import AutoProcessor
|
from transformers import AutoProcessor
|
||||||
|
|
||||||
from lerobot.common.constants import ACTION, OBS_ROBOT
|
from lerobot.common.constants import ACTION, OBS_STATE
|
||||||
from lerobot.common.policies.normalize import (
|
from lerobot.common.policies.normalize import (
|
||||||
Normalize,
|
Normalize,
|
||||||
Unnormalize,
|
Unnormalize,
|
||||||
@@ -278,7 +278,7 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
|||||||
self.eval()
|
self.eval()
|
||||||
|
|
||||||
if self.config.adapt_to_pi_aloha:
|
if self.config.adapt_to_pi_aloha:
|
||||||
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
|
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
|
||||||
|
|
||||||
batch = self.normalize_inputs(batch)
|
batch = self.normalize_inputs(batch)
|
||||||
|
|
||||||
@@ -313,7 +313,7 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
|||||||
def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> dict[str, Tensor]:
|
def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> dict[str, Tensor]:
|
||||||
"""Do a full training forward pass to compute the loss"""
|
"""Do a full training forward pass to compute the loss"""
|
||||||
if self.config.adapt_to_pi_aloha:
|
if self.config.adapt_to_pi_aloha:
|
||||||
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
|
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
|
||||||
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
|
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
|
||||||
batch = self.normalize_inputs(batch)
|
batch = self.normalize_inputs(batch)
|
||||||
batch = self.normalize_targets(batch)
|
batch = self.normalize_targets(batch)
|
||||||
@@ -385,10 +385,10 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
|||||||
|
|
||||||
def prepare_language(self, batch) -> tuple[Tensor, Tensor]:
|
def prepare_language(self, batch) -> tuple[Tensor, Tensor]:
|
||||||
"""Tokenize the text input"""
|
"""Tokenize the text input"""
|
||||||
device = batch[OBS_ROBOT].device
|
device = batch[OBS_STATE].device
|
||||||
tasks = batch["task"]
|
tasks = batch["task"]
|
||||||
if len(tasks) == 1:
|
if len(tasks) == 1:
|
||||||
tasks = [tasks[0] for _ in range(batch[OBS_ROBOT].shape[0])]
|
tasks = [tasks[0] for _ in range(batch[OBS_STATE].shape[0])]
|
||||||
|
|
||||||
tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks]
|
tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks]
|
||||||
tokenized_prompt = self.language_tokenizer.__call__(
|
tokenized_prompt = self.language_tokenizer.__call__(
|
||||||
@@ -432,7 +432,7 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
|||||||
|
|
||||||
def prepare_state(self, batch):
|
def prepare_state(self, batch):
|
||||||
"""Pad state"""
|
"""Pad state"""
|
||||||
state = batch[OBS_ROBOT][:, -1, :] if batch[OBS_ROBOT].ndim > 2 else batch[OBS_ROBOT]
|
state = batch[OBS_STATE][:, -1, :] if batch[OBS_STATE].ndim > 2 else batch[OBS_STATE]
|
||||||
state = pad_vector(state, self.config.max_state_dim)
|
state = pad_vector(state, self.config.max_state_dim)
|
||||||
return state
|
return state
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user