revision 1

This commit is contained in:
Alexander Soare
2024-04-16 17:15:51 +01:00
parent 43a614c173
commit a9496fde39
4 changed files with 21 additions and 15 deletions

View File

@@ -65,12 +65,16 @@ class ActionChunkingTransformerPolicy(nn.Module):
"ActionChunkingTransformerPolicy does not handle multiple observation steps."
)
def __init__(self, cfg: ActionChunkingTransformerConfig):
def __init__(self, cfg: ActionChunkingTransformerConfig | None = None):
"""
TODO(alexander-soare): Add documentation for all parameters once we have model configs established.
Args:
cfg: Policy configuration class instance or None, in which case the default instantiation of the
configuration class is used.
"""
super().__init__()
if getattr(cfg, "n_obs_steps", 1) != 1:
if cfg is None:
cfg = ActionChunkingTransformerConfig()
if cfg.n_obs_steps != 1:
raise ValueError(self._multiple_obs_steps_not_handled_msg)
self.cfg = cfg

View File

@@ -33,8 +33,6 @@ from lerobot.common.policies.utils import (
populate_queues,
)
logger = logging.getLogger(__name__)
class DiffusionPolicy(nn.Module):
"""
@@ -44,8 +42,17 @@ class DiffusionPolicy(nn.Module):
name = "diffusion"
def __init__(self, cfg: DiffusionConfig, lr_scheduler_num_training_steps: int):
def __init__(self, cfg: DiffusionConfig | None, lr_scheduler_num_training_steps: int = 0):
super().__init__()
"""
Args:
cfg: Policy configuration class instance or None, in which case the default instantiation of the
configuration class is used.
"""
# TODO(alexander-soare): LR scheduler will be removed.
assert lr_scheduler_num_training_steps > 0
if cfg is None:
cfg = DiffusionConfig()
self.cfg = cfg
# queues are populated during rollout of the policy, they contain the n latest observations and actions