Added possibility to cache the embedding of the images when the encoder choice is pretrained and frozen
Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
@@ -148,7 +148,7 @@ class SACPolicy(
|
|||||||
return actions
|
return actions
|
||||||
|
|
||||||
def critic_forward(
|
def critic_forward(
|
||||||
self, observations: dict[str, Tensor], actions: Tensor, use_target: bool = False
|
self, observations: dict[str, Tensor], actions: Tensor, use_target: bool = False, features: Optional[Tensor] = None
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
"""Forward pass through a critic network ensemble
|
"""Forward pass through a critic network ensemble
|
||||||
|
|
||||||
@@ -161,7 +161,7 @@ class SACPolicy(
|
|||||||
Tensor of Q-values from all critics
|
Tensor of Q-values from all critics
|
||||||
"""
|
"""
|
||||||
critics = self.critic_target if use_target else self.critic_ensemble
|
critics = self.critic_target if use_target else self.critic_ensemble
|
||||||
q_values = critics(observations, actions)
|
q_values = critics(observations, actions, features)
|
||||||
return q_values
|
return q_values
|
||||||
|
|
||||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]: ...
|
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]: ...
|
||||||
@@ -175,14 +175,14 @@ class SACPolicy(
|
|||||||
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
|
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
|
||||||
)
|
)
|
||||||
|
|
||||||
def compute_loss_critic(self, observations, actions, rewards, next_observations, done) -> Tensor:
|
def compute_loss_critic(self, observations, actions, rewards, next_observations, done, obs_features=None, next_obs_features=None) -> Tensor:
|
||||||
temperature = self.log_alpha.exp().item()
|
temperature = self.log_alpha.exp().item()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
next_action_preds, next_log_probs, _ = self.actor(next_observations)
|
next_action_preds, next_log_probs, _ = self.actor(next_observations, next_obs_features)
|
||||||
|
|
||||||
# 2- compute q targets
|
# 2- compute q targets
|
||||||
q_targets = self.critic_forward(
|
q_targets = self.critic_forward(
|
||||||
observations=next_observations, actions=next_action_preds, use_target=True
|
observations=next_observations, actions=next_action_preds, use_target=True, features=next_obs_features
|
||||||
)
|
)
|
||||||
|
|
||||||
# subsample critics to prevent overfitting if use high UTD (update to date)
|
# subsample critics to prevent overfitting if use high UTD (update to date)
|
||||||
@@ -199,7 +199,7 @@ class SACPolicy(
|
|||||||
td_target = rewards + (1 - done) * self.config.discount * min_q
|
td_target = rewards + (1 - done) * self.config.discount * min_q
|
||||||
|
|
||||||
# 3- compute predicted qs
|
# 3- compute predicted qs
|
||||||
q_preds = self.critic_forward(observations, actions, use_target=False)
|
q_preds = self.critic_forward(observations, actions, use_target=False, features=obs_features)
|
||||||
|
|
||||||
# 4- Calculate loss
|
# 4- Calculate loss
|
||||||
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
||||||
@@ -214,18 +214,18 @@ class SACPolicy(
|
|||||||
).sum()
|
).sum()
|
||||||
return critics_loss
|
return critics_loss
|
||||||
|
|
||||||
def compute_loss_temperature(self, observations) -> Tensor:
|
def compute_loss_temperature(self, observations, obs_features=None) -> Tensor:
|
||||||
"""Compute the temperature loss"""
|
"""Compute the temperature loss"""
|
||||||
# calculate temperature loss
|
# calculate temperature loss
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
_, log_probs, _ = self.actor(observations)
|
_, log_probs, _ = self.actor(observations, obs_features)
|
||||||
temperature_loss = (-self.log_alpha.exp() * (log_probs + self.config.target_entropy)).mean()
|
temperature_loss = (-self.log_alpha.exp() * (log_probs + self.config.target_entropy)).mean()
|
||||||
return temperature_loss
|
return temperature_loss
|
||||||
|
|
||||||
def compute_loss_actor(self, observations) -> Tensor:
|
def compute_loss_actor(self, observations, obs_features=None) -> Tensor:
|
||||||
temperature = self.log_alpha.exp().item()
|
temperature = self.log_alpha.exp().item()
|
||||||
|
|
||||||
actions_pi, log_probs, _ = self.actor(observations)
|
actions_pi, log_probs, _ = self.actor(observations, obs_features)
|
||||||
|
|
||||||
q_preds = self.critic_forward(observations, actions_pi, use_target=False)
|
q_preds = self.critic_forward(observations, actions_pi, use_target=False)
|
||||||
min_q_preds = q_preds.min(dim=0)[0]
|
min_q_preds = q_preds.min(dim=0)[0]
|
||||||
@@ -360,17 +360,18 @@ class CriticEnsemble(nn.Module):
|
|||||||
self,
|
self,
|
||||||
observations: dict[str, torch.Tensor],
|
observations: dict[str, torch.Tensor],
|
||||||
actions: torch.Tensor,
|
actions: torch.Tensor,
|
||||||
|
features: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
device = get_device_from_parameters(self)
|
device = get_device_from_parameters(self)
|
||||||
# Move each tensor in observations to device
|
# Move observations to the correct device
|
||||||
observations = {k: v.to(device) for k, v in observations.items()}
|
observations = {k: v.to(device) for k, v in observations.items()}
|
||||||
# NOTE: We normalize actions it helps for sample efficiency
|
# Normalize actions for sample efficiency
|
||||||
actions: dict[str, torch.tensor] = {"action": actions}
|
actions: dict[str, torch.Tensor] = {"action": actions}
|
||||||
# NOTE: Normalization layer took dict in input and outputs a dict that why
|
|
||||||
actions = self.output_normalization(actions)["action"]
|
actions = self.output_normalization(actions)["action"]
|
||||||
actions = actions.to(device)
|
actions = actions.to(device)
|
||||||
|
|
||||||
obs_enc = observations if self.encoder is None else self.encoder(observations)
|
# Use precomputed features if provided; otherwise, encode observations.
|
||||||
|
obs_enc = features if features is not None else (observations if self.encoder is None else self.encoder(observations))
|
||||||
|
|
||||||
inputs = torch.cat([obs_enc, actions], dim=-1)
|
inputs = torch.cat([obs_enc, actions], dim=-1)
|
||||||
list_q_values = []
|
list_q_values = []
|
||||||
@@ -435,9 +436,10 @@ class Policy(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
observations: torch.Tensor,
|
observations: torch.Tensor,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
features: Optional[torch.Tensor] = None,
|
||||||
# Encode observations if encoder exists
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
obs_enc = observations if self.encoder is None else self.encoder(observations)
|
# Use precomputed features if provided; otherwise compute encoder representations.
|
||||||
|
obs_enc = features if features is not None else (observations if self.encoder is None else self.encoder(observations))
|
||||||
|
|
||||||
# Get network outputs
|
# Get network outputs
|
||||||
outputs = self.network(obs_enc)
|
outputs = self.network(obs_enc)
|
||||||
@@ -456,7 +458,7 @@ class Policy(nn.Module):
|
|||||||
else:
|
else:
|
||||||
log_std = self.fixed_std.expand_as(means)
|
log_std = self.fixed_std.expand_as(means)
|
||||||
|
|
||||||
# uses tanh activation function to squash the action to be in the range of [-1, 1]
|
# Get distribution and sample actions
|
||||||
normal = torch.distributions.Normal(means, torch.exp(log_std))
|
normal = torch.distributions.Normal(means, torch.exp(log_std))
|
||||||
x_t = normal.rsample() # Reparameterization trick (mean + std * N(0,1))
|
x_t = normal.rsample() # Reparameterization trick (mean + std * N(0,1))
|
||||||
log_probs = normal.log_prob(x_t) # Base log probability before Tanh
|
log_probs = normal.log_prob(x_t) # Base log probability before Tanh
|
||||||
|
|||||||
@@ -52,10 +52,10 @@ policy:
|
|||||||
n_action_steps: 1
|
n_action_steps: 1
|
||||||
|
|
||||||
shared_encoder: true
|
shared_encoder: true
|
||||||
vision_encoder_name: null
|
# vision_encoder_name: null
|
||||||
# vision_encoder_name: "helper2424/resnet10"
|
vision_encoder_name: "helper2424/resnet10"
|
||||||
# freeze_vision_encoder: true
|
freeze_vision_encoder: true
|
||||||
freeze_vision_encoder: false
|
# freeze_vision_encoder: false
|
||||||
input_shapes:
|
input_shapes:
|
||||||
# # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
# # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||||
observation.state: ["${env.state_dim}"]
|
observation.state: ["${env.state_dim}"]
|
||||||
|
|||||||
@@ -384,6 +384,21 @@ def add_actor_information_and_train(
|
|||||||
done = batch["done"]
|
done = batch["done"]
|
||||||
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
|
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
|
||||||
|
|
||||||
|
# Precompute encoder features from the frozen vision encoder if enabled
|
||||||
|
obs_features, next_obs_features = None, None
|
||||||
|
if policy.config.vision_encoder_name is not None and policy.config.freeze_vision_encoder:
|
||||||
|
with torch.no_grad():
|
||||||
|
obs_features = (
|
||||||
|
policy.actor.encoder(observations)
|
||||||
|
if policy.actor.encoder is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
next_obs_features = (
|
||||||
|
policy.actor.encoder(next_observations)
|
||||||
|
if policy.actor.encoder is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
with policy_lock:
|
with policy_lock:
|
||||||
loss_critic = policy.compute_loss_critic(
|
loss_critic = policy.compute_loss_critic(
|
||||||
observations=observations,
|
observations=observations,
|
||||||
@@ -391,6 +406,8 @@ def add_actor_information_and_train(
|
|||||||
rewards=rewards,
|
rewards=rewards,
|
||||||
next_observations=next_observations,
|
next_observations=next_observations,
|
||||||
done=done,
|
done=done,
|
||||||
|
obs_features=obs_features, # pass precomputed features
|
||||||
|
next_obs_features=next_obs_features, # for target computation
|
||||||
)
|
)
|
||||||
optimizers["critic"].zero_grad()
|
optimizers["critic"].zero_grad()
|
||||||
loss_critic.backward()
|
loss_critic.backward()
|
||||||
@@ -412,6 +429,21 @@ def add_actor_information_and_train(
|
|||||||
|
|
||||||
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
|
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
|
||||||
|
|
||||||
|
# Precompute encoder features from the frozen vision encoder if enabled
|
||||||
|
obs_features, next_obs_features = None, None
|
||||||
|
if policy.config.vision_encoder_name is not None and policy.config.freeze_vision_encoder:
|
||||||
|
with torch.no_grad():
|
||||||
|
obs_features = (
|
||||||
|
policy.actor.encoder(observations)
|
||||||
|
if policy.actor.encoder is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
next_obs_features = (
|
||||||
|
policy.actor.encoder(next_observations)
|
||||||
|
if policy.actor.encoder is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
with policy_lock:
|
with policy_lock:
|
||||||
loss_critic = policy.compute_loss_critic(
|
loss_critic = policy.compute_loss_critic(
|
||||||
observations=observations,
|
observations=observations,
|
||||||
@@ -419,6 +451,8 @@ def add_actor_information_and_train(
|
|||||||
rewards=rewards,
|
rewards=rewards,
|
||||||
next_observations=next_observations,
|
next_observations=next_observations,
|
||||||
done=done,
|
done=done,
|
||||||
|
obs_features=obs_features, # pass precomputed features
|
||||||
|
next_obs_features=next_obs_features, # for target computation
|
||||||
)
|
)
|
||||||
optimizers["critic"].zero_grad()
|
optimizers["critic"].zero_grad()
|
||||||
loss_critic.backward()
|
loss_critic.backward()
|
||||||
@@ -430,7 +464,10 @@ def add_actor_information_and_train(
|
|||||||
if optimization_step % cfg.training.policy_update_freq == 0:
|
if optimization_step % cfg.training.policy_update_freq == 0:
|
||||||
for _ in range(cfg.training.policy_update_freq):
|
for _ in range(cfg.training.policy_update_freq):
|
||||||
with policy_lock:
|
with policy_lock:
|
||||||
loss_actor = policy.compute_loss_actor(observations=observations)
|
loss_actor = policy.compute_loss_actor(
|
||||||
|
observations=observations,
|
||||||
|
obs_features=obs_features, # reuse precomputed features here
|
||||||
|
)
|
||||||
|
|
||||||
optimizers["actor"].zero_grad()
|
optimizers["actor"].zero_grad()
|
||||||
loss_actor.backward()
|
loss_actor.backward()
|
||||||
@@ -438,7 +475,10 @@ def add_actor_information_and_train(
|
|||||||
|
|
||||||
training_infos["loss_actor"] = loss_actor.item()
|
training_infos["loss_actor"] = loss_actor.item()
|
||||||
|
|
||||||
loss_temperature = policy.compute_loss_temperature(observations=observations)
|
loss_temperature = policy.compute_loss_temperature(
|
||||||
|
observations=observations,
|
||||||
|
obs_features=obs_features, # and for temperature loss as well
|
||||||
|
)
|
||||||
optimizers["temperature"].zero_grad()
|
optimizers["temperature"].zero_grad()
|
||||||
loss_temperature.backward()
|
loss_temperature.backward()
|
||||||
optimizers["temperature"].step()
|
optimizers["temperature"].step()
|
||||||
@@ -582,7 +622,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||||||
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
|
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
|
||||||
)
|
)
|
||||||
# compile policy
|
# compile policy
|
||||||
policy = torch.compile(policy)
|
# policy = torch.compile(policy)
|
||||||
assert isinstance(policy, nn.Module)
|
assert isinstance(policy, nn.Module)
|
||||||
|
|
||||||
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy)
|
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy)
|
||||||
|
|||||||
Reference in New Issue
Block a user