ready for review
This commit is contained in:
@@ -1,5 +1,5 @@
|
|||||||
import copy
|
import copy
|
||||||
from typing import Dict, Tuple, Union
|
from typing import Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
import timm
|
import timm
|
||||||
import torch
|
import torch
|
||||||
@@ -46,7 +46,7 @@ class MultiImageObsEncoder(ModuleAttrMixin):
|
|||||||
share_rgb_model: bool = False,
|
share_rgb_model: bool = False,
|
||||||
# renormalize rgb input with imagenet normalization
|
# renormalize rgb input with imagenet normalization
|
||||||
# assuming input in [0,1]
|
# assuming input in [0,1]
|
||||||
imagenet_norm: bool = False,
|
norm_mean_std: Optional[tuple[float, float]] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Assumes rgb input: B,C,H,W
|
Assumes rgb input: B,C,H,W
|
||||||
@@ -120,13 +120,9 @@ class MultiImageObsEncoder(ModuleAttrMixin):
|
|||||||
this_normalizer = torchvision.transforms.CenterCrop(size=(h, w))
|
this_normalizer = torchvision.transforms.CenterCrop(size=(h, w))
|
||||||
# configure normalizer
|
# configure normalizer
|
||||||
this_normalizer = nn.Identity()
|
this_normalizer = nn.Identity()
|
||||||
if imagenet_norm:
|
if norm_mean_std is not None:
|
||||||
# TODO(rcadene): move normalizer to dataset and env
|
|
||||||
this_normalizer = torchvision.transforms.Normalize(
|
this_normalizer = torchvision.transforms.Normalize(
|
||||||
# Note: This matches the normalization in the original impl. for PushT Image. This may not be
|
mean=norm_mean_std[0], std=norm_mean_std[1]
|
||||||
# the case for other tasks.
|
|
||||||
mean=[127.5, 127.5, 127.5],
|
|
||||||
std=[127.5, 127.5, 127.5],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
this_transform = nn.Sequential(this_resizer, this_randomizer, this_normalizer)
|
this_transform = nn.Sequential(this_resizer, this_randomizer, this_normalizer)
|
||||||
|
|||||||
@@ -81,7 +81,7 @@ obs_encoder:
|
|||||||
# random_crop: True
|
# random_crop: True
|
||||||
use_group_norm: True
|
use_group_norm: True
|
||||||
share_rgb_model: False
|
share_rgb_model: False
|
||||||
imagenet_norm: True
|
norm_mean_std: [0.5, 0.5] # for PushT the original impl normalizes to [-1, 1] (maybe not the case for robomimic envs)
|
||||||
|
|
||||||
rgb_model:
|
rgb_model:
|
||||||
model_name: resnet18
|
model_name: resnet18
|
||||||
|
|||||||
Reference in New Issue
Block a user