wip: still needs batch logic for act and tdmp

This commit is contained in:
Alexander Soare
2024-03-14 15:22:55 +00:00
parent 8c56770318
commit ba91976944
11 changed files with 240 additions and 100 deletions

View File

@@ -3,14 +3,14 @@ import time
import hydra
import torch
import torch.nn as nn
from lerobot.common.policies.abstract import AbstractPolicy
from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy
from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler
from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder
class DiffusionPolicy(nn.Module):
class DiffusionPolicy(AbstractPolicy):
def __init__(
self,
cfg,
@@ -44,6 +44,7 @@ class DiffusionPolicy(nn.Module):
**cfg_obs_encoder,
)
self.n_action_steps = n_action_steps # needed for the parent class
self.diffusion = DiffusionUnetImagePolicy(
shape_meta=shape_meta,
noise_scheduler=noise_scheduler,
@@ -93,21 +94,16 @@ class DiffusionPolicy(nn.Module):
)
@torch.no_grad()
def forward(self, observation, step_count):
def select_action(self, observation, step_count):
# TODO(rcadene): remove unused step_count
del step_count
# TODO(rcadene): remove unsqueeze hack to add bsize=1
observation["image"] = observation["image"].unsqueeze(0)
observation["state"] = observation["state"].unsqueeze(0)
obs_dict = {
"image": observation["image"],
"agent_pos": observation["state"],
}
out = self.diffusion.predict_action(obs_dict)
action = out["action"].squeeze(0)
action = out["action"]
return action
def update(self, replay_buffer, step):