Merge branch 'user/alexander-soare/train_pusht' into user/alexander-soare/multistep_policy_and_serial_env

This commit is contained in:
Alexander Soare
2024-03-14 16:06:21 +00:00
5 changed files with 241 additions and 19 deletions

View File

@@ -7,7 +7,7 @@ import torch
from lerobot.common.policies.abstract import AbstractPolicy
from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy
from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler
from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder
from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder, RgbEncoder
class DiffusionPolicy(AbstractPolicy):
@@ -38,7 +38,7 @@ class DiffusionPolicy(AbstractPolicy):
self.cfg = cfg
noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler)
rgb_model = hydra.utils.instantiate(cfg_rgb_model)
rgb_model = RgbEncoder(input_shape=shape_meta.obs.image.shape, **cfg_rgb_model)
obs_encoder = MultiImageObsEncoder(
rgb_model=rgb_model,
**cfg_obs_encoder,