[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot]
2025-03-31 16:10:00 +00:00
committed by Adil Zouitine
parent 22da1739b1
commit f5cfd9fd48
5 changed files with 35 additions and 26 deletions

View File

@@ -198,7 +198,7 @@ class SACPolicy(
def grasp_critic_forward(self, observations, use_target=False, observation_features=None):
"""Forward pass through a grasp critic network
Args:
observations: Dictionary of observations
use_target: If True, use target critics, otherwise use ensemble critics
@@ -254,7 +254,7 @@ class SACPolicy(
observation_features=observation_features,
next_observation_features=next_observation_features,
)
if model == "grasp_critic":
# Extract grasp_critic-specific components
complementary_info: dict[str, Tensor] = batch["complementary_info"]
@@ -307,7 +307,7 @@ class SACPolicy(
param.data * self.config.critic_target_update_weight
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
)
def update_temperature(self):
self.temperature = self.log_alpha.exp().item()
@@ -369,8 +369,17 @@ class SACPolicy(
).sum()
return critics_loss
def compute_loss_grasp_critic(self, observations, actions, rewards, next_observations, done, observation_features=None, next_observation_features=None, complementary_info=None):
def compute_loss_grasp_critic(
self,
observations,
actions,
rewards,
next_observations,
done,
observation_features=None,
next_observation_features=None,
complementary_info=None,
):
batch_size = rewards.shape[0]
grasp_actions = torch.clip(actions[:, -1].long() + 1, 0, 2) # Map [-1, 0, 1] -> [0, 1, 2]
@@ -632,9 +641,7 @@ class GraspCritic(nn.Module):
self.parameters_to_optimize += list(self.output_layer.parameters())
def forward(
self,
observations: torch.Tensor,
observation_features: torch.Tensor | None = None
self, observations: torch.Tensor, observation_features: torch.Tensor | None = None
) -> torch.Tensor:
device = get_device_from_parameters(self)
# Move each tensor in observations to device