forked from tangger/lerobot
Merge remote-tracking branch 'upstream/main' into unify_policy_api
This commit is contained in:
@@ -65,12 +65,16 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||
"ActionChunkingTransformerPolicy does not handle multiple observation steps."
|
||||
)
|
||||
|
||||
def __init__(self, cfg: ActionChunkingTransformerConfig):
|
||||
def __init__(self, cfg: ActionChunkingTransformerConfig | None = None):
|
||||
"""
|
||||
TODO(alexander-soare): Add documentation for all parameters once we have model configs established.
|
||||
Args:
|
||||
cfg: Policy configuration class instance or None, in which case the default instantiation of the
|
||||
configuration class is used.
|
||||
"""
|
||||
super().__init__()
|
||||
if getattr(cfg, "n_obs_steps", 1) != 1:
|
||||
if cfg is None:
|
||||
cfg = ActionChunkingTransformerConfig()
|
||||
if cfg.n_obs_steps != 1:
|
||||
raise ValueError(self._multiple_obs_steps_not_handled_msg)
|
||||
self.cfg = cfg
|
||||
|
||||
@@ -161,8 +165,8 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||
if self.cfg.n_action_steps is not None:
|
||||
self._action_queue = deque([], maxlen=self.cfg.n_action_steps)
|
||||
|
||||
def forward(self, batch: dict[str, Tensor], **_) -> Tensor:
|
||||
return self.select_action(self, batch)
|
||||
# def forward(self, batch: dict[str, Tensor], **_) -> Tensor:
|
||||
# return self.select_action(self, batch)
|
||||
|
||||
@torch.no_grad
|
||||
def select_action(self, batch: dict[str, Tensor], **_) -> Tensor:
|
||||
|
||||
@@ -33,8 +33,6 @@ from lerobot.common.policies.utils import (
|
||||
populate_queues,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DiffusionPolicy(nn.Module):
|
||||
"""
|
||||
@@ -44,8 +42,17 @@ class DiffusionPolicy(nn.Module):
|
||||
|
||||
name = "diffusion"
|
||||
|
||||
def __init__(self, cfg: DiffusionConfig, lr_scheduler_num_training_steps: int):
|
||||
def __init__(self, cfg: DiffusionConfig | None, lr_scheduler_num_training_steps: int = 0):
|
||||
super().__init__()
|
||||
"""
|
||||
Args:
|
||||
cfg: Policy configuration class instance or None, in which case the default instantiation of the
|
||||
configuration class is used.
|
||||
"""
|
||||
# TODO(alexander-soare): LR scheduler will be removed.
|
||||
assert lr_scheduler_num_training_steps > 0
|
||||
if cfg is None:
|
||||
cfg = DiffusionConfig()
|
||||
self.cfg = cfg
|
||||
|
||||
# queues are populated during rollout of the policy, they contain the n latest observations and actions
|
||||
|
||||
Reference in New Issue
Block a user