[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
committed by
Adil Zouitine
parent
22da1739b1
commit
f5cfd9fd48
@@ -198,7 +198,7 @@ class SACPolicy(
|
||||
|
||||
def grasp_critic_forward(self, observations, use_target=False, observation_features=None):
|
||||
"""Forward pass through a grasp critic network
|
||||
|
||||
|
||||
Args:
|
||||
observations: Dictionary of observations
|
||||
use_target: If True, use target critics, otherwise use ensemble critics
|
||||
@@ -254,7 +254,7 @@ class SACPolicy(
|
||||
observation_features=observation_features,
|
||||
next_observation_features=next_observation_features,
|
||||
)
|
||||
|
||||
|
||||
if model == "grasp_critic":
|
||||
# Extract grasp_critic-specific components
|
||||
complementary_info: dict[str, Tensor] = batch["complementary_info"]
|
||||
@@ -307,7 +307,7 @@ class SACPolicy(
|
||||
param.data * self.config.critic_target_update_weight
|
||||
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
|
||||
)
|
||||
|
||||
|
||||
def update_temperature(self):
|
||||
self.temperature = self.log_alpha.exp().item()
|
||||
|
||||
@@ -369,8 +369,17 @@ class SACPolicy(
|
||||
).sum()
|
||||
return critics_loss
|
||||
|
||||
def compute_loss_grasp_critic(self, observations, actions, rewards, next_observations, done, observation_features=None, next_observation_features=None, complementary_info=None):
|
||||
|
||||
def compute_loss_grasp_critic(
|
||||
self,
|
||||
observations,
|
||||
actions,
|
||||
rewards,
|
||||
next_observations,
|
||||
done,
|
||||
observation_features=None,
|
||||
next_observation_features=None,
|
||||
complementary_info=None,
|
||||
):
|
||||
batch_size = rewards.shape[0]
|
||||
grasp_actions = torch.clip(actions[:, -1].long() + 1, 0, 2) # Map [-1, 0, 1] -> [0, 1, 2]
|
||||
|
||||
@@ -632,9 +641,7 @@ class GraspCritic(nn.Module):
|
||||
self.parameters_to_optimize += list(self.output_layer.parameters())
|
||||
|
||||
def forward(
|
||||
self,
|
||||
observations: torch.Tensor,
|
||||
observation_features: torch.Tensor | None = None
|
||||
self, observations: torch.Tensor, observation_features: torch.Tensor | None = None
|
||||
) -> torch.Tensor:
|
||||
device = get_device_from_parameters(self)
|
||||
# Move each tensor in observations to device
|
||||
|
||||
Reference in New Issue
Block a user