[HIL-SERL] Update CI to allow installation of prerelease versions for lerobot (#1018)
Co-authored-by: imstevenpmwork <steven.palma@huggingface.co>
This commit is contained in:
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@@ -101,7 +101,7 @@ jobs:
|
|||||||
python-version: "3.10"
|
python-version: "3.10"
|
||||||
|
|
||||||
- name: Install lerobot
|
- name: Install lerobot
|
||||||
run: uv sync --extra "test" --prerelease=allow
|
run: uv sync --extra "test"
|
||||||
|
|
||||||
- name: Test with pytest
|
- name: Test with pytest
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
@@ -318,7 +318,7 @@ class LeRobotDatasetMetadata:
|
|||||||
obj.root.mkdir(parents=True, exist_ok=False)
|
obj.root.mkdir(parents=True, exist_ok=False)
|
||||||
|
|
||||||
if robot is not None:
|
if robot is not None:
|
||||||
features = {**(features or {}), **get_features_from_robot(robot)}
|
features = get_features_from_robot(robot, use_videos)
|
||||||
robot_type = robot.robot_type
|
robot_type = robot.robot_type
|
||||||
if not all(cam.fps == fps for cam in robot.cameras.values()):
|
if not all(cam.fps == fps for cam in robot.cameras.values()):
|
||||||
logging.warning(
|
logging.warning(
|
||||||
@@ -821,9 +821,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
if self.features[key]["dtype"] in ["image", "video"]:
|
if self.features[key]["dtype"] in ["image", "video"]:
|
||||||
img_path = self._get_image_file_path(
|
img_path = self._get_image_file_path(
|
||||||
episode_index=self.episode_buffer["episode_index"],
|
episode_index=self.episode_buffer["episode_index"], image_key=key, frame_index=frame_index
|
||||||
image_key=key,
|
|
||||||
frame_index=frame_index,
|
|
||||||
)
|
)
|
||||||
if frame_index == 0:
|
if frame_index == 0:
|
||||||
img_path.parent.mkdir(parents=True, exist_ok=True)
|
img_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
@@ -869,10 +867,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
for key, ft in self.features.items():
|
for key, ft in self.features.items():
|
||||||
# index, episode_index, task_index are already processed above, and image and video
|
# index, episode_index, task_index are already processed above, and image and video
|
||||||
# are processed separately by storing image path and frame info as meta data
|
# are processed separately by storing image path and frame info as meta data
|
||||||
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in [
|
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]:
|
||||||
"image",
|
|
||||||
"video",
|
|
||||||
]:
|
|
||||||
continue
|
continue
|
||||||
episode_buffer[key] = np.stack(episode_buffer[key])
|
episode_buffer[key] = np.stack(episode_buffer[key])
|
||||||
|
|
||||||
|
|||||||
@@ -37,35 +37,29 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
|||||||
"""
|
"""
|
||||||
# map to expected inputs for the policy
|
# map to expected inputs for the policy
|
||||||
return_observations = {}
|
return_observations = {}
|
||||||
# TODO: You have to merge all tensors from agent key and extra key
|
if "pixels" in observations:
|
||||||
# You don't keep sensor param key in the observation
|
if isinstance(observations["pixels"], dict):
|
||||||
# And you keep sensor data rgb
|
imgs = {f"observation.images.{key}": img for key, img in observations["pixels"].items()}
|
||||||
for key, img in observations.items():
|
else:
|
||||||
if "images" not in key:
|
imgs = {"observation.image": observations["pixels"]}
|
||||||
continue
|
|
||||||
|
|
||||||
# TODO(aliberts, rcadene): use transforms.ToTensor()?
|
for imgkey, img in imgs.items():
|
||||||
if not torch.is_tensor(img):
|
# TODO(aliberts, rcadene): use transforms.ToTensor()?
|
||||||
img = torch.from_numpy(img)
|
img = torch.from_numpy(img)
|
||||||
|
|
||||||
if img.ndim == 3:
|
# sanity check that images are channel last
|
||||||
img = img.unsqueeze(0)
|
_, h, w, c = img.shape
|
||||||
|
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
|
||||||
|
|
||||||
# sanity check that images are channel last
|
# sanity check that images are uint8
|
||||||
_, h, w, c = img.shape
|
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
|
||||||
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
|
|
||||||
|
|
||||||
# sanity check that images are uint8
|
# convert to channel first of type float32 in range [0,1]
|
||||||
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
|
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
|
||||||
|
img = img.type(torch.float32)
|
||||||
|
img /= 255
|
||||||
|
|
||||||
# convert to channel first of type float32 in range [0,1]
|
return_observations[imgkey] = img
|
||||||
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
|
|
||||||
img = img.type(torch.float32)
|
|
||||||
img /= 255
|
|
||||||
|
|
||||||
return_observations[key] = img
|
|
||||||
# obs state agent qpos and qvel
|
|
||||||
# image
|
|
||||||
|
|
||||||
if "environment_state" in observations:
|
if "environment_state" in observations:
|
||||||
return_observations["observation.environment_state"] = torch.from_numpy(
|
return_observations["observation.environment_state"] = torch.from_numpy(
|
||||||
@@ -74,8 +68,7 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
|||||||
|
|
||||||
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing
|
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing
|
||||||
# requirement for "agent_pos"
|
# requirement for "agent_pos"
|
||||||
# return_observations["observation.state"] = torch.from_numpy(observations["agent_pos"]).float()
|
return_observations["observation.state"] = torch.from_numpy(observations["agent_pos"]).float()
|
||||||
return_observations["observation.state"] = observations["observation.state"].float()
|
|
||||||
return return_observations
|
return return_observations
|
||||||
|
|
||||||
|
|
||||||
@@ -93,7 +86,7 @@ def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]:
|
|||||||
else:
|
else:
|
||||||
feature = ft
|
feature = ft
|
||||||
|
|
||||||
policy_key = env_cfg.features_map.get(key, key)
|
policy_key = env_cfg.features_map[key]
|
||||||
policy_features[policy_key] = feature
|
policy_features[policy_key] = feature
|
||||||
|
|
||||||
return policy_features
|
return policy_features
|
||||||
|
|||||||
@@ -88,7 +88,7 @@ class RecordControlConfig(ControlConfig):
|
|||||||
# Resume recording on an existing dataset.
|
# Resume recording on an existing dataset.
|
||||||
resume: bool = False
|
resume: bool = False
|
||||||
# Reset follower arms to an initial configuration.
|
# Reset follower arms to an initial configuration.
|
||||||
reset_follower_arms: bool = True
|
reset_follower_arms: bool = False
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
# HACK: We parse again the cli args here to get the pretrained path if there was one.
|
# HACK: We parse again the cli args here to get the pretrained path if there was one.
|
||||||
|
|||||||
@@ -129,22 +129,16 @@ def predict_action(observation, policy, device, use_amp):
|
|||||||
return action
|
return action
|
||||||
|
|
||||||
|
|
||||||
def init_keyboard_listener(assign_rewards=False):
|
def init_keyboard_listener():
|
||||||
"""
|
"""
|
||||||
Initializes a keyboard listener to enable early termination of an episode
|
Initializes a keyboard listener to enable early termination of an episode
|
||||||
or environment reset by pressing the right arrow key ('->'). This may require
|
or environment reset by pressing the right arrow key ('->'). This may require
|
||||||
sudo permissions to allow the terminal to monitor keyboard events.
|
sudo permissions to allow the terminal to monitor keyboard events.
|
||||||
|
|
||||||
Args:
|
|
||||||
assign_rewards (bool): If True, allows annotating the collected trajectory
|
|
||||||
with a binary reward at the end of the episode to indicate success.
|
|
||||||
"""
|
"""
|
||||||
events = {}
|
events = {}
|
||||||
events["exit_early"] = False
|
events["exit_early"] = False
|
||||||
events["rerecord_episode"] = False
|
events["rerecord_episode"] = False
|
||||||
events["stop_recording"] = False
|
events["stop_recording"] = False
|
||||||
if assign_rewards:
|
|
||||||
events["next.reward"] = 0
|
|
||||||
|
|
||||||
if is_headless():
|
if is_headless():
|
||||||
logging.warning(
|
logging.warning(
|
||||||
@@ -169,12 +163,6 @@ def init_keyboard_listener(assign_rewards=False):
|
|||||||
print("Escape key pressed. Stopping data recording...")
|
print("Escape key pressed. Stopping data recording...")
|
||||||
events["stop_recording"] = True
|
events["stop_recording"] = True
|
||||||
events["exit_early"] = True
|
events["exit_early"] = True
|
||||||
elif assign_rewards and key == keyboard.Key.space:
|
|
||||||
events["next.reward"] = 1 if events["next.reward"] == 0 else 0
|
|
||||||
print(
|
|
||||||
"Space key pressed. Assigning new reward to the subsequent frames. New reward:",
|
|
||||||
events["next.reward"],
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error handling key press: {e}")
|
print(f"Error handling key press: {e}")
|
||||||
|
|||||||
@@ -276,7 +276,7 @@ def record(
|
|||||||
|
|
||||||
if not robot.is_connected:
|
if not robot.is_connected:
|
||||||
robot.connect()
|
robot.connect()
|
||||||
listener, events = init_keyboard_listener(assign_rewards=cfg.assign_rewards)
|
listener, events = init_keyboard_listener()
|
||||||
|
|
||||||
# Execute a few seconds without recording to:
|
# Execute a few seconds without recording to:
|
||||||
# 1. teleoperate the robot to move it in starting position if no policy provided,
|
# 1. teleoperate the robot to move it in starting position if no policy provided,
|
||||||
|
|||||||
@@ -201,14 +201,9 @@ def record(
|
|||||||
resume: bool = False,
|
resume: bool = False,
|
||||||
local_files_only: bool = False,
|
local_files_only: bool = False,
|
||||||
run_compute_stats: bool = True,
|
run_compute_stats: bool = True,
|
||||||
assign_rewards: bool = False,
|
|
||||||
) -> LeRobotDataset:
|
) -> LeRobotDataset:
|
||||||
# Load pretrained policy
|
# Load pretrained policy
|
||||||
|
|
||||||
extra_features = (
|
|
||||||
{"next.reward": {"dtype": "int64", "shape": (1,), "names": None}} if assign_rewards else None
|
|
||||||
)
|
|
||||||
|
|
||||||
policy = None
|
policy = None
|
||||||
if pretrained_policy_name_or_path is not None:
|
if pretrained_policy_name_or_path is not None:
|
||||||
policy, policy_fps, device, use_amp = init_policy(pretrained_policy_name_or_path, policy_overrides)
|
policy, policy_fps, device, use_amp = init_policy(pretrained_policy_name_or_path, policy_overrides)
|
||||||
@@ -221,7 +216,7 @@ def record(
|
|||||||
raise ValueError("Either policy or process_action_fn has to be set to enable control in sim.")
|
raise ValueError("Either policy or process_action_fn has to be set to enable control in sim.")
|
||||||
|
|
||||||
# initialize listener before sim env
|
# initialize listener before sim env
|
||||||
listener, events = init_keyboard_listener(assign_rewards=assign_rewards)
|
listener, events = init_keyboard_listener()
|
||||||
|
|
||||||
# create sim env
|
# create sim env
|
||||||
env = env()
|
env = env()
|
||||||
@@ -269,7 +264,6 @@ def record(
|
|||||||
"shape": env.action_space.shape,
|
"shape": env.action_space.shape,
|
||||||
"names": None,
|
"names": None,
|
||||||
}
|
}
|
||||||
features = {**features, **extra_features}
|
|
||||||
|
|
||||||
# Create empty dataset or load existing saved episodes
|
# Create empty dataset or load existing saved episodes
|
||||||
sanity_check_dataset_name(repo_id, policy)
|
sanity_check_dataset_name(repo_id, policy)
|
||||||
@@ -321,13 +315,6 @@ def record(
|
|||||||
"timestamp": env_timestamp,
|
"timestamp": env_timestamp,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Overwrite environment reward with manually assigned reward
|
|
||||||
if assign_rewards:
|
|
||||||
frame["next.reward"] = events["next.reward"]
|
|
||||||
|
|
||||||
# Should success always be false to match what we do in control_utils?
|
|
||||||
frame["next.success"] = False
|
|
||||||
|
|
||||||
for key in image_keys:
|
for key in image_keys:
|
||||||
if not key.startswith("observation.image"):
|
if not key.startswith("observation.image"):
|
||||||
frame["observation.image." + key] = observation[key]
|
frame["observation.image." + key] = observation[key]
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ dependencies = [
|
|||||||
"datasets>=2.19.0",
|
"datasets>=2.19.0",
|
||||||
"deepdiff>=7.0.1",
|
"deepdiff>=7.0.1",
|
||||||
"diffusers>=0.27.2",
|
"diffusers>=0.27.2",
|
||||||
"draccus>=0.10.0",
|
"draccus==0.10.0",
|
||||||
"einops>=0.8.0",
|
"einops>=0.8.0",
|
||||||
"flask>=3.0.3",
|
"flask>=3.0.3",
|
||||||
"gdown>=5.1.0",
|
"gdown>=5.1.0",
|
||||||
@@ -70,7 +70,7 @@ dependencies = [
|
|||||||
"pyzmq>=26.2.1",
|
"pyzmq>=26.2.1",
|
||||||
"rerun-sdk>=0.21.0",
|
"rerun-sdk>=0.21.0",
|
||||||
"termcolor>=2.4.0",
|
"termcolor>=2.4.0",
|
||||||
"torch>=2.2.1",
|
"torch>=2.2.1,<=2.6.0",
|
||||||
"torchcodec>=0.2.1; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')",
|
"torchcodec>=0.2.1; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')",
|
||||||
"torchmetrics>=1.6.0",
|
"torchmetrics>=1.6.0",
|
||||||
"torchvision>=0.21.0",
|
"torchvision>=0.21.0",
|
||||||
@@ -89,7 +89,6 @@ dynamixel = ["dynamixel-sdk>=3.7.31", "pynput>=1.7.7"]
|
|||||||
feetech = ["feetech-servo-sdk>=1.0.0", "pynput>=1.7.7"]
|
feetech = ["feetech-servo-sdk>=1.0.0", "pynput>=1.7.7"]
|
||||||
hilserl = ["transformers>=4.48.0", "torchmetrics>=1.6.0"]
|
hilserl = ["transformers>=4.48.0", "torchmetrics>=1.6.0"]
|
||||||
intelrealsense = ["pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'"]
|
intelrealsense = ["pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'"]
|
||||||
mani_skill = ["mani-skill==3.0.0b20"]
|
|
||||||
pi0 = ["transformers>=4.48.0"]
|
pi0 = ["transformers>=4.48.0"]
|
||||||
pusht = ["gym-pusht>=0.1.5 ; python_version < '4.0'"]
|
pusht = ["gym-pusht>=0.1.5 ; python_version < '4.0'"]
|
||||||
stretch = [
|
stretch = [
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ from lerobot.common.policies.hilserl.classifier.modeling_classifier import (
|
|||||||
ClassifierConfig,
|
ClassifierConfig,
|
||||||
ClassifierOutput,
|
ClassifierOutput,
|
||||||
)
|
)
|
||||||
|
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||||
from tests.utils import require_package
|
from tests.utils import require_package
|
||||||
|
|
||||||
|
|
||||||
@@ -27,19 +28,39 @@ def test_binary_classifier_with_default_params():
|
|||||||
)
|
)
|
||||||
|
|
||||||
config = ClassifierConfig()
|
config = ClassifierConfig()
|
||||||
|
config.input_features = {
|
||||||
|
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||||
|
}
|
||||||
|
config.output_features = {
|
||||||
|
"next.reward": PolicyFeature(type=FeatureType.REWARD, shape=(1,)),
|
||||||
|
}
|
||||||
|
config.normalization_mapping = {
|
||||||
|
"VISUAL": NormalizationMode.IDENTITY,
|
||||||
|
"REWARD": NormalizationMode.IDENTITY,
|
||||||
|
}
|
||||||
|
config.num_cameras = 1
|
||||||
classifier = Classifier(config)
|
classifier = Classifier(config)
|
||||||
|
|
||||||
batch_size = 10
|
batch_size = 10
|
||||||
|
|
||||||
input = torch.rand(batch_size, 3, 224, 224)
|
input = {
|
||||||
output = classifier(input)
|
"observation.image": torch.rand((batch_size, 3, 224, 224)),
|
||||||
|
"next.reward": torch.randint(low=0, high=2, size=(batch_size,)).float(),
|
||||||
|
}
|
||||||
|
|
||||||
|
images, labels = classifier.extract_images_and_labels(input)
|
||||||
|
assert len(images) == 1
|
||||||
|
assert images[0].shape == torch.Size([batch_size, 3, 224, 224])
|
||||||
|
assert labels.shape == torch.Size([batch_size])
|
||||||
|
|
||||||
|
output = classifier.predict(images)
|
||||||
|
|
||||||
assert output is not None
|
assert output is not None
|
||||||
assert output.logits.shape == torch.Size([batch_size])
|
assert output.logits.size() == torch.Size([batch_size])
|
||||||
assert not torch.isnan(output.logits).any(), "Tensor contains NaN values"
|
assert not torch.isnan(output.logits).any(), "Tensor contains NaN values"
|
||||||
assert output.probabilities.shape == torch.Size([batch_size])
|
assert output.probabilities.shape == torch.Size([batch_size])
|
||||||
assert not torch.isnan(output.probabilities).any(), "Tensor contains NaN values"
|
assert not torch.isnan(output.probabilities).any(), "Tensor contains NaN values"
|
||||||
assert output.hidden_states.shape == torch.Size([batch_size, 2048])
|
assert output.hidden_states.shape == torch.Size([batch_size, 512])
|
||||||
assert not torch.isnan(output.hidden_states).any(), "Tensor contains NaN values"
|
assert not torch.isnan(output.hidden_states).any(), "Tensor contains NaN values"
|
||||||
|
|
||||||
|
|
||||||
@@ -50,20 +71,37 @@ def test_multiclass_classifier():
|
|||||||
)
|
)
|
||||||
|
|
||||||
num_classes = 5
|
num_classes = 5
|
||||||
config = ClassifierConfig(num_classes=num_classes)
|
config = ClassifierConfig()
|
||||||
|
config.input_features = {
|
||||||
|
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||||
|
}
|
||||||
|
config.output_features = {
|
||||||
|
"next.reward": PolicyFeature(type=FeatureType.REWARD, shape=(num_classes,)),
|
||||||
|
}
|
||||||
|
config.num_cameras = 1
|
||||||
|
config.num_classes = num_classes
|
||||||
classifier = Classifier(config)
|
classifier = Classifier(config)
|
||||||
|
|
||||||
batch_size = 10
|
batch_size = 10
|
||||||
|
|
||||||
input = torch.rand(batch_size, 3, 224, 224)
|
input = {
|
||||||
output = classifier(input)
|
"observation.image": torch.rand((batch_size, 3, 224, 224)),
|
||||||
|
"next.reward": torch.rand((batch_size, num_classes)),
|
||||||
|
}
|
||||||
|
|
||||||
|
images, labels = classifier.extract_images_and_labels(input)
|
||||||
|
assert len(images) == 1
|
||||||
|
assert images[0].shape == torch.Size([batch_size, 3, 224, 224])
|
||||||
|
assert labels.shape == torch.Size([batch_size, num_classes])
|
||||||
|
|
||||||
|
output = classifier.predict(images)
|
||||||
|
|
||||||
assert output is not None
|
assert output is not None
|
||||||
assert output.logits.shape == torch.Size([batch_size, num_classes])
|
assert output.logits.shape == torch.Size([batch_size, num_classes])
|
||||||
assert not torch.isnan(output.logits).any(), "Tensor contains NaN values"
|
assert not torch.isnan(output.logits).any(), "Tensor contains NaN values"
|
||||||
assert output.probabilities.shape == torch.Size([batch_size, num_classes])
|
assert output.probabilities.shape == torch.Size([batch_size, num_classes])
|
||||||
assert not torch.isnan(output.probabilities).any(), "Tensor contains NaN values"
|
assert not torch.isnan(output.probabilities).any(), "Tensor contains NaN values"
|
||||||
assert output.hidden_states.shape == torch.Size([batch_size, 2048])
|
assert output.hidden_states.shape == torch.Size([batch_size, 512])
|
||||||
assert not torch.isnan(output.hidden_states).any(), "Tensor contains NaN values"
|
assert not torch.isnan(output.hidden_states).any(), "Tensor contains NaN values"
|
||||||
|
|
||||||
|
|
||||||
@@ -87,9 +125,9 @@ def test_explicit_device_setup():
|
|||||||
Classifier,
|
Classifier,
|
||||||
)
|
)
|
||||||
|
|
||||||
config = ClassifierConfig(device="meta")
|
config = ClassifierConfig(device="cpu")
|
||||||
assert config.device == "meta"
|
assert config.device == "cpu"
|
||||||
|
|
||||||
classifier = Classifier(config)
|
classifier = Classifier(config)
|
||||||
for p in classifier.parameters():
|
for p in classifier.parameters():
|
||||||
assert p.device == torch.device("meta")
|
assert p.device == torch.device("cpu")
|
||||||
|
|||||||
@@ -1,310 +0,0 @@
|
|||||||
import os
|
|
||||||
import tempfile
|
|
||||||
from pathlib import Path
|
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
from hydra import compose, initialize_config_dir
|
|
||||||
from torch import nn
|
|
||||||
from torch.utils.data import Dataset
|
|
||||||
|
|
||||||
from lerobot.common.policies.hilserl.classifier.configuration_classifier import (
|
|
||||||
ClassifierConfig,
|
|
||||||
)
|
|
||||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
|
|
||||||
from lerobot.scripts.train_hilserl_classifier import (
|
|
||||||
create_balanced_sampler,
|
|
||||||
train,
|
|
||||||
train_epoch,
|
|
||||||
validate,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class MockDataset(Dataset):
|
|
||||||
def __init__(self, data):
|
|
||||||
self.data = data
|
|
||||||
self.meta = MagicMock()
|
|
||||||
self.meta.stats = {}
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
return self.data[idx]
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.data)
|
|
||||||
|
|
||||||
|
|
||||||
def make_dummy_model():
|
|
||||||
model_config = ClassifierConfig(
|
|
||||||
num_classes=2,
|
|
||||||
model_name="hf-tiny-model-private/tiny-random-ResNetModel",
|
|
||||||
num_cameras=1,
|
|
||||||
)
|
|
||||||
model = Classifier(config=model_config)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def test_create_balanced_sampler():
|
|
||||||
# Mock dataset with imbalanced classes
|
|
||||||
data = [
|
|
||||||
{"label": 0},
|
|
||||||
{"label": 0},
|
|
||||||
{"label": 1},
|
|
||||||
{"label": 0},
|
|
||||||
{"label": 1},
|
|
||||||
{"label": 1},
|
|
||||||
{"label": 1},
|
|
||||||
{"label": 1},
|
|
||||||
]
|
|
||||||
dataset = MockDataset(data)
|
|
||||||
cfg = MagicMock()
|
|
||||||
cfg.training.label_key = "label"
|
|
||||||
|
|
||||||
sampler = create_balanced_sampler(dataset, cfg)
|
|
||||||
|
|
||||||
# Get weights from the sampler
|
|
||||||
weights = sampler.weights.float()
|
|
||||||
|
|
||||||
# Check that samples have appropriate weights
|
|
||||||
labels = [item["label"] for item in data]
|
|
||||||
class_counts = torch.tensor([labels.count(0), labels.count(1)], dtype=torch.float32)
|
|
||||||
class_weights = 1.0 / class_counts
|
|
||||||
expected_weights = torch.tensor([class_weights[label] for label in labels], dtype=torch.float32)
|
|
||||||
|
|
||||||
# Test that the weights are correct
|
|
||||||
assert torch.allclose(weights, expected_weights)
|
|
||||||
|
|
||||||
|
|
||||||
def test_train_epoch():
|
|
||||||
model = make_dummy_model()
|
|
||||||
# Mock components
|
|
||||||
model.train = MagicMock()
|
|
||||||
|
|
||||||
train_loader = [
|
|
||||||
{
|
|
||||||
"image": torch.rand(2, 3, 224, 224),
|
|
||||||
"label": torch.tensor([0.0, 1.0]),
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
criterion = nn.BCEWithLogitsLoss()
|
|
||||||
optimizer = MagicMock()
|
|
||||||
grad_scaler = MagicMock()
|
|
||||||
device = torch.device("cpu")
|
|
||||||
logger = MagicMock()
|
|
||||||
step = 0
|
|
||||||
cfg = MagicMock()
|
|
||||||
cfg.training.image_keys = ["image"]
|
|
||||||
cfg.training.label_key = "label"
|
|
||||||
cfg.training.use_amp = False
|
|
||||||
|
|
||||||
# Call the function under test
|
|
||||||
train_epoch(
|
|
||||||
model,
|
|
||||||
train_loader,
|
|
||||||
criterion,
|
|
||||||
optimizer,
|
|
||||||
grad_scaler,
|
|
||||||
device,
|
|
||||||
logger,
|
|
||||||
step,
|
|
||||||
cfg,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check that model.train() was called
|
|
||||||
model.train.assert_called_once()
|
|
||||||
|
|
||||||
# Check that optimizer.zero_grad() was called
|
|
||||||
optimizer.zero_grad.assert_called()
|
|
||||||
|
|
||||||
# Check that logger.log_dict was called
|
|
||||||
logger.log_dict.assert_called()
|
|
||||||
|
|
||||||
|
|
||||||
def test_validate():
|
|
||||||
model = make_dummy_model()
|
|
||||||
|
|
||||||
# Mock components
|
|
||||||
model.eval = MagicMock()
|
|
||||||
val_loader = [
|
|
||||||
{
|
|
||||||
"image": torch.rand(2, 3, 224, 224),
|
|
||||||
"label": torch.tensor([0.0, 1.0]),
|
|
||||||
}
|
|
||||||
]
|
|
||||||
criterion = nn.BCEWithLogitsLoss()
|
|
||||||
device = torch.device("cpu")
|
|
||||||
logger = MagicMock()
|
|
||||||
cfg = MagicMock()
|
|
||||||
cfg.training.image_keys = ["image"]
|
|
||||||
cfg.training.label_key = "label"
|
|
||||||
cfg.training.use_amp = False
|
|
||||||
|
|
||||||
# Call validate
|
|
||||||
accuracy, eval_info = validate(model, val_loader, criterion, device, logger, cfg)
|
|
||||||
|
|
||||||
# Check that model.eval() was called
|
|
||||||
model.eval.assert_called_once()
|
|
||||||
|
|
||||||
# Check accuracy/eval_info are calculated and of the correct type
|
|
||||||
assert isinstance(accuracy, float)
|
|
||||||
assert isinstance(eval_info, dict)
|
|
||||||
|
|
||||||
|
|
||||||
def test_train_epoch_multiple_cameras():
|
|
||||||
model_config = ClassifierConfig(
|
|
||||||
num_classes=2,
|
|
||||||
model_name="hf-tiny-model-private/tiny-random-ResNetModel",
|
|
||||||
num_cameras=2,
|
|
||||||
)
|
|
||||||
model = Classifier(config=model_config)
|
|
||||||
|
|
||||||
# Mock components
|
|
||||||
model.train = MagicMock()
|
|
||||||
|
|
||||||
train_loader = [
|
|
||||||
{
|
|
||||||
"image_1": torch.rand(2, 3, 224, 224),
|
|
||||||
"image_2": torch.rand(2, 3, 224, 224),
|
|
||||||
"label": torch.tensor([0.0, 1.0]),
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
criterion = nn.BCEWithLogitsLoss()
|
|
||||||
optimizer = MagicMock()
|
|
||||||
grad_scaler = MagicMock()
|
|
||||||
device = torch.device("cpu")
|
|
||||||
logger = MagicMock()
|
|
||||||
step = 0
|
|
||||||
cfg = MagicMock()
|
|
||||||
cfg.training.image_keys = ["image_1", "image_2"]
|
|
||||||
cfg.training.label_key = "label"
|
|
||||||
cfg.training.use_amp = False
|
|
||||||
|
|
||||||
# Call the function under test
|
|
||||||
train_epoch(
|
|
||||||
model,
|
|
||||||
train_loader,
|
|
||||||
criterion,
|
|
||||||
optimizer,
|
|
||||||
grad_scaler,
|
|
||||||
device,
|
|
||||||
logger,
|
|
||||||
step,
|
|
||||||
cfg,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check that model.train() was called
|
|
||||||
model.train.assert_called_once()
|
|
||||||
|
|
||||||
# Check that optimizer.zero_grad() was called
|
|
||||||
optimizer.zero_grad.assert_called()
|
|
||||||
|
|
||||||
# Check that logger.log_dict was called
|
|
||||||
logger.log_dict.assert_called()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("resume", [True, False])
|
|
||||||
@patch("lerobot.scripts.train_hilserl_classifier.init_hydra_config")
|
|
||||||
@patch("lerobot.scripts.train_hilserl_classifier.Logger.get_last_checkpoint_dir")
|
|
||||||
@patch("lerobot.scripts.train_hilserl_classifier.Logger.get_last_pretrained_model_dir")
|
|
||||||
@patch("lerobot.scripts.train_hilserl_classifier.Logger")
|
|
||||||
@patch("lerobot.scripts.train_hilserl_classifier.LeRobotDataset")
|
|
||||||
@patch("lerobot.scripts.train_hilserl_classifier.get_model")
|
|
||||||
def test_resume_function(
|
|
||||||
mock_get_model,
|
|
||||||
mock_dataset,
|
|
||||||
mock_logger,
|
|
||||||
mock_get_last_pretrained_model_dir,
|
|
||||||
mock_get_last_checkpoint_dir,
|
|
||||||
mock_init_hydra_config,
|
|
||||||
resume,
|
|
||||||
):
|
|
||||||
# Initialize Hydra
|
|
||||||
test_file_dir = os.path.dirname(os.path.abspath(__file__))
|
|
||||||
config_dir = os.path.abspath(os.path.join(test_file_dir, "..", "lerobot", "configs", "policy"))
|
|
||||||
assert os.path.exists(config_dir), f"Config directory does not exist at {config_dir}"
|
|
||||||
|
|
||||||
with initialize_config_dir(config_dir=config_dir, job_name="test_app", version_base="1.2"):
|
|
||||||
cfg = compose(
|
|
||||||
config_name="hilserl_classifier",
|
|
||||||
overrides=[
|
|
||||||
"device=cpu",
|
|
||||||
"seed=42",
|
|
||||||
f"output_dir={tempfile.mkdtemp()}",
|
|
||||||
"wandb.enable=False",
|
|
||||||
f"resume={resume}",
|
|
||||||
"dataset_repo_id=dataset_repo_id",
|
|
||||||
"train_split_proportion=0.8",
|
|
||||||
"training.num_workers=0",
|
|
||||||
"training.batch_size=2",
|
|
||||||
"training.image_keys=[image]",
|
|
||||||
"training.label_key=label",
|
|
||||||
"training.use_amp=False",
|
|
||||||
"training.num_epochs=1",
|
|
||||||
"eval.batch_size=2",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mock the init_hydra_config function to return cfg
|
|
||||||
mock_init_hydra_config.return_value = cfg
|
|
||||||
|
|
||||||
# Mock dataset
|
|
||||||
dataset = MockDataset([{"image": torch.rand(3, 224, 224), "label": i % 2} for i in range(10)])
|
|
||||||
mock_dataset.return_value = dataset
|
|
||||||
|
|
||||||
# Mock checkpoint handling
|
|
||||||
mock_checkpoint_dir = MagicMock(spec=Path)
|
|
||||||
mock_checkpoint_dir.exists.return_value = resume # Only exists if resuming
|
|
||||||
mock_get_last_checkpoint_dir.return_value = mock_checkpoint_dir
|
|
||||||
mock_get_last_pretrained_model_dir.return_value = Path(tempfile.mkdtemp())
|
|
||||||
|
|
||||||
# Mock logger
|
|
||||||
logger = MagicMock()
|
|
||||||
resumed_step = 1000
|
|
||||||
if resume:
|
|
||||||
logger.load_last_training_state.return_value = resumed_step
|
|
||||||
else:
|
|
||||||
logger.load_last_training_state.return_value = 0
|
|
||||||
mock_logger.return_value = logger
|
|
||||||
|
|
||||||
# Instantiate the model and set make_policy to return it
|
|
||||||
model = make_dummy_model()
|
|
||||||
mock_get_model.return_value = model
|
|
||||||
|
|
||||||
# Call train
|
|
||||||
train(cfg)
|
|
||||||
|
|
||||||
# Check that checkpoint handling methods were called
|
|
||||||
if resume:
|
|
||||||
mock_get_last_checkpoint_dir.assert_called_once_with(Path(cfg.output_dir))
|
|
||||||
mock_get_last_pretrained_model_dir.assert_called_once_with(Path(cfg.output_dir))
|
|
||||||
mock_checkpoint_dir.exists.assert_called_once()
|
|
||||||
logger.load_last_training_state.assert_called_once()
|
|
||||||
else:
|
|
||||||
mock_get_last_checkpoint_dir.assert_not_called()
|
|
||||||
mock_get_last_pretrained_model_dir.assert_not_called()
|
|
||||||
mock_checkpoint_dir.exists.assert_not_called()
|
|
||||||
logger.load_last_training_state.assert_not_called()
|
|
||||||
|
|
||||||
# Collect the steps from logger.log_dict calls
|
|
||||||
train_log_calls = logger.log_dict.call_args_list
|
|
||||||
|
|
||||||
# Extract the steps used in the train logging
|
|
||||||
steps = []
|
|
||||||
for call in train_log_calls:
|
|
||||||
mode = call.kwargs.get("mode", call.args[2] if len(call.args) > 2 else None)
|
|
||||||
if mode == "train":
|
|
||||||
step = call.kwargs.get("step", call.args[1] if len(call.args) > 1 else None)
|
|
||||||
steps.append(step)
|
|
||||||
|
|
||||||
expected_start_step = resumed_step if resume else 0
|
|
||||||
|
|
||||||
# Calculate expected_steps
|
|
||||||
train_size = int(cfg.train_split_proportion * len(dataset))
|
|
||||||
batch_size = cfg.training.batch_size
|
|
||||||
num_batches = (train_size + batch_size - 1) // batch_size
|
|
||||||
|
|
||||||
expected_steps = [expected_start_step + i for i in range(num_batches)]
|
|
||||||
|
|
||||||
assert steps == expected_steps, f"Expected steps {expected_steps}, got {steps}"
|
|
||||||
Reference in New Issue
Block a user