backup wip
This commit is contained in:
@@ -7,7 +7,7 @@ import torch
|
||||
from lerobot.common.policies.abstract import AbstractPolicy
|
||||
from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy
|
||||
from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler
|
||||
from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder
|
||||
from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder, RgbEncoder
|
||||
|
||||
|
||||
class DiffusionPolicy(AbstractPolicy):
|
||||
@@ -38,6 +38,10 @@ class DiffusionPolicy(AbstractPolicy):
|
||||
self.cfg = cfg
|
||||
|
||||
noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler)
|
||||
rgb_model_input_shape = copy.deepcopy(shape_meta.obs.image.shape)
|
||||
if cfg_obs_encoder.crop_shape is not None:
|
||||
rgb_model_input_shape[1:] = cfg_obs_encoder.crop_shape
|
||||
rgb_model = RgbEncoder(input_shape=rgb_model_input_shape, **cfg_rgb_model)
|
||||
rgb_model = hydra.utils.instantiate(cfg_rgb_model)
|
||||
obs_encoder = MultiImageObsEncoder(
|
||||
rgb_model=rgb_model,
|
||||
|
||||
Reference in New Issue
Block a user