fix caching
This commit is contained in:
committed by
Michel Aractingi
parent
1efaf02df9
commit
8bcf41761d
@@ -775,7 +775,9 @@ def make_optimizers_and_scheduler(cfg: TrainPipelineConfig, policy: nn.Module):
|
||||
params=policy.actor.parameters_to_optimize,
|
||||
lr=cfg.policy.actor_lr,
|
||||
)
|
||||
optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr)
|
||||
optimizer_critic = torch.optim.Adam(
|
||||
params=policy.critic_ensemble.parameters_to_optimize, lr=cfg.policy.critic_lr
|
||||
)
|
||||
|
||||
if cfg.policy.num_discrete_actions is not None:
|
||||
optimizer_grasp_critic = torch.optim.Adam(
|
||||
@@ -1024,12 +1026,8 @@ def get_observation_features(
|
||||
return None, None
|
||||
|
||||
with torch.no_grad():
|
||||
observation_features = (
|
||||
policy.actor.encoder(observations) if policy.actor.encoder is not None else None
|
||||
)
|
||||
next_observation_features = (
|
||||
policy.actor.encoder(next_observations) if policy.actor.encoder is not None else None
|
||||
)
|
||||
observation_features = policy.actor.encoder.get_image_features(observations)
|
||||
next_observation_features = policy.actor.encoder.get_image_features(next_observations)
|
||||
|
||||
return observation_features, next_observation_features
|
||||
|
||||
|
||||
Reference in New Issue
Block a user