Merge remote-tracking branch 'upstream/main' into fix_pusht_diffusion

This commit is contained in:
Alexander Soare
2024-03-21 10:20:52 +00:00
15 changed files with 276 additions and 46 deletions

View File

@@ -1,11 +1,10 @@
from abc import ABC, abstractmethod
from collections import deque
import torch
from torch import Tensor, nn
class AbstractPolicy(nn.Module, ABC):
class AbstractPolicy(nn.Module):
"""Base policy which all policies should be derived from.
The forward method should generally not be overriden as it plays the role of handling multi-step policies. See its
@@ -22,9 +21,9 @@ class AbstractPolicy(nn.Module, ABC):
self.n_action_steps = n_action_steps
self.clear_action_queue()
@abstractmethod
def update(self, replay_buffer, step):
"""One step of the policy's learning algorithm."""
raise NotImplementedError("Abstract method")
def save(self, fp):
torch.save(self.state_dict(), fp)
@@ -33,13 +32,13 @@ class AbstractPolicy(nn.Module, ABC):
d = torch.load(fp)
self.load_state_dict(d)
@abstractmethod
def select_actions(self, observation) -> Tensor:
"""Select an action (or trajectory of actions) based on an observation during rollout.
If n_action_steps was provided at initialization, this should return a (batch_size, n_action_steps, *) tensor of
actions. Otherwise if n_actions_steps is None, this should return a (batch_size, *) tensor of actions.
"""
raise NotImplementedError("Abstract method")
def clear_action_queue(self):
"""This should be called whenever the environment is reset."""

View File

@@ -7,6 +7,7 @@ import torchvision.transforms as transforms
from lerobot.common.policies.abstract import AbstractPolicy
from lerobot.common.policies.act.detr_vae import build
from lerobot.common.utils import get_safe_torch_device
def build_act_model_and_optimizer(cfg):
@@ -45,7 +46,7 @@ class ActionChunkingTransformerPolicy(AbstractPolicy):
super().__init__(n_action_steps)
self.cfg = cfg
self.n_action_steps = n_action_steps
self.device = device
self.device = get_safe_torch_device(device)
self.model, self.optimizer = build_act_model_and_optimizer(cfg)
self.kl_weight = self.cfg.kl_weight
logging.info(f"KL Weight {self.kl_weight}")

View File

@@ -9,6 +9,7 @@ from lerobot.common.policies.abstract import AbstractPolicy
from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy
from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler
from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder, RgbEncoder
from lerobot.common.utils import get_safe_torch_device
class DiffusionPolicy(AbstractPolicy):
@@ -66,9 +67,8 @@ class DiffusionPolicy(AbstractPolicy):
**kwargs,
)
self.device = torch.device(cfg_device)
if torch.cuda.is_available() and cfg_device == "cuda":
self.diffusion.cuda()
self.device = get_safe_torch_device(cfg_device)
self.diffusion.to(self.device)
self.ema_diffusion = None
self.ema = None

View File

@@ -10,6 +10,7 @@ import torch.nn as nn
import lerobot.common.policies.tdmpc.helper as h
from lerobot.common.policies.abstract import AbstractPolicy
from lerobot.common.utils import get_safe_torch_device
FIRST_FRAME = 0
@@ -94,9 +95,10 @@ class TDMPC(AbstractPolicy):
self.action_dim = cfg.action_dim
self.cfg = cfg
self.device = torch.device(device)
self.device = get_safe_torch_device(device)
self.std = h.linear_schedule(cfg.std_schedule, 0)
self.model = TOLD(cfg).cuda() if torch.cuda.is_available() and device == "cuda" else TOLD(cfg)
self.model = TOLD(cfg)
self.model.to(self.device)
self.model_target = deepcopy(self.model)
self.optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr)
self.pi_optim = torch.optim.Adam(self.model._pi.parameters(), lr=self.cfg.lr)