Added caching function in the learner_server and modeling sac in order to limit the number of forward passes through the pretrained encoder when its frozen.

Added tensordict dependencies
Updated the version of torch and torchvision

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
Michel Aractingi
2025-02-21 10:13:43 +00:00
parent e1d55c7a44
commit d3b84ecd6f
8 changed files with 66 additions and 42 deletions

View File

@@ -153,7 +153,7 @@ class SACPolicy(
return actions
def critic_forward(
self, observations: dict[str, Tensor], actions: Tensor, use_target: bool = False
self, observations: dict[str, Tensor], actions: Tensor, use_target: bool = False, observation_features: Tensor | None = None
) -> Tensor:
"""Forward pass through a critic network ensemble
@@ -166,7 +166,7 @@ class SACPolicy(
Tensor of Q-values from all critics
"""
critics = self.critic_target if use_target else self.critic_ensemble
q_values = critics(observations, actions)
q_values = critics(observations, actions, observation_features)
return q_values
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]: ...
@@ -180,14 +180,14 @@ class SACPolicy(
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
)
def compute_loss_critic(self, observations, actions, rewards, next_observations, done) -> Tensor:
def compute_loss_critic(self, observations, actions, rewards, next_observations, done, observation_features: Tensor | None = None, next_observation_features: Tensor | None = None) -> Tensor:
temperature = self.log_alpha.exp().item()
with torch.no_grad():
next_action_preds, next_log_probs, _ = self.actor(next_observations)
next_action_preds, next_log_probs, _ = self.actor(next_observations, next_observation_features)
# 2- compute q targets
q_targets = self.critic_forward(
observations=next_observations, actions=next_action_preds, use_target=True
observations=next_observations, actions=next_action_preds, use_target=True, observation_features=next_observation_features
)
# subsample critics to prevent overfitting if use high UTD (update to date)
@@ -204,7 +204,7 @@ class SACPolicy(
td_target = rewards + (1 - done) * self.config.discount * min_q
# 3- compute predicted qs
q_preds = self.critic_forward(observations, actions, use_target=False)
q_preds = self.critic_forward(observations, actions, use_target=False, observation_features=observation_features)
# 4- Calculate loss
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
@@ -219,20 +219,20 @@ class SACPolicy(
).sum()
return critics_loss
def compute_loss_temperature(self, observations) -> Tensor:
def compute_loss_temperature(self, observations, observation_features: Tensor | None = None) -> Tensor:
"""Compute the temperature loss"""
# calculate temperature loss
with torch.no_grad():
_, log_probs, _ = self.actor(observations)
_, log_probs, _ = self.actor(observations, observation_features)
temperature_loss = (-self.log_alpha.exp() * (log_probs + self.config.target_entropy)).mean()
return temperature_loss
def compute_loss_actor(self, observations) -> Tensor:
def compute_loss_actor(self, observations, observation_features: Tensor | None = None) -> Tensor:
temperature = self.log_alpha.exp().item()
actions_pi, log_probs, _ = self.actor(observations)
actions_pi, log_probs, _ = self.actor(observations, observation_features)
q_preds = self.critic_forward(observations, actions_pi, use_target=False)
q_preds = self.critic_forward(observations, actions_pi, use_target=False, observation_features=observation_features)
min_q_preds = q_preds.min(dim=0)[0]
actor_loss = ((temperature * log_probs) - min_q_preds).mean()
@@ -370,6 +370,7 @@ class CriticEnsemble(nn.Module):
self,
observations: dict[str, torch.Tensor],
actions: torch.Tensor,
observation_features: torch.Tensor | None = None,
) -> torch.Tensor:
device = get_device_from_parameters(self)
# Move each tensor in observations to device
@@ -380,7 +381,7 @@ class CriticEnsemble(nn.Module):
actions = self.output_normalization(actions)["action"]
actions = actions.to(device)
obs_enc = observations if self.encoder is None else self.encoder(observations)
obs_enc = observation_features if observation_features is not None else (observations if self.encoder is None else self.encoder(observations))
inputs = torch.cat([obs_enc, actions], dim=-1)
q_values = self.ensemble(inputs) # [num_critics, B, 1]
@@ -441,9 +442,10 @@ class Policy(nn.Module):
def forward(
self,
observations: torch.Tensor,
observation_features: torch.Tensor | None = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Encode observations if encoder exists
obs_enc = observations if self.encoder is None else self.encoder(observations)
obs_enc = observation_features if observation_features is not None else (observations if self.encoder is None else self.encoder(observations))
# Get network outputs
outputs = self.network(obs_enc)