Fixes @torch.no_grad() usage (#1455)

* fix: decorator calls with parentheses

* fix no grad for normalize too

Signed-off-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com>

---------

Signed-off-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com>
This commit is contained in:
Francesco Capuano
2025-07-08 13:08:32 +02:00
committed by GitHub
parent aec1b29d23
commit a5e0aae13a
9 changed files with 16 additions and 15 deletions

View File

@@ -107,7 +107,7 @@ class ACTPolicy(PreTrainedPolicy):
else:
self._action_queue = deque([], maxlen=self.config.n_action_steps)
@torch.no_grad
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations.
@@ -132,7 +132,7 @@ class ACTPolicy(PreTrainedPolicy):
self._action_queue.extend(actions.transpose(0, 1))
return self._action_queue.popleft()
@torch.no_grad
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Predict a chunk of actions given environment observations."""
self.eval()

View File

@@ -99,7 +99,7 @@ class DiffusionPolicy(PreTrainedPolicy):
if self.config.env_state_feature:
self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps)
@torch.no_grad
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Predict a chunk of actions given environment observations."""
# stack n latest observations from the queue
@@ -111,7 +111,7 @@ class DiffusionPolicy(PreTrainedPolicy):
return actions
@torch.no_grad
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations.

View File

@@ -149,7 +149,7 @@ class Normalize(nn.Module):
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
# TODO(rcadene): should we remove torch.no_grad?
@torch.no_grad
@torch.no_grad()
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
# TODO: Remove this shallow copy
batch = dict(batch) # shallow copy avoids mutating the input batch
@@ -224,7 +224,7 @@ class Unnormalize(nn.Module):
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
# TODO(rcadene): should we remove torch.no_grad?
@torch.no_grad
@torch.no_grad()
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
batch = dict(batch) # shallow copy avoids mutating the input batch
for key, ft in self.features.items():

View File

@@ -260,12 +260,12 @@ class PI0Policy(PreTrainedPolicy):
def get_optim_params(self) -> dict:
return self.parameters()
@torch.no_grad
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Predict a chunk of actions given environment observations."""
raise NotImplementedError("Currently not implemented for PI0")
@torch.no_grad
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
"""Select a single action given environment observations.

View File

@@ -192,12 +192,12 @@ class PI0FASTPolicy(PreTrainedPolicy):
actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(actions[:, :, motor_idx])
return actions
@torch.no_grad
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Predict a chunk of actions given environment observations."""
raise NotImplementedError("Currently not implemented for PI0FAST")
@torch.no_grad
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations.

View File

@@ -76,7 +76,7 @@ class SACPolicy(
"""Reset the policy"""
pass
@torch.no_grad
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Predict a chunk of actions given environment observations."""
raise NotImplementedError("SACPolicy does not support action chunking. It returns single actions!")

View File

@@ -413,6 +413,7 @@ class SmolVLAPolicy(PreTrainedPolicy):
return batch
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
self.eval()
@@ -422,7 +423,7 @@ class SmolVLAPolicy(PreTrainedPolicy):
actions = self._get_action_chunk(batch, noise)
return actions
@torch.no_grad
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
"""Select a single action given environment observations.

View File

@@ -110,7 +110,7 @@ class TDMPCPolicy(PreTrainedPolicy):
# CEM for the next step.
self._prev_mean: torch.Tensor | None = None
@torch.no_grad
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Predict a chunk of actions given environment observations."""
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch if key in self._queues}

View File

@@ -124,14 +124,14 @@ class VQBeTPolicy(PreTrainedPolicy):
ACTION: deque(maxlen=self.config.action_chunk_size),
}
@torch.no_grad
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
actions = self.vqbet(batch, rollout=True)[:, : self.config.action_chunk_size]
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
return actions
@torch.no_grad
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations.