[Port HIL-SERL] Add HF vision encoder option in SAC (#651)

Added support with custom pretrained vision encoder to the modeling sac implementation. Great job @ChorntonYoel !
This commit is contained in:
Yoel
2025-01-31 09:42:13 +01:00
committed by Michel Aractingi
parent 7c89bd1018
commit f1c8bfe01e
4 changed files with 123 additions and 47 deletions

View File

@@ -150,6 +150,10 @@ def get_image_transforms(
sharpness_min_max: tuple[float, float] | None = None,
max_num_transforms: int | None = None,
random_order: bool = False,
interpolation: str | None = None,
image_size: tuple[int, int] | None = None,
image_mean: list[float] | None = None,
image_std: list[float] | None = None,
):
def check_value(name, weight, min_max):
if min_max is not None:
@@ -170,6 +174,18 @@ def get_image_transforms(
weights = []
transforms = []
if image_size is not None:
interpolations = [interpolation.value for interpolation in v2.InterpolationMode]
if interpolation is None:
# Use BICUBIC as default interpolation
interpolation_mode = v2.InterpolationMode.BICUBIC
elif interpolation in interpolations:
interpolation_mode = v2.InterpolationMode(interpolation)
else:
raise ValueError("The interpolation passed is not supported")
# Weight for resizing is always 1
weights.append(1.0)
transforms.append(v2.Resize(size=(image_size[0], image_size[1]), interpolation=interpolation_mode))
if brightness_min_max is not None and brightness_weight > 0.0:
weights.append(brightness_weight)
transforms.append(v2.ColorJitter(brightness=brightness_min_max))
@@ -185,6 +201,15 @@ def get_image_transforms(
if sharpness_min_max is not None and sharpness_weight > 0.0:
weights.append(sharpness_weight)
transforms.append(SharpnessJitter(sharpness=sharpness_min_max))
if image_mean is not None and image_std is not None:
# Weight for normalization is always 1
weights.append(1.0)
transforms.append(
v2.Normalize(
mean=image_mean,
std=image_std,
)
)
n_subset = len(transforms)
if max_num_transforms is not None: