forked from tangger/lerobot
backup wip
This commit is contained in:
@@ -12,6 +12,17 @@ class AbstractPolicy(nn.Module, ABC):
|
||||
documentation for more information.
|
||||
"""
|
||||
|
||||
def __init__(self, n_action_steps: int | None):
|
||||
"""
|
||||
n_action_steps: Sets the cache size for storing action trajectories. If None, it is assumed that a single
|
||||
action is returned by `select_actions` and that doesn't have a horizon dimension. The `forward` method then
|
||||
adds that dimension.
|
||||
"""
|
||||
super().__init__()
|
||||
self.n_action_steps = n_action_steps
|
||||
if n_action_steps is not None:
|
||||
self._action_queue = deque([], maxlen=n_action_steps)
|
||||
|
||||
@abstractmethod
|
||||
def update(self, replay_buffer, step):
|
||||
"""One step of the policy's learning algorithm."""
|
||||
@@ -24,10 +35,11 @@ class AbstractPolicy(nn.Module, ABC):
|
||||
self.load_state_dict(d)
|
||||
|
||||
@abstractmethod
|
||||
def select_action(self, observation) -> Tensor:
|
||||
def select_actions(self, observation) -> Tensor:
|
||||
"""Select an action (or trajectory of actions) based on an observation during rollout.
|
||||
|
||||
Should return a (batch_size, n_action_steps, *) tensor of actions.
|
||||
If n_action_steps was provided at initialization, this should return a (batch_size, n_action_steps, *) tensor of
|
||||
actions. Otherwise if n_actions_steps is None, this should return a (batch_size, *) tensor of actions.
|
||||
"""
|
||||
|
||||
def forward(self, *args, **kwargs) -> Tensor:
|
||||
@@ -41,18 +53,14 @@ class AbstractPolicy(nn.Module, ABC):
|
||||
observation, (3) repopulates the action queue when empty. This method handles the aforementioned logic so that
|
||||
the subclass doesn't have to.
|
||||
|
||||
This method effectively wraps the `select_action` method of the subclass. The following assumptions are made:
|
||||
1. The `select_action` method returns a Tensor of actions with shape (B, H, *) where B is the batch size, H is
|
||||
This method effectively wraps the `select_actions` method of the subclass. The following assumptions are made:
|
||||
1. The `select_actions` method returns a Tensor of actions with shape (B, H, *) where B is the batch size, H is
|
||||
the action trajectory horizon and * is the action dimensions.
|
||||
2. Prior to the `select_action` method being called, theres is an `n_action_steps` instance attribute defined.
|
||||
2. Prior to the `select_actions` method being called, theres is an `n_action_steps` instance attribute defined.
|
||||
"""
|
||||
n_action_steps_attr = "n_action_steps"
|
||||
if not hasattr(self, n_action_steps_attr):
|
||||
raise RuntimeError(f"Underlying policy must have an `{n_action_steps_attr}` attribute")
|
||||
if not hasattr(self, "_action_queue"):
|
||||
self._action_queue = deque([], maxlen=getattr(self, n_action_steps_attr))
|
||||
if self.n_action_steps is None:
|
||||
return self.select_actions(*args, **kwargs)
|
||||
if len(self._action_queue) == 0:
|
||||
# Each element in the queue has shape (B, *).
|
||||
self._action_queue.extend(self.select_action(*args, **kwargs).transpose(0, 1))
|
||||
|
||||
self._action_queue.extend(self.select_actions(*args, **kwargs).transpose(0, 1))
|
||||
return self._action_queue.popleft()
|
||||
|
||||
@@ -42,7 +42,7 @@ def kl_divergence(mu, logvar):
|
||||
|
||||
class ActionChunkingTransformerPolicy(AbstractPolicy):
|
||||
def __init__(self, cfg, device, n_action_steps=1):
|
||||
super().__init__()
|
||||
super().__init__(n_action_steps)
|
||||
self.cfg = cfg
|
||||
self.n_action_steps = n_action_steps
|
||||
self.device = device
|
||||
@@ -147,7 +147,10 @@ class ActionChunkingTransformerPolicy(AbstractPolicy):
|
||||
return loss
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, observation, step_count):
|
||||
def select_actions(self, observation, step_count):
|
||||
if observation["image"].shape[0] != 1:
|
||||
raise NotImplementedError("Batch size > 1 not handled")
|
||||
|
||||
# TODO(rcadene): remove unused step_count
|
||||
del step_count
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ class DiffusionPolicy(AbstractPolicy):
|
||||
# parameters passed to step
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
super().__init__(n_action_steps)
|
||||
self.cfg = cfg
|
||||
|
||||
noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler)
|
||||
@@ -44,7 +44,6 @@ class DiffusionPolicy(AbstractPolicy):
|
||||
**cfg_obs_encoder,
|
||||
)
|
||||
|
||||
self.n_action_steps = n_action_steps # needed for the parent class
|
||||
self.diffusion = DiffusionUnetImagePolicy(
|
||||
shape_meta=shape_meta,
|
||||
noise_scheduler=noise_scheduler,
|
||||
@@ -94,7 +93,7 @@ class DiffusionPolicy(AbstractPolicy):
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, observation, step_count):
|
||||
def select_actions(self, observation, step_count):
|
||||
# TODO(rcadene): remove unused step_count
|
||||
del step_count
|
||||
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
def make_policy(cfg):
|
||||
if cfg.policy.name != "diffusion" and cfg.rollout_batch_size > 1:
|
||||
raise NotImplementedError("Only diffusion policy supports rollout_batch_size > 1 for the time being.")
|
||||
|
||||
if cfg.policy.name == "tdmpc":
|
||||
from lerobot.common.policies.tdmpc.policy import TDMPC
|
||||
|
||||
|
||||
@@ -90,7 +90,7 @@ class TDMPC(AbstractPolicy):
|
||||
"""Implementation of TD-MPC learning + inference."""
|
||||
|
||||
def __init__(self, cfg, device):
|
||||
super().__init__()
|
||||
super().__init__(None)
|
||||
self.action_dim = cfg.action_dim
|
||||
|
||||
self.cfg = cfg
|
||||
@@ -125,7 +125,10 @@ class TDMPC(AbstractPolicy):
|
||||
self.model_target.load_state_dict(d["model_target"])
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, observation, step_count):
|
||||
def select_actions(self, observation, step_count):
|
||||
if observation["image"].shape[0] != 1:
|
||||
raise NotImplementedError("Batch size > 1 not handled")
|
||||
|
||||
t0 = step_count.item() == 0
|
||||
|
||||
obs = {
|
||||
@@ -133,7 +136,8 @@ class TDMPC(AbstractPolicy):
|
||||
"rgb": observation["image"].contiguous(),
|
||||
"state": observation["state"].contiguous(),
|
||||
}
|
||||
action = self.act(obs, t0=t0, step=self.step.item())
|
||||
# Note: unsqueeze needed because `act` still uses non-batch logic.
|
||||
action = self.act(obs, t0=t0, step=self.step.item()).unsqueeze(0)
|
||||
return action
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -144,7 +148,7 @@ class TDMPC(AbstractPolicy):
|
||||
if self.cfg.mpc:
|
||||
a = self.plan(z, t0=t0, step=step)
|
||||
else:
|
||||
a = self.model.pi(z, self.cfg.min_std * self.model.training)
|
||||
a = self.model.pi(z, self.cfg.min_std * self.model.training).squeeze(0)
|
||||
return a
|
||||
|
||||
@torch.no_grad()
|
||||
|
||||
Reference in New Issue
Block a user