forked from tangger/lerobot
committed by
Michel Aractingi
parent
9e5f254db0
commit
2f7339b410
@@ -772,11 +772,10 @@ def make_optimizers_and_scheduler(cfg: TrainPipelineConfig, policy: nn.Module):
|
||||
|
||||
"""
|
||||
optimizer_actor = torch.optim.Adam(
|
||||
# NOTE: Handle the case of shared encoder where the encoder weights are not optimized with the gradient of the actor
|
||||
params=[
|
||||
p
|
||||
for n, p in policy.actor.named_parameters()
|
||||
if not n.startswith("encoder") or not policy.config.shared_encoder
|
||||
if not policy.config.shared_encoder or not n.startswith("encoder")
|
||||
],
|
||||
lr=cfg.policy.actor_lr,
|
||||
)
|
||||
@@ -1028,8 +1027,8 @@ def get_observation_features(
|
||||
return None, None
|
||||
|
||||
with torch.no_grad():
|
||||
observation_features = policy.actor.encoder.get_base_image_features(observations, normalize=True)
|
||||
next_observation_features = policy.actor.encoder.get_base_image_features(
|
||||
observation_features = policy.actor.encoder.get_cached_image_features(observations, normalize=True)
|
||||
next_observation_features = policy.actor.encoder.get_cached_image_features(
|
||||
next_observations, normalize=True
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user