Remove warnings (#111)

- Replace `use_pretrained_backbone` with `pretrained_backbone_weights`
- Bump diffusers' minimum version `0.26.3` -> `0.27.2`
- Add ignore flags in CI's pytest
- Change Box observation spaces in simulation environments
- Set `version_base="1.2"` in Hydra initializations
- Bump einops' minimum version `0.7.0` -> `0.8.0`
This commit is contained in:
Simon Alibert
2024-04-29 00:31:33 +02:00
committed by GitHub
parent 55dc9f7f51
commit 791506dfb8
14 changed files with 52 additions and 36 deletions

View File

@@ -33,8 +33,8 @@ class ActionChunkingTransformerConfig:
deviation and "min_max" which rescale in a [-1, 1] range.
unnormalize_output_modes: Similar dictionary as `normalize_input_modes`, but to unormalize in original scale.
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
use_pretrained_backbone: Whether the backbone should be initialized with pretrained weights from
torchvision.
pretrained_backbone_weights: Pretrained weights from torchvision to initalize the backbone.
`None` means no pretrained weights.
replace_final_stride_with_dilation: Whether to replace the ResNet's final 2x2 stride with a dilated
convolution.
pre_norm: Whether to use "pre-norm" in the transformer blocks.
@@ -90,7 +90,7 @@ class ActionChunkingTransformerConfig:
# Architecture.
# Vision backbone.
vision_backbone: str = "resnet18"
use_pretrained_backbone: bool = True
pretrained_backbone_weights: str | None = "ResNet18_Weights.IMAGENET1K_V1"
replace_final_stride_with_dilation: int = False
# Transformer layers.
pre_norm: bool = False

View File

@@ -104,7 +104,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
# Backbone for image feature extraction.
backbone_model = getattr(torchvision.models, cfg.vision_backbone)(
replace_stride_with_dilation=[False, False, cfg.replace_final_stride_with_dilation],
pretrained=cfg.use_pretrained_backbone,
weights=cfg.pretrained_backbone_weights,
norm_layer=FrozenBatchNorm2d,
)
# Note: The assumption here is that we are using a ResNet model (and hence layer4 is the final feature

View File

@@ -35,8 +35,8 @@ class DiffusionConfig:
within the image size. If None, no cropping is done.
crop_is_random: Whether the crop should be random at training time (it's always a center crop in eval
mode).
use_pretrained_backbone: Whether the backbone should be initialized with pretrained weights from
torchvision.
pretrained_backbone_weights: Pretrained weights from torchvision to initalize the backbone.
`None` means no pretrained weights.
use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
The group sizes are set to be about 16 (to be precise, feature_dim // 16).
spatial_softmax_num_keypoints: Number of keypoints for SpatialSoftmax.
@@ -96,7 +96,7 @@ class DiffusionConfig:
vision_backbone: str = "resnet18"
crop_shape: tuple[int, int] | None = (84, 84)
crop_is_random: bool = True
use_pretrained_backbone: bool = False
pretrained_backbone_weights: str | None = None
use_group_norm: bool = True
spatial_softmax_num_keypoints: int = 32
# Unet.

View File

@@ -378,13 +378,13 @@ class _RgbEncoder(nn.Module):
# Set up backbone.
backbone_model = getattr(torchvision.models, cfg.vision_backbone)(
pretrained=cfg.use_pretrained_backbone
weights=cfg.pretrained_backbone_weights
)
# Note: This assumes that the layer4 feature map is children()[-3]
# TODO(alexander-soare): Use a safer alternative.
self.backbone = nn.Sequential(*(list(backbone_model.children())[:-2]))
if cfg.use_group_norm:
if cfg.use_pretrained_backbone:
if cfg.pretrained_backbone_weights:
raise ValueError(
"You can't replace BatchNorm in a pretrained model without ruining the weights!"
)

View File

@@ -0,0 +1,12 @@
import warnings
import imageio
def write_video(video_path, stacked_frames, fps):
# Filter out DeprecationWarnings raised from pkg_resources
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore", "pkg_resources is deprecated as an API", category=DeprecationWarning
)
imageio.mimsave(video_path, stacked_frames, fps=fps)

View File

@@ -92,7 +92,8 @@ def init_hydra_config(config_path: str, overrides: list[str] | None = None) -> D
hydra.core.global_hydra.GlobalHydra.instance().clear()
# Hydra needs a path relative to this file.
hydra.initialize(
str(_relative_path_between(Path(config_path).absolute().parent, Path(__file__).absolute().parent))
str(_relative_path_between(Path(config_path).absolute().parent, Path(__file__).absolute().parent)),
version_base="1.2",
)
cfg = hydra.compose(Path(config_path).stem, overrides)
return cfg