backup wip

This commit is contained in:
Alexander Soare
2024-03-19 18:50:04 +00:00
parent ea17f4ce50
commit 896a11f60e
16 changed files with 169 additions and 138 deletions

View File

@@ -12,6 +12,17 @@ class AbstractPolicy(nn.Module, ABC):
documentation for more information.
"""
def __init__(self, n_action_steps: int | None):
"""
n_action_steps: Sets the cache size for storing action trajectories. If None, it is assumed that a single
action is returned by `select_actions` and that doesn't have a horizon dimension. The `forward` method then
adds that dimension.
"""
super().__init__()
self.n_action_steps = n_action_steps
if n_action_steps is not None:
self._action_queue = deque([], maxlen=n_action_steps)
@abstractmethod
def update(self, replay_buffer, step):
"""One step of the policy's learning algorithm."""
@@ -24,10 +35,11 @@ class AbstractPolicy(nn.Module, ABC):
self.load_state_dict(d)
@abstractmethod
def select_action(self, observation) -> Tensor:
def select_actions(self, observation) -> Tensor:
"""Select an action (or trajectory of actions) based on an observation during rollout.
Should return a (batch_size, n_action_steps, *) tensor of actions.
If n_action_steps was provided at initialization, this should return a (batch_size, n_action_steps, *) tensor of
actions. Otherwise if n_actions_steps is None, this should return a (batch_size, *) tensor of actions.
"""
def forward(self, *args, **kwargs) -> Tensor:
@@ -41,18 +53,14 @@ class AbstractPolicy(nn.Module, ABC):
observation, (3) repopulates the action queue when empty. This method handles the aforementioned logic so that
the subclass doesn't have to.
This method effectively wraps the `select_action` method of the subclass. The following assumptions are made:
1. The `select_action` method returns a Tensor of actions with shape (B, H, *) where B is the batch size, H is
This method effectively wraps the `select_actions` method of the subclass. The following assumptions are made:
1. The `select_actions` method returns a Tensor of actions with shape (B, H, *) where B is the batch size, H is
the action trajectory horizon and * is the action dimensions.
2. Prior to the `select_action` method being called, theres is an `n_action_steps` instance attribute defined.
2. Prior to the `select_actions` method being called, theres is an `n_action_steps` instance attribute defined.
"""
n_action_steps_attr = "n_action_steps"
if not hasattr(self, n_action_steps_attr):
raise RuntimeError(f"Underlying policy must have an `{n_action_steps_attr}` attribute")
if not hasattr(self, "_action_queue"):
self._action_queue = deque([], maxlen=getattr(self, n_action_steps_attr))
if self.n_action_steps is None:
return self.select_actions(*args, **kwargs)
if len(self._action_queue) == 0:
# Each element in the queue has shape (B, *).
self._action_queue.extend(self.select_action(*args, **kwargs).transpose(0, 1))
self._action_queue.extend(self.select_actions(*args, **kwargs).transpose(0, 1))
return self._action_queue.popleft()

View File

@@ -42,7 +42,7 @@ def kl_divergence(mu, logvar):
class ActionChunkingTransformerPolicy(AbstractPolicy):
def __init__(self, cfg, device, n_action_steps=1):
super().__init__()
super().__init__(n_action_steps)
self.cfg = cfg
self.n_action_steps = n_action_steps
self.device = device
@@ -147,7 +147,10 @@ class ActionChunkingTransformerPolicy(AbstractPolicy):
return loss
@torch.no_grad()
def select_action(self, observation, step_count):
def select_actions(self, observation, step_count):
if observation["image"].shape[0] != 1:
raise NotImplementedError("Batch size > 1 not handled")
# TODO(rcadene): remove unused step_count
del step_count

View File

@@ -34,7 +34,7 @@ class DiffusionPolicy(AbstractPolicy):
# parameters passed to step
**kwargs,
):
super().__init__()
super().__init__(n_action_steps)
self.cfg = cfg
noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler)
@@ -44,7 +44,6 @@ class DiffusionPolicy(AbstractPolicy):
**cfg_obs_encoder,
)
self.n_action_steps = n_action_steps # needed for the parent class
self.diffusion = DiffusionUnetImagePolicy(
shape_meta=shape_meta,
noise_scheduler=noise_scheduler,
@@ -94,7 +93,7 @@ class DiffusionPolicy(AbstractPolicy):
)
@torch.no_grad()
def select_action(self, observation, step_count):
def select_actions(self, observation, step_count):
# TODO(rcadene): remove unused step_count
del step_count

View File

@@ -1,4 +1,7 @@
def make_policy(cfg):
if cfg.policy.name != "diffusion" and cfg.rollout_batch_size > 1:
raise NotImplementedError("Only diffusion policy supports rollout_batch_size > 1 for the time being.")
if cfg.policy.name == "tdmpc":
from lerobot.common.policies.tdmpc.policy import TDMPC

View File

@@ -90,7 +90,7 @@ class TDMPC(AbstractPolicy):
"""Implementation of TD-MPC learning + inference."""
def __init__(self, cfg, device):
super().__init__()
super().__init__(None)
self.action_dim = cfg.action_dim
self.cfg = cfg
@@ -125,7 +125,10 @@ class TDMPC(AbstractPolicy):
self.model_target.load_state_dict(d["model_target"])
@torch.no_grad()
def select_action(self, observation, step_count):
def select_actions(self, observation, step_count):
if observation["image"].shape[0] != 1:
raise NotImplementedError("Batch size > 1 not handled")
t0 = step_count.item() == 0
obs = {
@@ -133,7 +136,8 @@ class TDMPC(AbstractPolicy):
"rgb": observation["image"].contiguous(),
"state": observation["state"].contiguous(),
}
action = self.act(obs, t0=t0, step=self.step.item())
# Note: unsqueeze needed because `act` still uses non-batch logic.
action = self.act(obs, t0=t0, step=self.step.item()).unsqueeze(0)
return action
@torch.no_grad()
@@ -144,7 +148,7 @@ class TDMPC(AbstractPolicy):
if self.cfg.mpc:
a = self.plan(z, t0=t0, step=step)
else:
a = self.model.pi(z, self.cfg.min_std * self.model.training)
a = self.model.pi(z, self.cfg.min_std * self.model.training).squeeze(0)
return a
@torch.no_grad()