fix caching

This commit is contained in:
AdilZouitine
2025-04-04 14:29:38 +00:00
committed by Michel Aractingi
parent 1efaf02df9
commit 8bcf41761d
2 changed files with 117 additions and 119 deletions

View File

@@ -775,7 +775,9 @@ def make_optimizers_and_scheduler(cfg: TrainPipelineConfig, policy: nn.Module):
params=policy.actor.parameters_to_optimize,
lr=cfg.policy.actor_lr,
)
optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr)
optimizer_critic = torch.optim.Adam(
params=policy.critic_ensemble.parameters_to_optimize, lr=cfg.policy.critic_lr
)
if cfg.policy.num_discrete_actions is not None:
optimizer_grasp_critic = torch.optim.Adam(
@@ -1024,12 +1026,8 @@ def get_observation_features(
return None, None
with torch.no_grad():
observation_features = (
policy.actor.encoder(observations) if policy.actor.encoder is not None else None
)
next_observation_features = (
policy.actor.encoder(next_observations) if policy.actor.encoder is not None else None
)
observation_features = policy.actor.encoder.get_image_features(observations)
next_observation_features = policy.actor.encoder.get_image_features(next_observations)
return observation_features, next_observation_features