Refactor TD-MPC (#103)
Co-authored-by: Cadene <re.cadene@gmail.com> Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
This commit is contained in:
@@ -3,7 +3,7 @@ from dataclasses import dataclass, field
|
||||
|
||||
@dataclass
|
||||
class DiffusionConfig:
|
||||
"""Configuration class for Diffusion Policy.
|
||||
"""Configuration class for DiffusionPolicy.
|
||||
|
||||
Defaults are configured for training with PushT providing proprioceptive and single camera observations.
|
||||
|
||||
@@ -25,11 +25,12 @@ class DiffusionConfig:
|
||||
The key represents the output data name, and the value is a list indicating the dimensions
|
||||
of the corresponding data. For example, "action" refers to an output shape of [14], indicating
|
||||
14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension.
|
||||
normalize_input_modes: A dictionary with key represents the modality (e.g. "observation.state"),
|
||||
and the value specifies the normalization mode to apply. The two availables
|
||||
modes are "mean_std" which substracts the mean and divide by the standard
|
||||
deviation and "min_max" which rescale in a [-1, 1] range.
|
||||
unnormalize_output_modes: Similar dictionary as `normalize_input_modes`, but to unormalize in original scale.
|
||||
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
|
||||
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
|
||||
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
|
||||
[-1, 1] range.
|
||||
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
|
||||
original scale. Note that this is also used for normalizing the training targets.
|
||||
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
|
||||
crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit
|
||||
within the image size. If None, no cropping is done.
|
||||
@@ -70,13 +71,13 @@ class DiffusionConfig:
|
||||
horizon: int = 16
|
||||
n_action_steps: int = 8
|
||||
|
||||
input_shapes: dict[str, list[str]] = field(
|
||||
input_shapes: dict[str, list[int]] = field(
|
||||
default_factory=lambda: {
|
||||
"observation.image": [3, 96, 96],
|
||||
"observation.state": [2],
|
||||
}
|
||||
)
|
||||
output_shapes: dict[str, list[str]] = field(
|
||||
output_shapes: dict[str, list[int]] = field(
|
||||
default_factory=lambda: {
|
||||
"action": [2],
|
||||
}
|
||||
|
||||
@@ -43,15 +43,16 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
def __init__(
|
||||
self,
|
||||
config: DiffusionConfig | None = None,
|
||||
dataset_stats=None,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
config: Policy configuration class instance or None, in which case the default instantiation of
|
||||
the configuration class is used.
|
||||
the configuration class is used.
|
||||
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
|
||||
that they will be passed with a call to `load_state_dict` before the policy is used.
|
||||
"""
|
||||
super().__init__()
|
||||
# TODO(alexander-soare): LR scheduler will be removed.
|
||||
if config is None:
|
||||
config = DiffusionConfig()
|
||||
self.config = config
|
||||
@@ -88,7 +89,7 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
}
|
||||
|
||||
@torch.no_grad
|
||||
def select_action(self, batch: dict[str, Tensor], **_) -> Tensor:
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Select a single action given environment observations.
|
||||
|
||||
This method handles caching a history of observations and an action trajectory generated by the
|
||||
@@ -136,7 +137,7 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
action = self._queues["action"].popleft()
|
||||
return action
|
||||
|
||||
def forward(self, batch: dict[str, Tensor], **_) -> dict[str, Tensor]:
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = self.normalize_targets(batch)
|
||||
|
||||
Reference in New Issue
Block a user