Added possibility to cache the embedding of the images when the encoder choice is pretrained and frozen

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
Michel Aractingi
2025-02-18 08:28:13 +00:00
parent befa1fe9af
commit 8469d13681
3 changed files with 74 additions and 32 deletions

View File

@@ -148,7 +148,7 @@ class SACPolicy(
return actions
def critic_forward(
self, observations: dict[str, Tensor], actions: Tensor, use_target: bool = False
self, observations: dict[str, Tensor], actions: Tensor, use_target: bool = False, features: Optional[Tensor] = None
) -> Tensor:
"""Forward pass through a critic network ensemble
@@ -161,7 +161,7 @@ class SACPolicy(
Tensor of Q-values from all critics
"""
critics = self.critic_target if use_target else self.critic_ensemble
q_values = critics(observations, actions)
q_values = critics(observations, actions, features)
return q_values
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]: ...
@@ -175,14 +175,14 @@ class SACPolicy(
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
)
def compute_loss_critic(self, observations, actions, rewards, next_observations, done) -> Tensor:
def compute_loss_critic(self, observations, actions, rewards, next_observations, done, obs_features=None, next_obs_features=None) -> Tensor:
temperature = self.log_alpha.exp().item()
with torch.no_grad():
next_action_preds, next_log_probs, _ = self.actor(next_observations)
next_action_preds, next_log_probs, _ = self.actor(next_observations, next_obs_features)
# 2- compute q targets
q_targets = self.critic_forward(
observations=next_observations, actions=next_action_preds, use_target=True
observations=next_observations, actions=next_action_preds, use_target=True, features=next_obs_features
)
# subsample critics to prevent overfitting if use high UTD (update to date)
@@ -199,7 +199,7 @@ class SACPolicy(
td_target = rewards + (1 - done) * self.config.discount * min_q
# 3- compute predicted qs
q_preds = self.critic_forward(observations, actions, use_target=False)
q_preds = self.critic_forward(observations, actions, use_target=False, features=obs_features)
# 4- Calculate loss
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
@@ -214,18 +214,18 @@ class SACPolicy(
).sum()
return critics_loss
def compute_loss_temperature(self, observations) -> Tensor:
def compute_loss_temperature(self, observations, obs_features=None) -> Tensor:
"""Compute the temperature loss"""
# calculate temperature loss
with torch.no_grad():
_, log_probs, _ = self.actor(observations)
_, log_probs, _ = self.actor(observations, obs_features)
temperature_loss = (-self.log_alpha.exp() * (log_probs + self.config.target_entropy)).mean()
return temperature_loss
def compute_loss_actor(self, observations) -> Tensor:
def compute_loss_actor(self, observations, obs_features=None) -> Tensor:
temperature = self.log_alpha.exp().item()
actions_pi, log_probs, _ = self.actor(observations)
actions_pi, log_probs, _ = self.actor(observations, obs_features)
q_preds = self.critic_forward(observations, actions_pi, use_target=False)
min_q_preds = q_preds.min(dim=0)[0]
@@ -360,18 +360,19 @@ class CriticEnsemble(nn.Module):
self,
observations: dict[str, torch.Tensor],
actions: torch.Tensor,
features: Optional[torch.Tensor] = None,
) -> torch.Tensor:
device = get_device_from_parameters(self)
# Move each tensor in observations to device
# Move observations to the correct device
observations = {k: v.to(device) for k, v in observations.items()}
# NOTE: We normalize actions it helps for sample efficiency
actions: dict[str, torch.tensor] = {"action": actions}
# NOTE: Normalization layer took dict in input and outputs a dict that why
# Normalize actions for sample efficiency
actions: dict[str, torch.Tensor] = {"action": actions}
actions = self.output_normalization(actions)["action"]
actions = actions.to(device)
obs_enc = observations if self.encoder is None else self.encoder(observations)
# Use precomputed features if provided; otherwise, encode observations.
obs_enc = features if features is not None else (observations if self.encoder is None else self.encoder(observations))
inputs = torch.cat([obs_enc, actions], dim=-1)
list_q_values = []
for network, output_layer in zip(self.network_list, self.output_layers, strict=False):
@@ -435,19 +436,20 @@ class Policy(nn.Module):
def forward(
self,
observations: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Encode observations if encoder exists
obs_enc = observations if self.encoder is None else self.encoder(observations)
features: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# Use precomputed features if provided; otherwise compute encoder representations.
obs_enc = features if features is not None else (observations if self.encoder is None else self.encoder(observations))
# Get network outputs
outputs = self.network(obs_enc)
means = self.mean_layer(outputs)
# Compute standard deviations
if self.fixed_std is None:
log_std = self.std_layer(outputs)
assert not torch.isnan(log_std).any(), "[ERROR] log_std became NaN after std_layer!"
if self.use_tanh_squash:
log_std = torch.tanh(log_std)
log_std = self.log_std_min + 0.5 * (self.log_std_max - self.log_std_min) * (log_std + 1.0)
@@ -455,8 +457,8 @@ class Policy(nn.Module):
log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
else:
log_std = self.fixed_std.expand_as(means)
# uses tanh activation function to squash the action to be in the range of [-1, 1]
# Get distribution and sample actions
normal = torch.distributions.Normal(means, torch.exp(log_std))
x_t = normal.rsample() # Reparameterization trick (mean + std * N(0,1))
log_probs = normal.log_prob(x_t) # Base log probability before Tanh

View File

@@ -52,10 +52,10 @@ policy:
n_action_steps: 1
shared_encoder: true
vision_encoder_name: null
# vision_encoder_name: "helper2424/resnet10"
# freeze_vision_encoder: true
freeze_vision_encoder: false
# vision_encoder_name: null
vision_encoder_name: "helper2424/resnet10"
freeze_vision_encoder: true
# freeze_vision_encoder: false
input_shapes:
# # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
observation.state: ["${env.state_dim}"]

View File

@@ -384,6 +384,21 @@ def add_actor_information_and_train(
done = batch["done"]
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
# Precompute encoder features from the frozen vision encoder if enabled
obs_features, next_obs_features = None, None
if policy.config.vision_encoder_name is not None and policy.config.freeze_vision_encoder:
with torch.no_grad():
obs_features = (
policy.actor.encoder(observations)
if policy.actor.encoder is not None
else None
)
next_obs_features = (
policy.actor.encoder(next_observations)
if policy.actor.encoder is not None
else None
)
with policy_lock:
loss_critic = policy.compute_loss_critic(
observations=observations,
@@ -391,6 +406,8 @@ def add_actor_information_and_train(
rewards=rewards,
next_observations=next_observations,
done=done,
obs_features=obs_features, # pass precomputed features
next_obs_features=next_obs_features, # for target computation
)
optimizers["critic"].zero_grad()
loss_critic.backward()
@@ -412,6 +429,21 @@ def add_actor_information_and_train(
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
# Precompute encoder features from the frozen vision encoder if enabled
obs_features, next_obs_features = None, None
if policy.config.vision_encoder_name is not None and policy.config.freeze_vision_encoder:
with torch.no_grad():
obs_features = (
policy.actor.encoder(observations)
if policy.actor.encoder is not None
else None
)
next_obs_features = (
policy.actor.encoder(next_observations)
if policy.actor.encoder is not None
else None
)
with policy_lock:
loss_critic = policy.compute_loss_critic(
observations=observations,
@@ -419,6 +451,8 @@ def add_actor_information_and_train(
rewards=rewards,
next_observations=next_observations,
done=done,
obs_features=obs_features, # pass precomputed features
next_obs_features=next_obs_features, # for target computation
)
optimizers["critic"].zero_grad()
loss_critic.backward()
@@ -430,7 +464,10 @@ def add_actor_information_and_train(
if optimization_step % cfg.training.policy_update_freq == 0:
for _ in range(cfg.training.policy_update_freq):
with policy_lock:
loss_actor = policy.compute_loss_actor(observations=observations)
loss_actor = policy.compute_loss_actor(
observations=observations,
obs_features=obs_features, # reuse precomputed features here
)
optimizers["actor"].zero_grad()
loss_actor.backward()
@@ -438,7 +475,10 @@ def add_actor_information_and_train(
training_infos["loss_actor"] = loss_actor.item()
loss_temperature = policy.compute_loss_temperature(observations=observations)
loss_temperature = policy.compute_loss_temperature(
observations=observations,
obs_features=obs_features, # and for temperature loss as well
)
optimizers["temperature"].zero_grad()
loss_temperature.backward()
optimizers["temperature"].step()
@@ -582,7 +622,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
)
# compile policy
policy = torch.compile(policy)
# policy = torch.compile(policy)
assert isinstance(policy, nn.Module)
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy)