Fixes @torch.no_grad() usage (#1455)
* fix: decorator calls with parentheses * fix no grad for normalize too Signed-off-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com> --------- Signed-off-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
aec1b29d23
commit
a5e0aae13a
@@ -107,7 +107,7 @@ class ACTPolicy(PreTrainedPolicy):
|
|||||||
else:
|
else:
|
||||||
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
||||||
|
|
||||||
@torch.no_grad
|
@torch.no_grad()
|
||||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
"""Select a single action given environment observations.
|
"""Select a single action given environment observations.
|
||||||
|
|
||||||
@@ -132,7 +132,7 @@ class ACTPolicy(PreTrainedPolicy):
|
|||||||
self._action_queue.extend(actions.transpose(0, 1))
|
self._action_queue.extend(actions.transpose(0, 1))
|
||||||
return self._action_queue.popleft()
|
return self._action_queue.popleft()
|
||||||
|
|
||||||
@torch.no_grad
|
@torch.no_grad()
|
||||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
"""Predict a chunk of actions given environment observations."""
|
"""Predict a chunk of actions given environment observations."""
|
||||||
self.eval()
|
self.eval()
|
||||||
|
|||||||
@@ -99,7 +99,7 @@ class DiffusionPolicy(PreTrainedPolicy):
|
|||||||
if self.config.env_state_feature:
|
if self.config.env_state_feature:
|
||||||
self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps)
|
self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps)
|
||||||
|
|
||||||
@torch.no_grad
|
@torch.no_grad()
|
||||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
"""Predict a chunk of actions given environment observations."""
|
"""Predict a chunk of actions given environment observations."""
|
||||||
# stack n latest observations from the queue
|
# stack n latest observations from the queue
|
||||||
@@ -111,7 +111,7 @@ class DiffusionPolicy(PreTrainedPolicy):
|
|||||||
|
|
||||||
return actions
|
return actions
|
||||||
|
|
||||||
@torch.no_grad
|
@torch.no_grad()
|
||||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
"""Select a single action given environment observations.
|
"""Select a single action given environment observations.
|
||||||
|
|
||||||
|
|||||||
@@ -149,7 +149,7 @@ class Normalize(nn.Module):
|
|||||||
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
||||||
|
|
||||||
# TODO(rcadene): should we remove torch.no_grad?
|
# TODO(rcadene): should we remove torch.no_grad?
|
||||||
@torch.no_grad
|
@torch.no_grad()
|
||||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||||
# TODO: Remove this shallow copy
|
# TODO: Remove this shallow copy
|
||||||
batch = dict(batch) # shallow copy avoids mutating the input batch
|
batch = dict(batch) # shallow copy avoids mutating the input batch
|
||||||
@@ -224,7 +224,7 @@ class Unnormalize(nn.Module):
|
|||||||
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
||||||
|
|
||||||
# TODO(rcadene): should we remove torch.no_grad?
|
# TODO(rcadene): should we remove torch.no_grad?
|
||||||
@torch.no_grad
|
@torch.no_grad()
|
||||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||||
batch = dict(batch) # shallow copy avoids mutating the input batch
|
batch = dict(batch) # shallow copy avoids mutating the input batch
|
||||||
for key, ft in self.features.items():
|
for key, ft in self.features.items():
|
||||||
|
|||||||
@@ -260,12 +260,12 @@ class PI0Policy(PreTrainedPolicy):
|
|||||||
def get_optim_params(self) -> dict:
|
def get_optim_params(self) -> dict:
|
||||||
return self.parameters()
|
return self.parameters()
|
||||||
|
|
||||||
@torch.no_grad
|
@torch.no_grad()
|
||||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
"""Predict a chunk of actions given environment observations."""
|
"""Predict a chunk of actions given environment observations."""
|
||||||
raise NotImplementedError("Currently not implemented for PI0")
|
raise NotImplementedError("Currently not implemented for PI0")
|
||||||
|
|
||||||
@torch.no_grad
|
@torch.no_grad()
|
||||||
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||||
"""Select a single action given environment observations.
|
"""Select a single action given environment observations.
|
||||||
|
|
||||||
|
|||||||
@@ -192,12 +192,12 @@ class PI0FASTPolicy(PreTrainedPolicy):
|
|||||||
actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(actions[:, :, motor_idx])
|
actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(actions[:, :, motor_idx])
|
||||||
return actions
|
return actions
|
||||||
|
|
||||||
@torch.no_grad
|
@torch.no_grad()
|
||||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
"""Predict a chunk of actions given environment observations."""
|
"""Predict a chunk of actions given environment observations."""
|
||||||
raise NotImplementedError("Currently not implemented for PI0FAST")
|
raise NotImplementedError("Currently not implemented for PI0FAST")
|
||||||
|
|
||||||
@torch.no_grad
|
@torch.no_grad()
|
||||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
"""Select a single action given environment observations.
|
"""Select a single action given environment observations.
|
||||||
|
|
||||||
|
|||||||
@@ -76,7 +76,7 @@ class SACPolicy(
|
|||||||
"""Reset the policy"""
|
"""Reset the policy"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@torch.no_grad
|
@torch.no_grad()
|
||||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
"""Predict a chunk of actions given environment observations."""
|
"""Predict a chunk of actions given environment observations."""
|
||||||
raise NotImplementedError("SACPolicy does not support action chunking. It returns single actions!")
|
raise NotImplementedError("SACPolicy does not support action chunking. It returns single actions!")
|
||||||
|
|||||||
@@ -413,6 +413,7 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
|||||||
|
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||||
self.eval()
|
self.eval()
|
||||||
|
|
||||||
@@ -422,7 +423,7 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
|||||||
actions = self._get_action_chunk(batch, noise)
|
actions = self._get_action_chunk(batch, noise)
|
||||||
return actions
|
return actions
|
||||||
|
|
||||||
@torch.no_grad
|
@torch.no_grad()
|
||||||
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||||
"""Select a single action given environment observations.
|
"""Select a single action given environment observations.
|
||||||
|
|
||||||
|
|||||||
@@ -110,7 +110,7 @@ class TDMPCPolicy(PreTrainedPolicy):
|
|||||||
# CEM for the next step.
|
# CEM for the next step.
|
||||||
self._prev_mean: torch.Tensor | None = None
|
self._prev_mean: torch.Tensor | None = None
|
||||||
|
|
||||||
@torch.no_grad
|
@torch.no_grad()
|
||||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
"""Predict a chunk of actions given environment observations."""
|
"""Predict a chunk of actions given environment observations."""
|
||||||
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch if key in self._queues}
|
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch if key in self._queues}
|
||||||
|
|||||||
@@ -124,14 +124,14 @@ class VQBeTPolicy(PreTrainedPolicy):
|
|||||||
ACTION: deque(maxlen=self.config.action_chunk_size),
|
ACTION: deque(maxlen=self.config.action_chunk_size),
|
||||||
}
|
}
|
||||||
|
|
||||||
@torch.no_grad
|
@torch.no_grad()
|
||||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
|
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
|
||||||
actions = self.vqbet(batch, rollout=True)[:, : self.config.action_chunk_size]
|
actions = self.vqbet(batch, rollout=True)[:, : self.config.action_chunk_size]
|
||||||
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
|
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
|
||||||
return actions
|
return actions
|
||||||
|
|
||||||
@torch.no_grad
|
@torch.no_grad()
|
||||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
"""Select a single action given environment observations.
|
"""Select a single action given environment observations.
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user