forked from tangger/lerobot
Port HIL SERL (#644)
Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co> Co-authored-by: Eugene Mironov <helper2424@gmail.com> Co-authored-by: s1lent4gnt <kmeftah.khalil@gmail.com> Co-authored-by: Ke Wang <superwk1017@gmail.com> Co-authored-by: Yoel Chornton <yoel.chornton@gmail.com> Co-authored-by: imstevenpmwork <steven.palma@huggingface.co> Co-authored-by: Simon Alibert <simon.alibert@huggingface.co>
This commit is contained in:
@@ -21,6 +21,7 @@ from lerobot.common.constants import (
|
||||
from lerobot.common.optim.optimizers import (
|
||||
AdamConfig,
|
||||
AdamWConfig,
|
||||
MultiAdamConfig,
|
||||
SGDConfig,
|
||||
load_optimizer_state,
|
||||
save_optimizer_state,
|
||||
@@ -33,13 +34,21 @@ from lerobot.common.optim.optimizers import (
|
||||
(AdamConfig, torch.optim.Adam),
|
||||
(AdamWConfig, torch.optim.AdamW),
|
||||
(SGDConfig, torch.optim.SGD),
|
||||
(MultiAdamConfig, dict),
|
||||
],
|
||||
)
|
||||
def test_optimizer_build(config_cls, expected_class, model_params):
|
||||
config = config_cls()
|
||||
optimizer = config.build(model_params)
|
||||
assert isinstance(optimizer, expected_class)
|
||||
assert optimizer.defaults["lr"] == config.lr
|
||||
if config_cls == MultiAdamConfig:
|
||||
params_dict = {"default": model_params}
|
||||
optimizer = config.build(params_dict)
|
||||
assert isinstance(optimizer, expected_class)
|
||||
assert isinstance(optimizer["default"], torch.optim.Adam)
|
||||
assert optimizer["default"].defaults["lr"] == config.lr
|
||||
else:
|
||||
optimizer = config.build(model_params)
|
||||
assert isinstance(optimizer, expected_class)
|
||||
assert optimizer.defaults["lr"] == config.lr
|
||||
|
||||
|
||||
def test_save_optimizer_state(optimizer, tmp_path):
|
||||
@@ -54,3 +63,180 @@ def test_save_and_load_optimizer_state(model_params, optimizer, tmp_path):
|
||||
loaded_optimizer = load_optimizer_state(loaded_optimizer, tmp_path)
|
||||
|
||||
torch.testing.assert_close(optimizer.state_dict(), loaded_optimizer.state_dict())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def base_params_dict():
|
||||
return {
|
||||
"actor": [torch.nn.Parameter(torch.randn(10, 10))],
|
||||
"critic": [torch.nn.Parameter(torch.randn(5, 5))],
|
||||
"temperature": [torch.nn.Parameter(torch.randn(3, 3))],
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"config_params, expected_values",
|
||||
[
|
||||
# Test 1: Basic configuration with different learning rates
|
||||
(
|
||||
{
|
||||
"lr": 1e-3,
|
||||
"weight_decay": 1e-4,
|
||||
"optimizer_groups": {
|
||||
"actor": {"lr": 1e-4},
|
||||
"critic": {"lr": 5e-4},
|
||||
"temperature": {"lr": 2e-3},
|
||||
},
|
||||
},
|
||||
{
|
||||
"actor": {"lr": 1e-4, "weight_decay": 1e-4, "betas": (0.9, 0.999)},
|
||||
"critic": {"lr": 5e-4, "weight_decay": 1e-4, "betas": (0.9, 0.999)},
|
||||
"temperature": {"lr": 2e-3, "weight_decay": 1e-4, "betas": (0.9, 0.999)},
|
||||
},
|
||||
),
|
||||
# Test 2: Different weight decays and beta values
|
||||
(
|
||||
{
|
||||
"lr": 1e-3,
|
||||
"weight_decay": 1e-4,
|
||||
"optimizer_groups": {
|
||||
"actor": {"lr": 1e-4, "weight_decay": 1e-5},
|
||||
"critic": {"lr": 5e-4, "weight_decay": 1e-6},
|
||||
"temperature": {"lr": 2e-3, "betas": (0.95, 0.999)},
|
||||
},
|
||||
},
|
||||
{
|
||||
"actor": {"lr": 1e-4, "weight_decay": 1e-5, "betas": (0.9, 0.999)},
|
||||
"critic": {"lr": 5e-4, "weight_decay": 1e-6, "betas": (0.9, 0.999)},
|
||||
"temperature": {"lr": 2e-3, "weight_decay": 1e-4, "betas": (0.95, 0.999)},
|
||||
},
|
||||
),
|
||||
# Test 3: Epsilon parameter customization
|
||||
(
|
||||
{
|
||||
"lr": 1e-3,
|
||||
"weight_decay": 1e-4,
|
||||
"optimizer_groups": {
|
||||
"actor": {"lr": 1e-4, "eps": 1e-6},
|
||||
"critic": {"lr": 5e-4, "eps": 1e-7},
|
||||
"temperature": {"lr": 2e-3, "eps": 1e-8},
|
||||
},
|
||||
},
|
||||
{
|
||||
"actor": {"lr": 1e-4, "weight_decay": 1e-4, "betas": (0.9, 0.999), "eps": 1e-6},
|
||||
"critic": {"lr": 5e-4, "weight_decay": 1e-4, "betas": (0.9, 0.999), "eps": 1e-7},
|
||||
"temperature": {"lr": 2e-3, "weight_decay": 1e-4, "betas": (0.9, 0.999), "eps": 1e-8},
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_multi_adam_configuration(base_params_dict, config_params, expected_values):
|
||||
# Create config with the given parameters
|
||||
config = MultiAdamConfig(**config_params)
|
||||
optimizers = config.build(base_params_dict)
|
||||
|
||||
# Verify optimizer count and keys
|
||||
assert len(optimizers) == len(expected_values)
|
||||
assert set(optimizers.keys()) == set(expected_values.keys())
|
||||
|
||||
# Check that all optimizers are Adam instances
|
||||
for opt in optimizers.values():
|
||||
assert isinstance(opt, torch.optim.Adam)
|
||||
|
||||
# Verify hyperparameters for each optimizer
|
||||
for name, expected in expected_values.items():
|
||||
optimizer = optimizers[name]
|
||||
for param, value in expected.items():
|
||||
assert optimizer.defaults[param] == value
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def multi_optimizers(base_params_dict):
|
||||
config = MultiAdamConfig(
|
||||
lr=1e-3,
|
||||
optimizer_groups={
|
||||
"actor": {"lr": 1e-4},
|
||||
"critic": {"lr": 5e-4},
|
||||
"temperature": {"lr": 2e-3},
|
||||
},
|
||||
)
|
||||
return config.build(base_params_dict)
|
||||
|
||||
|
||||
def test_save_multi_optimizer_state(multi_optimizers, tmp_path):
|
||||
# Save optimizer states
|
||||
save_optimizer_state(multi_optimizers, tmp_path)
|
||||
|
||||
# Verify that directories were created for each optimizer
|
||||
for name in multi_optimizers:
|
||||
assert (tmp_path / name).is_dir()
|
||||
assert (tmp_path / name / OPTIMIZER_STATE).is_file()
|
||||
assert (tmp_path / name / OPTIMIZER_PARAM_GROUPS).is_file()
|
||||
|
||||
|
||||
def test_save_and_load_multi_optimizer_state(base_params_dict, multi_optimizers, tmp_path):
|
||||
# Option 1: Add a minimal backward pass to populate optimizer states
|
||||
for name, params in base_params_dict.items():
|
||||
if name in multi_optimizers:
|
||||
# Create a dummy loss and do backward
|
||||
dummy_loss = params[0].sum()
|
||||
dummy_loss.backward()
|
||||
# Perform an optimization step
|
||||
multi_optimizers[name].step()
|
||||
# Zero gradients for next steps
|
||||
multi_optimizers[name].zero_grad()
|
||||
|
||||
# Save optimizer states
|
||||
save_optimizer_state(multi_optimizers, tmp_path)
|
||||
|
||||
# Create new optimizers with the same config
|
||||
config = MultiAdamConfig(
|
||||
lr=1e-3,
|
||||
optimizer_groups={
|
||||
"actor": {"lr": 1e-4},
|
||||
"critic": {"lr": 5e-4},
|
||||
"temperature": {"lr": 2e-3},
|
||||
},
|
||||
)
|
||||
new_optimizers = config.build(base_params_dict)
|
||||
|
||||
# Load optimizer states
|
||||
loaded_optimizers = load_optimizer_state(new_optimizers, tmp_path)
|
||||
|
||||
# Verify state dictionaries match
|
||||
for name in multi_optimizers:
|
||||
torch.testing.assert_close(multi_optimizers[name].state_dict(), loaded_optimizers[name].state_dict())
|
||||
|
||||
|
||||
def test_save_and_load_empty_multi_optimizer_state(base_params_dict, tmp_path):
|
||||
"""Test saving and loading optimizer states even when the state is empty (no backward pass)."""
|
||||
# Create config and build optimizers
|
||||
config = MultiAdamConfig(
|
||||
lr=1e-3,
|
||||
optimizer_groups={
|
||||
"actor": {"lr": 1e-4},
|
||||
"critic": {"lr": 5e-4},
|
||||
"temperature": {"lr": 2e-3},
|
||||
},
|
||||
)
|
||||
optimizers = config.build(base_params_dict)
|
||||
|
||||
# Save optimizer states without any backward pass (empty state)
|
||||
save_optimizer_state(optimizers, tmp_path)
|
||||
|
||||
# Create new optimizers with the same config
|
||||
new_optimizers = config.build(base_params_dict)
|
||||
|
||||
# Load optimizer states
|
||||
loaded_optimizers = load_optimizer_state(new_optimizers, tmp_path)
|
||||
|
||||
# Verify hyperparameters match even with empty state
|
||||
for name, optimizer in optimizers.items():
|
||||
assert optimizer.defaults["lr"] == loaded_optimizers[name].defaults["lr"]
|
||||
assert optimizer.defaults["weight_decay"] == loaded_optimizers[name].defaults["weight_decay"]
|
||||
assert optimizer.defaults["betas"] == loaded_optimizers[name].defaults["betas"]
|
||||
|
||||
# Verify state dictionaries match (they will be empty)
|
||||
torch.testing.assert_close(
|
||||
optimizer.state_dict()["param_groups"], loaded_optimizers[name].state_dict()["param_groups"]
|
||||
)
|
||||
|
||||
139
tests/policies/hilserl/test_modeling_classifier.py
Normal file
139
tests/policies/hilserl/test_modeling_classifier.py
Normal file
@@ -0,0 +1,139 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.common.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
from lerobot.common.policies.sac.reward_model.modeling_classifier import ClassifierOutput
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from tests.utils import require_package
|
||||
|
||||
|
||||
def test_classifier_output():
|
||||
output = ClassifierOutput(
|
||||
logits=torch.tensor([1, 2, 3]),
|
||||
probabilities=torch.tensor([0.1, 0.2, 0.3]),
|
||||
hidden_states=None,
|
||||
)
|
||||
|
||||
assert (
|
||||
f"{output}"
|
||||
== "ClassifierOutput(logits=tensor([1, 2, 3]), probabilities=tensor([0.1000, 0.2000, 0.3000]), hidden_states=None)"
|
||||
)
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_binary_classifier_with_default_params():
|
||||
from lerobot.common.policies.sac.reward_model.modeling_classifier import Classifier
|
||||
|
||||
config = RewardClassifierConfig()
|
||||
config.input_features = {
|
||||
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
config.output_features = {
|
||||
"next.reward": PolicyFeature(type=FeatureType.REWARD, shape=(1,)),
|
||||
}
|
||||
config.normalization_mapping = {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"REWARD": NormalizationMode.IDENTITY,
|
||||
}
|
||||
config.num_cameras = 1
|
||||
classifier = Classifier(config)
|
||||
|
||||
batch_size = 10
|
||||
|
||||
input = {
|
||||
"observation.image": torch.rand((batch_size, 3, 128, 128)),
|
||||
"next.reward": torch.randint(low=0, high=2, size=(batch_size,)).float(),
|
||||
}
|
||||
|
||||
images, labels = classifier.extract_images_and_labels(input)
|
||||
assert len(images) == 1
|
||||
assert images[0].shape == torch.Size([batch_size, 3, 128, 128])
|
||||
assert labels.shape == torch.Size([batch_size])
|
||||
|
||||
output = classifier.predict(images)
|
||||
|
||||
assert output is not None
|
||||
assert output.logits.size() == torch.Size([batch_size])
|
||||
assert not torch.isnan(output.logits).any(), "Tensor contains NaN values"
|
||||
assert output.probabilities.shape == torch.Size([batch_size])
|
||||
assert not torch.isnan(output.probabilities).any(), "Tensor contains NaN values"
|
||||
assert output.hidden_states.shape == torch.Size([batch_size, 256])
|
||||
assert not torch.isnan(output.hidden_states).any(), "Tensor contains NaN values"
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_multiclass_classifier():
|
||||
from lerobot.common.policies.sac.reward_model.modeling_classifier import Classifier
|
||||
|
||||
num_classes = 5
|
||||
config = RewardClassifierConfig()
|
||||
config.input_features = {
|
||||
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
config.output_features = {
|
||||
"next.reward": PolicyFeature(type=FeatureType.REWARD, shape=(num_classes,)),
|
||||
}
|
||||
config.num_cameras = 1
|
||||
config.num_classes = num_classes
|
||||
classifier = Classifier(config)
|
||||
|
||||
batch_size = 10
|
||||
|
||||
input = {
|
||||
"observation.image": torch.rand((batch_size, 3, 128, 128)),
|
||||
"next.reward": torch.rand((batch_size, num_classes)),
|
||||
}
|
||||
|
||||
images, labels = classifier.extract_images_and_labels(input)
|
||||
assert len(images) == 1
|
||||
assert images[0].shape == torch.Size([batch_size, 3, 128, 128])
|
||||
assert labels.shape == torch.Size([batch_size, num_classes])
|
||||
|
||||
output = classifier.predict(images)
|
||||
|
||||
assert output is not None
|
||||
assert output.logits.shape == torch.Size([batch_size, num_classes])
|
||||
assert not torch.isnan(output.logits).any(), "Tensor contains NaN values"
|
||||
assert output.probabilities.shape == torch.Size([batch_size, num_classes])
|
||||
assert not torch.isnan(output.probabilities).any(), "Tensor contains NaN values"
|
||||
assert output.hidden_states.shape == torch.Size([batch_size, 256])
|
||||
assert not torch.isnan(output.hidden_states).any(), "Tensor contains NaN values"
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_default_device():
|
||||
from lerobot.common.policies.sac.reward_model.modeling_classifier import Classifier
|
||||
|
||||
config = RewardClassifierConfig()
|
||||
assert config.device == "cpu"
|
||||
|
||||
classifier = Classifier(config)
|
||||
for p in classifier.parameters():
|
||||
assert p.device == torch.device("cpu")
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_explicit_device_setup():
|
||||
from lerobot.common.policies.sac.reward_model.modeling_classifier import Classifier
|
||||
|
||||
config = RewardClassifierConfig(device="cpu")
|
||||
assert config.device == "cpu"
|
||||
|
||||
classifier = Classifier(config)
|
||||
for p in classifier.parameters():
|
||||
assert p.device == torch.device("cpu")
|
||||
217
tests/policies/test_sac_config.py
Normal file
217
tests/policies/test_sac_config.py
Normal file
@@ -0,0 +1,217 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.common.policies.sac.configuration_sac import (
|
||||
ActorLearnerConfig,
|
||||
ActorNetworkConfig,
|
||||
ConcurrencyConfig,
|
||||
CriticNetworkConfig,
|
||||
PolicyConfig,
|
||||
SACConfig,
|
||||
)
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
|
||||
|
||||
def test_sac_config_default_initialization():
|
||||
config = SACConfig()
|
||||
|
||||
assert config.normalization_mapping == {
|
||||
"VISUAL": NormalizationMode.MEAN_STD,
|
||||
"STATE": NormalizationMode.MIN_MAX,
|
||||
"ENV": NormalizationMode.MIN_MAX,
|
||||
"ACTION": NormalizationMode.MIN_MAX,
|
||||
}
|
||||
assert config.dataset_stats == {
|
||||
"observation.image": {
|
||||
"mean": [0.485, 0.456, 0.406],
|
||||
"std": [0.229, 0.224, 0.225],
|
||||
},
|
||||
"observation.state": {
|
||||
"min": [0.0, 0.0],
|
||||
"max": [1.0, 1.0],
|
||||
},
|
||||
"action": {
|
||||
"min": [0.0, 0.0, 0.0],
|
||||
"max": [1.0, 1.0, 1.0],
|
||||
},
|
||||
}
|
||||
|
||||
# Basic parameters
|
||||
assert config.device == "cpu"
|
||||
assert config.storage_device == "cpu"
|
||||
assert config.discount == 0.99
|
||||
assert config.temperature_init == 1.0
|
||||
assert config.num_critics == 2
|
||||
|
||||
# Architecture specifics
|
||||
assert config.vision_encoder_name is None
|
||||
assert config.freeze_vision_encoder is True
|
||||
assert config.image_encoder_hidden_dim == 32
|
||||
assert config.shared_encoder is True
|
||||
assert config.num_discrete_actions is None
|
||||
assert config.image_embedding_pooling_dim == 8
|
||||
|
||||
# Training parameters
|
||||
assert config.online_steps == 1000000
|
||||
assert config.online_env_seed == 10000
|
||||
assert config.online_buffer_capacity == 100000
|
||||
assert config.offline_buffer_capacity == 100000
|
||||
assert config.async_prefetch is False
|
||||
assert config.online_step_before_learning == 100
|
||||
assert config.policy_update_freq == 1
|
||||
|
||||
# SAC algorithm parameters
|
||||
assert config.num_subsample_critics is None
|
||||
assert config.critic_lr == 3e-4
|
||||
assert config.actor_lr == 3e-4
|
||||
assert config.temperature_lr == 3e-4
|
||||
assert config.critic_target_update_weight == 0.005
|
||||
assert config.utd_ratio == 1
|
||||
assert config.state_encoder_hidden_dim == 256
|
||||
assert config.latent_dim == 256
|
||||
assert config.target_entropy is None
|
||||
assert config.use_backup_entropy is True
|
||||
assert config.grad_clip_norm == 40.0
|
||||
|
||||
# Dataset stats defaults
|
||||
expected_dataset_stats = {
|
||||
"observation.image": {
|
||||
"mean": [0.485, 0.456, 0.406],
|
||||
"std": [0.229, 0.224, 0.225],
|
||||
},
|
||||
"observation.state": {
|
||||
"min": [0.0, 0.0],
|
||||
"max": [1.0, 1.0],
|
||||
},
|
||||
"action": {
|
||||
"min": [0.0, 0.0, 0.0],
|
||||
"max": [1.0, 1.0, 1.0],
|
||||
},
|
||||
}
|
||||
assert config.dataset_stats == expected_dataset_stats
|
||||
|
||||
# Critic network configuration
|
||||
assert config.critic_network_kwargs.hidden_dims == [256, 256]
|
||||
assert config.critic_network_kwargs.activate_final is True
|
||||
assert config.critic_network_kwargs.final_activation is None
|
||||
|
||||
# Actor network configuration
|
||||
assert config.actor_network_kwargs.hidden_dims == [256, 256]
|
||||
assert config.actor_network_kwargs.activate_final is True
|
||||
|
||||
# Policy configuration
|
||||
assert config.policy_kwargs.use_tanh_squash is True
|
||||
assert config.policy_kwargs.std_min == 1e-5
|
||||
assert config.policy_kwargs.std_max == 10.0
|
||||
assert config.policy_kwargs.init_final == 0.05
|
||||
|
||||
# Discrete critic network configuration
|
||||
assert config.discrete_critic_network_kwargs.hidden_dims == [256, 256]
|
||||
assert config.discrete_critic_network_kwargs.activate_final is True
|
||||
assert config.discrete_critic_network_kwargs.final_activation is None
|
||||
|
||||
# Actor learner configuration
|
||||
assert config.actor_learner_config.learner_host == "127.0.0.1"
|
||||
assert config.actor_learner_config.learner_port == 50051
|
||||
assert config.actor_learner_config.policy_parameters_push_frequency == 4
|
||||
|
||||
# Concurrency configuration
|
||||
assert config.concurrency.actor == "threads"
|
||||
assert config.concurrency.learner == "threads"
|
||||
|
||||
assert isinstance(config.actor_network_kwargs, ActorNetworkConfig)
|
||||
assert isinstance(config.critic_network_kwargs, CriticNetworkConfig)
|
||||
assert isinstance(config.policy_kwargs, PolicyConfig)
|
||||
assert isinstance(config.actor_learner_config, ActorLearnerConfig)
|
||||
assert isinstance(config.concurrency, ConcurrencyConfig)
|
||||
|
||||
|
||||
def test_critic_network_kwargs():
|
||||
config = CriticNetworkConfig()
|
||||
assert config.hidden_dims == [256, 256]
|
||||
assert config.activate_final is True
|
||||
assert config.final_activation is None
|
||||
|
||||
|
||||
def test_actor_network_kwargs():
|
||||
config = ActorNetworkConfig()
|
||||
assert config.hidden_dims == [256, 256]
|
||||
assert config.activate_final is True
|
||||
|
||||
|
||||
def test_policy_kwargs():
|
||||
config = PolicyConfig()
|
||||
assert config.use_tanh_squash is True
|
||||
assert config.std_min == 1e-5
|
||||
assert config.std_max == 10.0
|
||||
assert config.init_final == 0.05
|
||||
|
||||
|
||||
def test_actor_learner_config():
|
||||
config = ActorLearnerConfig()
|
||||
assert config.learner_host == "127.0.0.1"
|
||||
assert config.learner_port == 50051
|
||||
assert config.policy_parameters_push_frequency == 4
|
||||
|
||||
|
||||
def test_concurrency_config():
|
||||
config = ConcurrencyConfig()
|
||||
assert config.actor == "threads"
|
||||
assert config.learner == "threads"
|
||||
|
||||
|
||||
def test_sac_config_custom_initialization():
|
||||
config = SACConfig(
|
||||
device="cpu",
|
||||
discount=0.95,
|
||||
temperature_init=0.5,
|
||||
num_critics=3,
|
||||
)
|
||||
|
||||
assert config.device == "cpu"
|
||||
assert config.discount == 0.95
|
||||
assert config.temperature_init == 0.5
|
||||
assert config.num_critics == 3
|
||||
|
||||
|
||||
def test_validate_features():
|
||||
config = SACConfig(
|
||||
input_features={"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(10,))},
|
||||
output_features={"action": PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
|
||||
)
|
||||
config.validate_features()
|
||||
|
||||
|
||||
def test_validate_features_missing_observation():
|
||||
config = SACConfig(
|
||||
input_features={"wrong_key": PolicyFeature(type=FeatureType.STATE, shape=(10,))},
|
||||
output_features={"action": PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
|
||||
)
|
||||
with pytest.raises(
|
||||
ValueError, match="You must provide either 'observation.state' or an image observation"
|
||||
):
|
||||
config.validate_features()
|
||||
|
||||
|
||||
def test_validate_features_missing_action():
|
||||
config = SACConfig(
|
||||
input_features={"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(10,))},
|
||||
output_features={"wrong_key": PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
|
||||
)
|
||||
with pytest.raises(ValueError, match="You must provide 'action' in the output features"):
|
||||
config.validate_features()
|
||||
541
tests/policies/test_sac_policy.py
Normal file
541
tests/policies/test_sac_policy.py
Normal file
@@ -0,0 +1,541 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.common.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.common.policies.sac.modeling_sac import MLP, SACPolicy
|
||||
from lerobot.common.utils.random_utils import seeded_context, set_seed
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
|
||||
try:
|
||||
import transformers # noqa: F401
|
||||
|
||||
TRANSFORMERS_AVAILABLE = True
|
||||
except ImportError:
|
||||
TRANSFORMERS_AVAILABLE = False
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def set_random_seed():
|
||||
seed = 42
|
||||
set_seed(seed)
|
||||
|
||||
|
||||
def test_mlp_with_default_args():
|
||||
mlp = MLP(input_dim=10, hidden_dims=[256, 256])
|
||||
|
||||
x = torch.randn(10)
|
||||
y = mlp(x)
|
||||
assert y.shape == (256,)
|
||||
|
||||
|
||||
def test_mlp_with_batch_dim():
|
||||
mlp = MLP(input_dim=10, hidden_dims=[256, 256])
|
||||
x = torch.randn(2, 10)
|
||||
y = mlp(x)
|
||||
assert y.shape == (2, 256)
|
||||
|
||||
|
||||
def test_forward_with_empty_hidden_dims():
|
||||
mlp = MLP(input_dim=10, hidden_dims=[])
|
||||
x = torch.randn(1, 10)
|
||||
assert mlp(x).shape == (1, 10)
|
||||
|
||||
|
||||
def test_mlp_with_dropout():
|
||||
mlp = MLP(input_dim=10, hidden_dims=[256, 256, 11], dropout_rate=0.1)
|
||||
x = torch.randn(1, 10)
|
||||
y = mlp(x)
|
||||
assert y.shape == (1, 11)
|
||||
|
||||
drop_out_layers_count = sum(isinstance(layer, nn.Dropout) for layer in mlp.net)
|
||||
assert drop_out_layers_count == 2
|
||||
|
||||
|
||||
def test_mlp_with_custom_final_activation():
|
||||
mlp = MLP(input_dim=10, hidden_dims=[256, 256], final_activation=torch.nn.Tanh())
|
||||
x = torch.randn(1, 10)
|
||||
y = mlp(x)
|
||||
assert y.shape == (1, 256)
|
||||
assert (y >= -1).all() and (y <= 1).all()
|
||||
|
||||
|
||||
def test_sac_policy_with_default_args():
|
||||
with pytest.raises(ValueError, match="should be an instance of class `PreTrainedConfig`"):
|
||||
SACPolicy()
|
||||
|
||||
|
||||
def create_dummy_state(batch_size: int, state_dim: int = 10) -> Tensor:
|
||||
return {
|
||||
"observation.state": torch.randn(batch_size, state_dim),
|
||||
}
|
||||
|
||||
|
||||
def create_dummy_with_visual_input(batch_size: int, state_dim: int = 10) -> Tensor:
|
||||
return {
|
||||
"observation.image": torch.randn(batch_size, 3, 84, 84),
|
||||
"observation.state": torch.randn(batch_size, state_dim),
|
||||
}
|
||||
|
||||
|
||||
def create_dummy_action(batch_size: int, action_dim: int = 10) -> Tensor:
|
||||
return torch.randn(batch_size, action_dim)
|
||||
|
||||
|
||||
def create_default_train_batch(
|
||||
batch_size: int = 8, state_dim: int = 10, action_dim: int = 10
|
||||
) -> dict[str, Tensor]:
|
||||
return {
|
||||
"action": create_dummy_action(batch_size, action_dim),
|
||||
"reward": torch.randn(batch_size),
|
||||
"state": create_dummy_state(batch_size, state_dim),
|
||||
"next_state": create_dummy_state(batch_size, state_dim),
|
||||
"done": torch.randn(batch_size),
|
||||
}
|
||||
|
||||
|
||||
def create_train_batch_with_visual_input(
|
||||
batch_size: int = 8, state_dim: int = 10, action_dim: int = 10
|
||||
) -> dict[str, Tensor]:
|
||||
return {
|
||||
"action": create_dummy_action(batch_size, action_dim),
|
||||
"reward": torch.randn(batch_size),
|
||||
"state": create_dummy_with_visual_input(batch_size, state_dim),
|
||||
"next_state": create_dummy_with_visual_input(batch_size, state_dim),
|
||||
"done": torch.randn(batch_size),
|
||||
}
|
||||
|
||||
|
||||
def create_observation_batch(batch_size: int = 8, state_dim: int = 10) -> dict[str, Tensor]:
|
||||
return {
|
||||
"observation.state": torch.randn(batch_size, state_dim),
|
||||
}
|
||||
|
||||
|
||||
def create_observation_batch_with_visual_input(batch_size: int = 8, state_dim: int = 10) -> dict[str, Tensor]:
|
||||
return {
|
||||
"observation.state": torch.randn(batch_size, state_dim),
|
||||
"observation.image": torch.randn(batch_size, 3, 84, 84),
|
||||
}
|
||||
|
||||
|
||||
def make_optimizers(policy: SACPolicy, has_discrete_action: bool = False) -> dict[str, torch.optim.Optimizer]:
|
||||
"""Create optimizers for the SAC policy."""
|
||||
optimizer_actor = torch.optim.Adam(
|
||||
# Handle the case of shared encoder where the encoder weights are not optimized with the actor gradient
|
||||
params=[
|
||||
p
|
||||
for n, p in policy.actor.named_parameters()
|
||||
if not policy.config.shared_encoder or not n.startswith("encoder")
|
||||
],
|
||||
lr=policy.config.actor_lr,
|
||||
)
|
||||
optimizer_critic = torch.optim.Adam(
|
||||
params=policy.critic_ensemble.parameters(),
|
||||
lr=policy.config.critic_lr,
|
||||
)
|
||||
optimizer_temperature = torch.optim.Adam(
|
||||
params=[policy.log_alpha],
|
||||
lr=policy.config.critic_lr,
|
||||
)
|
||||
|
||||
optimizers = {
|
||||
"actor": optimizer_actor,
|
||||
"critic": optimizer_critic,
|
||||
"temperature": optimizer_temperature,
|
||||
}
|
||||
|
||||
if has_discrete_action:
|
||||
optimizers["discrete_critic"] = torch.optim.Adam(
|
||||
params=policy.discrete_critic.parameters(),
|
||||
lr=policy.config.critic_lr,
|
||||
)
|
||||
|
||||
return optimizers
|
||||
|
||||
|
||||
def create_default_config(
|
||||
state_dim: int, continuous_action_dim: int, has_discrete_action: bool = False
|
||||
) -> SACConfig:
|
||||
action_dim = continuous_action_dim
|
||||
if has_discrete_action:
|
||||
action_dim += 1
|
||||
|
||||
config = SACConfig(
|
||||
input_features={"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))},
|
||||
output_features={"action": PolicyFeature(type=FeatureType.ACTION, shape=(continuous_action_dim,))},
|
||||
dataset_stats={
|
||||
"observation.state": {
|
||||
"min": [0.0] * state_dim,
|
||||
"max": [1.0] * state_dim,
|
||||
},
|
||||
"action": {
|
||||
"min": [0.0] * continuous_action_dim,
|
||||
"max": [1.0] * continuous_action_dim,
|
||||
},
|
||||
},
|
||||
)
|
||||
config.validate_features()
|
||||
return config
|
||||
|
||||
|
||||
def create_config_with_visual_input(
|
||||
state_dim: int, continuous_action_dim: int, has_discrete_action: bool = False
|
||||
) -> SACConfig:
|
||||
config = create_default_config(
|
||||
state_dim=state_dim,
|
||||
continuous_action_dim=continuous_action_dim,
|
||||
has_discrete_action=has_discrete_action,
|
||||
)
|
||||
config.input_features["observation.image"] = PolicyFeature(type=FeatureType.VISUAL, shape=(3, 84, 84))
|
||||
config.dataset_stats["observation.image"] = {
|
||||
"mean": torch.randn(3, 1, 1),
|
||||
"std": torch.randn(3, 1, 1),
|
||||
}
|
||||
|
||||
# Let make tests a little bit faster
|
||||
config.state_encoder_hidden_dim = 32
|
||||
config.latent_dim = 32
|
||||
|
||||
config.validate_features()
|
||||
return config
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)])
|
||||
def test_sac_policy_with_default_config(batch_size: int, state_dim: int, action_dim: int):
|
||||
batch = create_default_train_batch(batch_size=batch_size, action_dim=action_dim, state_dim=state_dim)
|
||||
config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim)
|
||||
|
||||
policy = SACPolicy(config=config)
|
||||
policy.train()
|
||||
|
||||
optimizers = make_optimizers(policy)
|
||||
|
||||
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
|
||||
assert cirtic_loss.item() is not None
|
||||
assert cirtic_loss.shape == ()
|
||||
cirtic_loss.backward()
|
||||
optimizers["critic"].step()
|
||||
|
||||
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
|
||||
assert actor_loss.item() is not None
|
||||
assert actor_loss.shape == ()
|
||||
|
||||
actor_loss.backward()
|
||||
optimizers["actor"].step()
|
||||
|
||||
temperature_loss = policy.forward(batch, model="temperature")["loss_temperature"]
|
||||
assert temperature_loss.item() is not None
|
||||
assert temperature_loss.shape == ()
|
||||
|
||||
temperature_loss.backward()
|
||||
optimizers["temperature"].step()
|
||||
|
||||
policy.eval()
|
||||
with torch.no_grad():
|
||||
observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
|
||||
selected_action = policy.select_action(observation_batch)
|
||||
assert selected_action.shape == (batch_size, action_dim)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)])
|
||||
def test_sac_policy_with_visual_input(batch_size: int, state_dim: int, action_dim: int):
|
||||
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
|
||||
policy = SACPolicy(config=config)
|
||||
|
||||
batch = create_train_batch_with_visual_input(
|
||||
batch_size=batch_size, state_dim=state_dim, action_dim=action_dim
|
||||
)
|
||||
|
||||
policy.train()
|
||||
|
||||
optimizers = make_optimizers(policy)
|
||||
|
||||
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
|
||||
assert cirtic_loss.item() is not None
|
||||
assert cirtic_loss.shape == ()
|
||||
cirtic_loss.backward()
|
||||
optimizers["critic"].step()
|
||||
|
||||
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
|
||||
assert actor_loss.item() is not None
|
||||
assert actor_loss.shape == ()
|
||||
|
||||
actor_loss.backward()
|
||||
optimizers["actor"].step()
|
||||
|
||||
temperature_loss = policy.forward(batch, model="temperature")["loss_temperature"]
|
||||
assert temperature_loss.item() is not None
|
||||
assert temperature_loss.shape == ()
|
||||
|
||||
temperature_loss.backward()
|
||||
optimizers["temperature"].step()
|
||||
|
||||
policy.eval()
|
||||
with torch.no_grad():
|
||||
observation_batch = create_observation_batch_with_visual_input(
|
||||
batch_size=batch_size, state_dim=state_dim
|
||||
)
|
||||
selected_action = policy.select_action(observation_batch)
|
||||
assert selected_action.shape == (batch_size, action_dim)
|
||||
|
||||
|
||||
# Let's check best candidates for pretrained encoders
|
||||
@pytest.mark.parametrize(
|
||||
"batch_size,state_dim,action_dim,vision_encoder_name",
|
||||
[(1, 6, 6, "helper2424/resnet10"), (1, 6, 6, "facebook/convnext-base-224")],
|
||||
)
|
||||
@pytest.mark.skipif(not TRANSFORMERS_AVAILABLE, reason="Transformers are not installed")
|
||||
def test_sac_policy_with_pretrained_encoder(
|
||||
batch_size: int, state_dim: int, action_dim: int, vision_encoder_name: str
|
||||
):
|
||||
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
|
||||
config.vision_encoder_name = vision_encoder_name
|
||||
policy = SACPolicy(config=config)
|
||||
policy.train()
|
||||
|
||||
batch = create_train_batch_with_visual_input(
|
||||
batch_size=batch_size, state_dim=state_dim, action_dim=action_dim
|
||||
)
|
||||
|
||||
optimizers = make_optimizers(policy)
|
||||
|
||||
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
|
||||
assert cirtic_loss.item() is not None
|
||||
assert cirtic_loss.shape == ()
|
||||
cirtic_loss.backward()
|
||||
optimizers["critic"].step()
|
||||
|
||||
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
|
||||
assert actor_loss.item() is not None
|
||||
assert actor_loss.shape == ()
|
||||
|
||||
|
||||
def test_sac_policy_with_shared_encoder():
|
||||
batch_size = 2
|
||||
action_dim = 10
|
||||
state_dim = 10
|
||||
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
|
||||
config.shared_encoder = True
|
||||
|
||||
policy = SACPolicy(config=config)
|
||||
policy.train()
|
||||
|
||||
batch = create_train_batch_with_visual_input(
|
||||
batch_size=batch_size, state_dim=state_dim, action_dim=action_dim
|
||||
)
|
||||
|
||||
policy.train()
|
||||
|
||||
optimizers = make_optimizers(policy)
|
||||
|
||||
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
|
||||
assert cirtic_loss.item() is not None
|
||||
assert cirtic_loss.shape == ()
|
||||
cirtic_loss.backward()
|
||||
optimizers["critic"].step()
|
||||
|
||||
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
|
||||
assert actor_loss.item() is not None
|
||||
assert actor_loss.shape == ()
|
||||
|
||||
actor_loss.backward()
|
||||
optimizers["actor"].step()
|
||||
|
||||
|
||||
def test_sac_policy_with_discrete_critic():
|
||||
batch_size = 2
|
||||
continuous_action_dim = 9
|
||||
full_action_dim = continuous_action_dim + 1 # the last action is discrete
|
||||
state_dim = 10
|
||||
config = create_config_with_visual_input(
|
||||
state_dim=state_dim, continuous_action_dim=continuous_action_dim, has_discrete_action=True
|
||||
)
|
||||
|
||||
num_discrete_actions = 5
|
||||
config.num_discrete_actions = num_discrete_actions
|
||||
|
||||
policy = SACPolicy(config=config)
|
||||
policy.train()
|
||||
|
||||
batch = create_train_batch_with_visual_input(
|
||||
batch_size=batch_size, state_dim=state_dim, action_dim=full_action_dim
|
||||
)
|
||||
|
||||
policy.train()
|
||||
|
||||
optimizers = make_optimizers(policy, has_discrete_action=True)
|
||||
|
||||
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
|
||||
assert cirtic_loss.item() is not None
|
||||
assert cirtic_loss.shape == ()
|
||||
cirtic_loss.backward()
|
||||
optimizers["critic"].step()
|
||||
|
||||
discrete_critic_loss = policy.forward(batch, model="discrete_critic")["loss_discrete_critic"]
|
||||
assert discrete_critic_loss.item() is not None
|
||||
assert discrete_critic_loss.shape == ()
|
||||
discrete_critic_loss.backward()
|
||||
optimizers["discrete_critic"].step()
|
||||
|
||||
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
|
||||
assert actor_loss.item() is not None
|
||||
assert actor_loss.shape == ()
|
||||
|
||||
actor_loss.backward()
|
||||
optimizers["actor"].step()
|
||||
|
||||
policy.eval()
|
||||
with torch.no_grad():
|
||||
observation_batch = create_observation_batch_with_visual_input(
|
||||
batch_size=batch_size, state_dim=state_dim
|
||||
)
|
||||
selected_action = policy.select_action(observation_batch)
|
||||
assert selected_action.shape == (batch_size, full_action_dim)
|
||||
|
||||
discrete_actions = selected_action[:, -1].long()
|
||||
discrete_action_values = set(discrete_actions.tolist())
|
||||
|
||||
assert all(action in range(num_discrete_actions) for action in discrete_action_values), (
|
||||
f"Discrete action {discrete_action_values} is not in range({num_discrete_actions})"
|
||||
)
|
||||
|
||||
|
||||
def test_sac_policy_with_default_entropy():
|
||||
config = create_default_config(continuous_action_dim=10, state_dim=10)
|
||||
policy = SACPolicy(config=config)
|
||||
assert policy.target_entropy == -5.0
|
||||
|
||||
|
||||
def test_sac_policy_default_target_entropy_with_discrete_action():
|
||||
config = create_config_with_visual_input(state_dim=10, continuous_action_dim=6, has_discrete_action=True)
|
||||
policy = SACPolicy(config=config)
|
||||
assert policy.target_entropy == -3.0
|
||||
|
||||
|
||||
def test_sac_policy_with_predefined_entropy():
|
||||
config = create_default_config(state_dim=10, continuous_action_dim=6)
|
||||
config.target_entropy = -3.5
|
||||
|
||||
policy = SACPolicy(config=config)
|
||||
assert policy.target_entropy == pytest.approx(-3.5)
|
||||
|
||||
|
||||
def test_sac_policy_update_temperature():
|
||||
config = create_default_config(continuous_action_dim=10, state_dim=10)
|
||||
policy = SACPolicy(config=config)
|
||||
|
||||
assert policy.temperature == pytest.approx(1.0)
|
||||
policy.log_alpha.data = torch.tensor([math.log(0.1)])
|
||||
policy.update_temperature()
|
||||
assert policy.temperature == pytest.approx(0.1)
|
||||
|
||||
|
||||
def test_sac_policy_update_target_network():
|
||||
config = create_default_config(state_dim=10, continuous_action_dim=6)
|
||||
config.critic_target_update_weight = 1.0
|
||||
|
||||
policy = SACPolicy(config=config)
|
||||
policy.train()
|
||||
|
||||
for p in policy.critic_ensemble.parameters():
|
||||
p.data = torch.ones_like(p.data)
|
||||
|
||||
policy.update_target_networks()
|
||||
for p in policy.critic_target.parameters():
|
||||
assert torch.allclose(p.data, torch.ones_like(p.data)), (
|
||||
f"Target network {p.data} is not equal to {torch.ones_like(p.data)}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_critics", [1, 3])
|
||||
def test_sac_policy_with_critics_number_of_heads(num_critics: int):
|
||||
batch_size = 2
|
||||
action_dim = 10
|
||||
state_dim = 10
|
||||
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
|
||||
config.num_critics = num_critics
|
||||
|
||||
policy = SACPolicy(config=config)
|
||||
policy.train()
|
||||
|
||||
assert len(policy.critic_ensemble.critics) == num_critics
|
||||
|
||||
batch = create_train_batch_with_visual_input(
|
||||
batch_size=batch_size, state_dim=state_dim, action_dim=action_dim
|
||||
)
|
||||
|
||||
policy.train()
|
||||
|
||||
optimizers = make_optimizers(policy)
|
||||
|
||||
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
|
||||
assert cirtic_loss.item() is not None
|
||||
assert cirtic_loss.shape == ()
|
||||
cirtic_loss.backward()
|
||||
optimizers["critic"].step()
|
||||
|
||||
|
||||
def test_sac_policy_save_and_load(tmp_path):
|
||||
root = tmp_path / "test_sac_save_and_load"
|
||||
|
||||
state_dim = 10
|
||||
action_dim = 10
|
||||
batch_size = 2
|
||||
|
||||
config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim)
|
||||
policy = SACPolicy(config=config)
|
||||
policy.eval()
|
||||
policy.save_pretrained(root)
|
||||
loaded_policy = SACPolicy.from_pretrained(root, config=config)
|
||||
loaded_policy.eval()
|
||||
|
||||
batch = create_default_train_batch(batch_size=1, state_dim=10, action_dim=10)
|
||||
|
||||
with torch.no_grad():
|
||||
with seeded_context(12):
|
||||
# Collect policy values before saving
|
||||
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
|
||||
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
|
||||
temperature_loss = policy.forward(batch, model="temperature")["loss_temperature"]
|
||||
|
||||
observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
|
||||
actions = policy.select_action(observation_batch)
|
||||
|
||||
with seeded_context(12):
|
||||
# Collect policy values after loading
|
||||
loaded_cirtic_loss = loaded_policy.forward(batch, model="critic")["loss_critic"]
|
||||
loaded_actor_loss = loaded_policy.forward(batch, model="actor")["loss_actor"]
|
||||
loaded_temperature_loss = loaded_policy.forward(batch, model="temperature")["loss_temperature"]
|
||||
|
||||
loaded_observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
|
||||
loaded_actions = loaded_policy.select_action(loaded_observation_batch)
|
||||
|
||||
assert policy.state_dict().keys() == loaded_policy.state_dict().keys()
|
||||
for k in policy.state_dict():
|
||||
assert torch.allclose(policy.state_dict()[k], loaded_policy.state_dict()[k], atol=1e-6)
|
||||
|
||||
# Compare values before and after saving and loading
|
||||
# They should be the same
|
||||
assert torch.allclose(cirtic_loss, loaded_cirtic_loss)
|
||||
assert torch.allclose(actor_loss, loaded_actor_loss)
|
||||
assert torch.allclose(temperature_loss, loaded_temperature_loss)
|
||||
assert torch.allclose(actions, loaded_actions)
|
||||
208
tests/rl/test_actor.py
Normal file
208
tests/rl/test_actor.py
Normal file
@@ -0,0 +1,208 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from concurrent import futures
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch.multiprocessing import Event, Queue
|
||||
|
||||
from lerobot.common.utils.transition import Transition
|
||||
from tests.utils import require_package
|
||||
|
||||
|
||||
def create_learner_service_stub():
|
||||
import grpc
|
||||
|
||||
from lerobot.common.transport import services_pb2, services_pb2_grpc
|
||||
|
||||
class MockLearnerService(services_pb2_grpc.LearnerServiceServicer):
|
||||
def __init__(self):
|
||||
self.ready_call_count = 0
|
||||
self.should_fail = False
|
||||
|
||||
def Ready(self, request, context): # noqa: N802
|
||||
self.ready_call_count += 1
|
||||
if self.should_fail:
|
||||
context.set_code(grpc.StatusCode.UNAVAILABLE)
|
||||
context.set_details("Service unavailable")
|
||||
raise grpc.RpcError("Service unavailable")
|
||||
return services_pb2.Empty()
|
||||
|
||||
"""Fixture to start a LearnerService gRPC server and provide a connected stub."""
|
||||
|
||||
servicer = MockLearnerService()
|
||||
|
||||
# Create a gRPC server and add our servicer to it.
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=4))
|
||||
services_pb2_grpc.add_LearnerServiceServicer_to_server(servicer, server)
|
||||
port = server.add_insecure_port("[::]:0") # bind to a free port chosen by OS
|
||||
server.start() # start the server (non-blocking call):contentReference[oaicite:1]{index=1}
|
||||
|
||||
# Create a client channel and stub connected to the server's port.
|
||||
channel = grpc.insecure_channel(f"localhost:{port}")
|
||||
return services_pb2_grpc.LearnerServiceStub(channel), servicer, channel, server
|
||||
|
||||
|
||||
def close_service_stub(channel, server):
|
||||
channel.close()
|
||||
server.stop(None)
|
||||
|
||||
|
||||
@require_package("grpc")
|
||||
def test_establish_learner_connection_success():
|
||||
from lerobot.scripts.rl.actor import establish_learner_connection
|
||||
|
||||
"""Test successful connection establishment."""
|
||||
stub, _servicer, channel, server = create_learner_service_stub()
|
||||
|
||||
shutdown_event = Event()
|
||||
|
||||
# Test successful connection
|
||||
result = establish_learner_connection(stub, shutdown_event, attempts=5)
|
||||
|
||||
assert result is True
|
||||
|
||||
close_service_stub(channel, server)
|
||||
|
||||
|
||||
@require_package("grpc")
|
||||
def test_establish_learner_connection_failure():
|
||||
from lerobot.scripts.rl.actor import establish_learner_connection
|
||||
|
||||
"""Test connection failure."""
|
||||
stub, servicer, channel, server = create_learner_service_stub()
|
||||
servicer.should_fail = True
|
||||
|
||||
shutdown_event = Event()
|
||||
|
||||
# Test failed connection
|
||||
with patch("time.sleep"): # Speed up the test
|
||||
result = establish_learner_connection(stub, shutdown_event, attempts=2)
|
||||
|
||||
assert result is False
|
||||
|
||||
close_service_stub(channel, server)
|
||||
|
||||
|
||||
@require_package("grpc")
|
||||
def test_push_transitions_to_transport_queue():
|
||||
from lerobot.common.transport.utils import bytes_to_transitions
|
||||
from lerobot.scripts.rl.actor import push_transitions_to_transport_queue
|
||||
from tests.transport.test_transport_utils import assert_transitions_equal
|
||||
|
||||
"""Test pushing transitions to transport queue."""
|
||||
# Create mock transitions
|
||||
transitions = []
|
||||
for i in range(3):
|
||||
transition = Transition(
|
||||
state={"observation": torch.randn(3, 64, 64), "state": torch.randn(10)},
|
||||
action=torch.randn(5),
|
||||
reward=torch.tensor(1.0 + i),
|
||||
done=torch.tensor(False),
|
||||
truncated=torch.tensor(False),
|
||||
next_state={"observation": torch.randn(3, 64, 64), "state": torch.randn(10)},
|
||||
complementary_info={"step": torch.tensor(i)},
|
||||
)
|
||||
transitions.append(transition)
|
||||
|
||||
transitions_queue = Queue()
|
||||
|
||||
# Test pushing transitions
|
||||
push_transitions_to_transport_queue(transitions, transitions_queue)
|
||||
|
||||
# Verify the data can be retrieved
|
||||
serialized_data = transitions_queue.get()
|
||||
assert isinstance(serialized_data, bytes)
|
||||
deserialized_transitions = bytes_to_transitions(serialized_data)
|
||||
assert len(deserialized_transitions) == len(transitions)
|
||||
for i, deserialized_transition in enumerate(deserialized_transitions):
|
||||
assert_transitions_equal(deserialized_transition, transitions[i])
|
||||
|
||||
|
||||
@require_package("grpc")
|
||||
@pytest.mark.timeout(3) # force cross-platform watchdog
|
||||
def test_transitions_stream():
|
||||
from lerobot.scripts.rl.actor import transitions_stream
|
||||
|
||||
"""Test transitions stream functionality."""
|
||||
shutdown_event = Event()
|
||||
transitions_queue = Queue()
|
||||
|
||||
# Add test data to queue
|
||||
test_data = [b"transition_data_1", b"transition_data_2", b"transition_data_3"]
|
||||
for data in test_data:
|
||||
transitions_queue.put(data)
|
||||
|
||||
# Collect streamed data
|
||||
streamed_data = []
|
||||
stream_generator = transitions_stream(shutdown_event, transitions_queue, 0.1)
|
||||
|
||||
# Process a few items
|
||||
for i, message in enumerate(stream_generator):
|
||||
streamed_data.append(message)
|
||||
if i >= len(test_data) - 1:
|
||||
shutdown_event.set()
|
||||
break
|
||||
|
||||
# Verify we got messages
|
||||
assert len(streamed_data) == len(test_data)
|
||||
assert streamed_data[0].data == b"transition_data_1"
|
||||
assert streamed_data[1].data == b"transition_data_2"
|
||||
assert streamed_data[2].data == b"transition_data_3"
|
||||
|
||||
|
||||
@require_package("grpc")
|
||||
@pytest.mark.timeout(3) # force cross-platform watchdog
|
||||
def test_interactions_stream():
|
||||
from lerobot.common.transport.utils import bytes_to_python_object, python_object_to_bytes
|
||||
from lerobot.scripts.rl.actor import interactions_stream
|
||||
|
||||
"""Test interactions stream functionality."""
|
||||
shutdown_event = Event()
|
||||
interactions_queue = Queue()
|
||||
|
||||
# Create test interaction data (similar structure to what would be sent)
|
||||
test_interactions = [
|
||||
{"episode_reward": 10.5, "step": 1, "policy_fps": 30.2},
|
||||
{"episode_reward": 15.2, "step": 2, "policy_fps": 28.7},
|
||||
{"episode_reward": 8.7, "step": 3, "policy_fps": 29.1},
|
||||
]
|
||||
|
||||
# Serialize the interaction data as it would be in practice
|
||||
test_data = [
|
||||
interactions_queue.put(python_object_to_bytes(interaction)) for interaction in test_interactions
|
||||
]
|
||||
|
||||
# Collect streamed data
|
||||
streamed_data = []
|
||||
stream_generator = interactions_stream(shutdown_event, interactions_queue, 0.1)
|
||||
|
||||
# Process the items
|
||||
for i, message in enumerate(stream_generator):
|
||||
streamed_data.append(message)
|
||||
if i >= len(test_data) - 1:
|
||||
shutdown_event.set()
|
||||
break
|
||||
|
||||
# Verify we got messages
|
||||
assert len(streamed_data) == len(test_data)
|
||||
|
||||
# Verify the messages can be deserialized back to original data
|
||||
for i, message in enumerate(streamed_data):
|
||||
deserialized_interaction = bytes_to_python_object(message.data)
|
||||
assert deserialized_interaction == test_interactions[i]
|
||||
297
tests/rl/test_actor_learner.py
Normal file
297
tests/rl/test_actor_learner.py
Normal file
@@ -0,0 +1,297 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import socket
|
||||
import threading
|
||||
import time
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch.multiprocessing import Event, Queue
|
||||
|
||||
from lerobot.common.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.common.utils.transition import Transition
|
||||
from lerobot.configs.train import TrainRLServerPipelineConfig
|
||||
from tests.utils import require_package
|
||||
|
||||
|
||||
def create_test_transitions(count: int = 3) -> list[Transition]:
|
||||
"""Create test transitions for integration testing."""
|
||||
transitions = []
|
||||
for i in range(count):
|
||||
transition = Transition(
|
||||
state={"observation": torch.randn(3, 64, 64), "state": torch.randn(10)},
|
||||
action=torch.randn(5),
|
||||
reward=torch.tensor(1.0 + i),
|
||||
done=torch.tensor(i == count - 1), # Last transition is done
|
||||
truncated=torch.tensor(False),
|
||||
next_state={"observation": torch.randn(3, 64, 64), "state": torch.randn(10)},
|
||||
complementary_info={"step": torch.tensor(i), "episode_id": i // 2},
|
||||
)
|
||||
transitions.append(transition)
|
||||
return transitions
|
||||
|
||||
|
||||
def create_test_interactions(count: int = 3) -> list[dict]:
|
||||
"""Create test interactions for integration testing."""
|
||||
interactions = []
|
||||
for i in range(count):
|
||||
interaction = {
|
||||
"episode_reward": 10.0 + i * 5,
|
||||
"step": i * 100,
|
||||
"policy_fps": 30.0 + i,
|
||||
"intervention_rate": 0.1 * i,
|
||||
"episode_length": 200 + i * 50,
|
||||
}
|
||||
interactions.append(interaction)
|
||||
return interactions
|
||||
|
||||
|
||||
def find_free_port():
|
||||
"""Finds a free port on the local machine."""
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(("", 0)) # Bind to port 0 to let the OS choose a free port
|
||||
s.listen(1)
|
||||
port = s.getsockname()[1]
|
||||
return port
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cfg():
|
||||
cfg = TrainRLServerPipelineConfig()
|
||||
|
||||
port = find_free_port()
|
||||
|
||||
policy_cfg = SACConfig()
|
||||
policy_cfg.actor_learner_config.learner_host = "127.0.0.1"
|
||||
policy_cfg.actor_learner_config.learner_port = port
|
||||
policy_cfg.concurrency.actor = "threads"
|
||||
policy_cfg.concurrency.learner = "threads"
|
||||
policy_cfg.actor_learner_config.queue_get_timeout = 0.1
|
||||
|
||||
cfg.policy = policy_cfg
|
||||
|
||||
return cfg
|
||||
|
||||
|
||||
@require_package("grpc")
|
||||
@pytest.mark.timeout(10) # force cross-platform watchdog
|
||||
def test_end_to_end_transitions_flow(cfg):
|
||||
from lerobot.common.transport.utils import bytes_to_transitions
|
||||
from lerobot.scripts.rl.actor import (
|
||||
establish_learner_connection,
|
||||
learner_service_client,
|
||||
push_transitions_to_transport_queue,
|
||||
send_transitions,
|
||||
)
|
||||
from lerobot.scripts.rl.learner import start_learner
|
||||
from tests.transport.test_transport_utils import assert_transitions_equal
|
||||
|
||||
"""Test complete transitions flow from actor to learner."""
|
||||
transitions_actor_queue = Queue()
|
||||
transitions_learner_queue = Queue()
|
||||
|
||||
interactions_queue = Queue()
|
||||
parameters_queue = Queue()
|
||||
shutdown_event = Event()
|
||||
|
||||
learner_thread = threading.Thread(
|
||||
target=start_learner,
|
||||
args=(parameters_queue, transitions_learner_queue, interactions_queue, shutdown_event, cfg),
|
||||
)
|
||||
learner_thread.start()
|
||||
|
||||
policy_cfg = cfg.policy
|
||||
learner_client, channel = learner_service_client(
|
||||
host=policy_cfg.actor_learner_config.learner_host, port=policy_cfg.actor_learner_config.learner_port
|
||||
)
|
||||
|
||||
assert establish_learner_connection(learner_client, shutdown_event, attempts=5)
|
||||
|
||||
send_transitions_thread = threading.Thread(
|
||||
target=send_transitions, args=(cfg, transitions_actor_queue, shutdown_event, learner_client, channel)
|
||||
)
|
||||
send_transitions_thread.start()
|
||||
|
||||
input_transitions = create_test_transitions(count=5)
|
||||
|
||||
push_transitions_to_transport_queue(input_transitions, transitions_actor_queue)
|
||||
|
||||
# Wait for learner to start
|
||||
time.sleep(0.1)
|
||||
|
||||
shutdown_event.set()
|
||||
|
||||
# Wait for learner to receive transitions
|
||||
learner_thread.join()
|
||||
send_transitions_thread.join()
|
||||
channel.close()
|
||||
|
||||
received_transitions = []
|
||||
while not transitions_learner_queue.empty():
|
||||
received_transitions.extend(bytes_to_transitions(transitions_learner_queue.get()))
|
||||
|
||||
assert len(received_transitions) == len(input_transitions)
|
||||
for i, transition in enumerate(received_transitions):
|
||||
assert_transitions_equal(transition, input_transitions[i])
|
||||
|
||||
|
||||
@require_package("grpc")
|
||||
@pytest.mark.timeout(10)
|
||||
def test_end_to_end_interactions_flow(cfg):
|
||||
from lerobot.common.transport.utils import bytes_to_python_object, python_object_to_bytes
|
||||
from lerobot.scripts.rl.actor import (
|
||||
establish_learner_connection,
|
||||
learner_service_client,
|
||||
send_interactions,
|
||||
)
|
||||
from lerobot.scripts.rl.learner import start_learner
|
||||
|
||||
"""Test complete interactions flow from actor to learner."""
|
||||
# Queues for actor-learner communication
|
||||
interactions_actor_queue = Queue()
|
||||
interactions_learner_queue = Queue()
|
||||
|
||||
# Other queues required by the learner
|
||||
parameters_queue = Queue()
|
||||
transitions_learner_queue = Queue()
|
||||
|
||||
shutdown_event = Event()
|
||||
|
||||
# Start the learner in a separate thread
|
||||
learner_thread = threading.Thread(
|
||||
target=start_learner,
|
||||
args=(parameters_queue, transitions_learner_queue, interactions_learner_queue, shutdown_event, cfg),
|
||||
)
|
||||
learner_thread.start()
|
||||
|
||||
# Establish connection from actor to learner
|
||||
policy_cfg = cfg.policy
|
||||
learner_client, channel = learner_service_client(
|
||||
host=policy_cfg.actor_learner_config.learner_host, port=policy_cfg.actor_learner_config.learner_port
|
||||
)
|
||||
|
||||
assert establish_learner_connection(learner_client, shutdown_event, attempts=5)
|
||||
|
||||
# Start the actor's interaction sending process in a separate thread
|
||||
send_interactions_thread = threading.Thread(
|
||||
target=send_interactions,
|
||||
args=(cfg, interactions_actor_queue, shutdown_event, learner_client, channel),
|
||||
)
|
||||
send_interactions_thread.start()
|
||||
|
||||
# Create and push test interactions to the actor's queue
|
||||
input_interactions = create_test_interactions(count=5)
|
||||
for interaction in input_interactions:
|
||||
interactions_actor_queue.put(python_object_to_bytes(interaction))
|
||||
|
||||
# Wait for the communication to happen
|
||||
time.sleep(0.1)
|
||||
|
||||
# Signal shutdown and wait for threads to complete
|
||||
shutdown_event.set()
|
||||
learner_thread.join()
|
||||
send_interactions_thread.join()
|
||||
channel.close()
|
||||
|
||||
# Verify that the learner received the interactions
|
||||
received_interactions = []
|
||||
while not interactions_learner_queue.empty():
|
||||
received_interactions.append(bytes_to_python_object(interactions_learner_queue.get()))
|
||||
|
||||
assert len(received_interactions) == len(input_interactions)
|
||||
|
||||
# Sort by a unique key to handle potential reordering in queues
|
||||
received_interactions.sort(key=lambda x: x["step"])
|
||||
input_interactions.sort(key=lambda x: x["step"])
|
||||
|
||||
for received, expected in zip(received_interactions, input_interactions, strict=False):
|
||||
assert received == expected
|
||||
|
||||
|
||||
@require_package("grpc")
|
||||
@pytest.mark.parametrize("data_size", ["small", "large"])
|
||||
@pytest.mark.timeout(10)
|
||||
def test_end_to_end_parameters_flow(cfg, data_size):
|
||||
from lerobot.common.transport.utils import bytes_to_state_dict, state_to_bytes
|
||||
from lerobot.scripts.rl.actor import establish_learner_connection, learner_service_client, receive_policy
|
||||
from lerobot.scripts.rl.learner import start_learner
|
||||
|
||||
"""Test complete parameter flow from learner to actor, with small and large data."""
|
||||
# Actor's local queue to receive params
|
||||
parameters_actor_queue = Queue()
|
||||
# Learner's queue to send params from
|
||||
parameters_learner_queue = Queue()
|
||||
|
||||
# Other queues required by the learner
|
||||
transitions_learner_queue = Queue()
|
||||
interactions_learner_queue = Queue()
|
||||
|
||||
shutdown_event = Event()
|
||||
|
||||
# Start the learner in a separate thread
|
||||
learner_thread = threading.Thread(
|
||||
target=start_learner,
|
||||
args=(
|
||||
parameters_learner_queue,
|
||||
transitions_learner_queue,
|
||||
interactions_learner_queue,
|
||||
shutdown_event,
|
||||
cfg,
|
||||
),
|
||||
)
|
||||
learner_thread.start()
|
||||
|
||||
# Establish connection from actor to learner
|
||||
policy_cfg = cfg.policy
|
||||
learner_client, channel = learner_service_client(
|
||||
host=policy_cfg.actor_learner_config.learner_host, port=policy_cfg.actor_learner_config.learner_port
|
||||
)
|
||||
|
||||
assert establish_learner_connection(learner_client, shutdown_event, attempts=5)
|
||||
|
||||
# Start the actor's parameter receiving process in a separate thread
|
||||
receive_params_thread = threading.Thread(
|
||||
target=receive_policy,
|
||||
args=(cfg, parameters_actor_queue, shutdown_event, learner_client, channel),
|
||||
)
|
||||
receive_params_thread.start()
|
||||
|
||||
# Create test parameters based on parametrization
|
||||
if data_size == "small":
|
||||
input_params = {"layer.weight": torch.randn(128, 64)}
|
||||
else: # "large"
|
||||
# CHUNK_SIZE is 2MB, so this tensor (4MB) will force chunking
|
||||
input_params = {"large_layer.weight": torch.randn(1024, 1024)}
|
||||
|
||||
# Simulate learner having new parameters to send
|
||||
parameters_learner_queue.put(state_to_bytes(input_params))
|
||||
|
||||
# Wait for the actor to receive the parameters
|
||||
time.sleep(0.1)
|
||||
|
||||
# Signal shutdown and wait for threads to complete
|
||||
shutdown_event.set()
|
||||
learner_thread.join()
|
||||
receive_params_thread.join()
|
||||
channel.close()
|
||||
|
||||
# Verify that the actor received the parameters correctly
|
||||
received_params = bytes_to_state_dict(parameters_actor_queue.get())
|
||||
|
||||
assert received_params.keys() == input_params.keys()
|
||||
for key in input_params:
|
||||
assert torch.allclose(received_params[key], input_params[key])
|
||||
374
tests/rl/test_learner_service.py
Normal file
374
tests/rl/test_learner_service.py
Normal file
@@ -0,0 +1,374 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import threading
|
||||
import time
|
||||
from concurrent import futures
|
||||
from multiprocessing import Event, Queue
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.utils import require_package # our gRPC servicer class
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def learner_service_stub():
|
||||
shutdown_event = Event()
|
||||
parameters_queue = Queue()
|
||||
transitions_queue = Queue()
|
||||
interactions_queue = Queue()
|
||||
seconds_between_pushes = 1
|
||||
client, channel, server = create_learner_service_stub(
|
||||
shutdown_event, parameters_queue, transitions_queue, interactions_queue, seconds_between_pushes
|
||||
)
|
||||
|
||||
yield client # provide the stub to the test function
|
||||
|
||||
close_learner_service_stub(channel, server)
|
||||
|
||||
|
||||
@require_package("grpc")
|
||||
def create_learner_service_stub(
|
||||
shutdown_event: Event,
|
||||
parameters_queue: Queue,
|
||||
transitions_queue: Queue,
|
||||
interactions_queue: Queue,
|
||||
seconds_between_pushes: int,
|
||||
queue_get_timeout: float = 0.1,
|
||||
):
|
||||
import grpc
|
||||
|
||||
from lerobot.common.transport import services_pb2_grpc # generated from .proto
|
||||
from lerobot.scripts.rl.learner_service import LearnerService
|
||||
|
||||
"""Fixture to start a LearnerService gRPC server and provide a connected stub."""
|
||||
|
||||
servicer = LearnerService(
|
||||
shutdown_event=shutdown_event,
|
||||
parameters_queue=parameters_queue,
|
||||
seconds_between_pushes=seconds_between_pushes,
|
||||
transition_queue=transitions_queue,
|
||||
interaction_message_queue=interactions_queue,
|
||||
queue_get_timeout=queue_get_timeout,
|
||||
)
|
||||
|
||||
# Create a gRPC server and add our servicer to it.
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=4))
|
||||
services_pb2_grpc.add_LearnerServiceServicer_to_server(servicer, server)
|
||||
port = server.add_insecure_port("[::]:0") # bind to a free port chosen by OS
|
||||
server.start() # start the server (non-blocking call):contentReference[oaicite:1]{index=1}
|
||||
|
||||
# Create a client channel and stub connected to the server's port.
|
||||
channel = grpc.insecure_channel(f"localhost:{port}")
|
||||
return services_pb2_grpc.LearnerServiceStub(channel), channel, server
|
||||
|
||||
|
||||
@require_package("grpc")
|
||||
def close_learner_service_stub(channel, server):
|
||||
channel.close()
|
||||
server.stop(None)
|
||||
|
||||
|
||||
@pytest.mark.timeout(3) # force cross-platform watchdog
|
||||
def test_ready_method(learner_service_stub):
|
||||
from lerobot.common.transport import services_pb2
|
||||
|
||||
"""Test the ready method of the UserService."""
|
||||
request = services_pb2.Empty()
|
||||
response = learner_service_stub.Ready(request)
|
||||
assert response == services_pb2.Empty()
|
||||
|
||||
|
||||
@require_package("grpc")
|
||||
@pytest.mark.timeout(3) # force cross-platform watchdog
|
||||
def test_send_interactions():
|
||||
from lerobot.common.transport import services_pb2
|
||||
|
||||
shutdown_event = Event()
|
||||
|
||||
parameters_queue = Queue()
|
||||
transitions_queue = Queue()
|
||||
interactions_queue = Queue()
|
||||
seconds_between_pushes = 1
|
||||
client, channel, server = create_learner_service_stub(
|
||||
shutdown_event, parameters_queue, transitions_queue, interactions_queue, seconds_between_pushes
|
||||
)
|
||||
|
||||
list_of_interaction_messages = [
|
||||
services_pb2.InteractionMessage(transfer_state=services_pb2.TransferState.TRANSFER_BEGIN, data=b"1"),
|
||||
services_pb2.InteractionMessage(transfer_state=services_pb2.TransferState.TRANSFER_MIDDLE, data=b"2"),
|
||||
services_pb2.InteractionMessage(transfer_state=services_pb2.TransferState.TRANSFER_END, data=b"3"),
|
||||
services_pb2.InteractionMessage(transfer_state=services_pb2.TransferState.TRANSFER_END, data=b"4"),
|
||||
services_pb2.InteractionMessage(transfer_state=services_pb2.TransferState.TRANSFER_END, data=b"5"),
|
||||
services_pb2.InteractionMessage(transfer_state=services_pb2.TransferState.TRANSFER_BEGIN, data=b"6"),
|
||||
services_pb2.InteractionMessage(transfer_state=services_pb2.TransferState.TRANSFER_MIDDLE, data=b"7"),
|
||||
services_pb2.InteractionMessage(transfer_state=services_pb2.TransferState.TRANSFER_END, data=b"8"),
|
||||
]
|
||||
|
||||
def mock_intercations_stream():
|
||||
yield from list_of_interaction_messages
|
||||
|
||||
return services_pb2.Empty()
|
||||
|
||||
response = client.SendInteractions(mock_intercations_stream())
|
||||
assert response == services_pb2.Empty()
|
||||
|
||||
close_learner_service_stub(channel, server)
|
||||
|
||||
# Extract the data from the interactions queue
|
||||
interactions = []
|
||||
while not interactions_queue.empty():
|
||||
interactions.append(interactions_queue.get())
|
||||
|
||||
assert interactions == [b"123", b"4", b"5", b"678"]
|
||||
|
||||
|
||||
@require_package("grpc")
|
||||
@pytest.mark.timeout(3) # force cross-platform watchdog
|
||||
def test_send_transitions():
|
||||
from lerobot.common.transport import services_pb2
|
||||
|
||||
"""Test the SendTransitions method with various transition data."""
|
||||
shutdown_event = Event()
|
||||
parameters_queue = Queue()
|
||||
transitions_queue = Queue()
|
||||
interactions_queue = Queue()
|
||||
seconds_between_pushes = 1
|
||||
|
||||
client, channel, server = create_learner_service_stub(
|
||||
shutdown_event, parameters_queue, transitions_queue, interactions_queue, seconds_between_pushes
|
||||
)
|
||||
|
||||
# Create test transition messages
|
||||
list_of_transition_messages = [
|
||||
services_pb2.Transition(
|
||||
transfer_state=services_pb2.TransferState.TRANSFER_BEGIN, data=b"transition_1"
|
||||
),
|
||||
services_pb2.Transition(
|
||||
transfer_state=services_pb2.TransferState.TRANSFER_MIDDLE, data=b"transition_2"
|
||||
),
|
||||
services_pb2.Transition(transfer_state=services_pb2.TransferState.TRANSFER_END, data=b"transition_3"),
|
||||
services_pb2.Transition(transfer_state=services_pb2.TransferState.TRANSFER_BEGIN, data=b"batch_1"),
|
||||
services_pb2.Transition(transfer_state=services_pb2.TransferState.TRANSFER_END, data=b"batch_2"),
|
||||
]
|
||||
|
||||
def mock_transitions_stream():
|
||||
yield from list_of_transition_messages
|
||||
|
||||
response = client.SendTransitions(mock_transitions_stream())
|
||||
assert response == services_pb2.Empty()
|
||||
|
||||
close_learner_service_stub(channel, server)
|
||||
|
||||
# Extract the data from the transitions queue
|
||||
transitions = []
|
||||
while not transitions_queue.empty():
|
||||
transitions.append(transitions_queue.get())
|
||||
|
||||
# Should have assembled the chunked data
|
||||
assert transitions == [b"transition_1transition_2transition_3", b"batch_1batch_2"]
|
||||
|
||||
|
||||
@require_package("grpc")
|
||||
@pytest.mark.timeout(3) # force cross-platform watchdog
|
||||
def test_send_transitions_empty_stream():
|
||||
from lerobot.common.transport import services_pb2
|
||||
|
||||
"""Test SendTransitions with empty stream."""
|
||||
shutdown_event = Event()
|
||||
parameters_queue = Queue()
|
||||
transitions_queue = Queue()
|
||||
interactions_queue = Queue()
|
||||
seconds_between_pushes = 1
|
||||
|
||||
client, channel, server = create_learner_service_stub(
|
||||
shutdown_event, parameters_queue, transitions_queue, interactions_queue, seconds_between_pushes
|
||||
)
|
||||
|
||||
def empty_stream():
|
||||
return iter([])
|
||||
|
||||
response = client.SendTransitions(empty_stream())
|
||||
assert response == services_pb2.Empty()
|
||||
|
||||
close_learner_service_stub(channel, server)
|
||||
|
||||
# Queue should remain empty
|
||||
assert transitions_queue.empty()
|
||||
|
||||
|
||||
@require_package("grpc")
|
||||
@pytest.mark.timeout(10) # force cross-platform watchdog
|
||||
def test_stream_parameters():
|
||||
import time
|
||||
|
||||
from lerobot.common.transport import services_pb2
|
||||
|
||||
"""Test the StreamParameters method."""
|
||||
shutdown_event = Event()
|
||||
parameters_queue = Queue()
|
||||
transitions_queue = Queue()
|
||||
interactions_queue = Queue()
|
||||
seconds_between_pushes = 0.2 # Short delay for testing
|
||||
|
||||
client, channel, server = create_learner_service_stub(
|
||||
shutdown_event, parameters_queue, transitions_queue, interactions_queue, seconds_between_pushes
|
||||
)
|
||||
|
||||
# Add test parameters to the queue
|
||||
test_params = [b"param_batch_1", b"param_batch_2"]
|
||||
for param in test_params:
|
||||
parameters_queue.put(param)
|
||||
|
||||
# Start streaming parameters
|
||||
request = services_pb2.Empty()
|
||||
stream = client.StreamParameters(request)
|
||||
|
||||
# Collect streamed parameters and timestamps
|
||||
received_params = []
|
||||
timestamps = []
|
||||
|
||||
for response in stream:
|
||||
received_params.append(response.data)
|
||||
timestamps.append(time.time())
|
||||
|
||||
# We should receive one last item
|
||||
break
|
||||
|
||||
parameters_queue.put(b"param_batch_3")
|
||||
|
||||
for response in stream:
|
||||
received_params.append(response.data)
|
||||
timestamps.append(time.time())
|
||||
|
||||
# We should receive only one item
|
||||
break
|
||||
|
||||
shutdown_event.set()
|
||||
close_learner_service_stub(channel, server)
|
||||
|
||||
assert received_params == [b"param_batch_2", b"param_batch_3"]
|
||||
|
||||
# Check the time difference between the two sends
|
||||
time_diff = timestamps[1] - timestamps[0]
|
||||
# Check if the time difference is close to the expected push frequency
|
||||
assert time_diff == pytest.approx(seconds_between_pushes, abs=0.1)
|
||||
|
||||
|
||||
@require_package("grpc")
|
||||
@pytest.mark.timeout(3) # force cross-platform watchdog
|
||||
def test_stream_parameters_with_shutdown():
|
||||
from lerobot.common.transport import services_pb2
|
||||
|
||||
"""Test StreamParameters handles shutdown gracefully."""
|
||||
shutdown_event = Event()
|
||||
parameters_queue = Queue()
|
||||
transitions_queue = Queue()
|
||||
interactions_queue = Queue()
|
||||
seconds_between_pushes = 0.1
|
||||
queue_get_timeout = 0.001
|
||||
|
||||
client, channel, server = create_learner_service_stub(
|
||||
shutdown_event,
|
||||
parameters_queue,
|
||||
transitions_queue,
|
||||
interactions_queue,
|
||||
seconds_between_pushes,
|
||||
queue_get_timeout=queue_get_timeout,
|
||||
)
|
||||
|
||||
test_params = [b"param_batch_1", b"stop", b"param_batch_3", b"param_batch_4"]
|
||||
|
||||
# create a thread that will put the parameters in the queue
|
||||
def producer():
|
||||
for param in test_params:
|
||||
parameters_queue.put(param)
|
||||
time.sleep(0.1)
|
||||
|
||||
producer_thread = threading.Thread(target=producer)
|
||||
producer_thread.start()
|
||||
|
||||
# Start streaming
|
||||
request = services_pb2.Empty()
|
||||
stream = client.StreamParameters(request)
|
||||
|
||||
# Collect streamed parameters
|
||||
received_params = []
|
||||
|
||||
for response in stream:
|
||||
received_params.append(response.data)
|
||||
|
||||
if response.data == b"stop":
|
||||
shutdown_event.set()
|
||||
|
||||
producer_thread.join()
|
||||
close_learner_service_stub(channel, server)
|
||||
|
||||
assert received_params == [b"param_batch_1", b"stop"]
|
||||
|
||||
|
||||
@require_package("grpc")
|
||||
@pytest.mark.timeout(3) # force cross-platform watchdog
|
||||
def test_stream_parameters_waits_and_retries_on_empty_queue():
|
||||
import threading
|
||||
import time
|
||||
|
||||
from lerobot.common.transport import services_pb2
|
||||
|
||||
"""Test that StreamParameters waits and retries when the queue is empty."""
|
||||
shutdown_event = Event()
|
||||
parameters_queue = Queue()
|
||||
transitions_queue = Queue()
|
||||
interactions_queue = Queue()
|
||||
seconds_between_pushes = 0.05
|
||||
queue_get_timeout = 0.01
|
||||
|
||||
client, channel, server = create_learner_service_stub(
|
||||
shutdown_event,
|
||||
parameters_queue,
|
||||
transitions_queue,
|
||||
interactions_queue,
|
||||
seconds_between_pushes,
|
||||
queue_get_timeout=queue_get_timeout,
|
||||
)
|
||||
|
||||
request = services_pb2.Empty()
|
||||
stream = client.StreamParameters(request)
|
||||
|
||||
received_params = []
|
||||
|
||||
def producer():
|
||||
# Let the consumer start and find an empty queue.
|
||||
# It will wait `seconds_between_pushes` (0.05s), then `get` will timeout after `queue_get_timeout` (0.01s).
|
||||
# Total time for the first empty loop is > 0.06s. We wait a bit longer to be safe.
|
||||
time.sleep(0.06)
|
||||
parameters_queue.put(b"param_after_wait")
|
||||
time.sleep(0.05)
|
||||
parameters_queue.put(b"param_after_wait_2")
|
||||
|
||||
producer_thread = threading.Thread(target=producer)
|
||||
producer_thread.start()
|
||||
|
||||
# The consumer will block here until the producer sends an item.
|
||||
for response in stream:
|
||||
received_params.append(response.data)
|
||||
if response.data == b"param_after_wait_2":
|
||||
break # We only need one item for this test.
|
||||
|
||||
shutdown_event.set()
|
||||
producer_thread.join()
|
||||
close_learner_service_stub(channel, server)
|
||||
|
||||
assert received_params == [b"param_after_wait", b"param_after_wait_2"]
|
||||
571
tests/transport/test_transport_utils.py
Normal file
571
tests/transport/test_transport_utils.py
Normal file
@@ -0,0 +1,571 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import io
|
||||
from multiprocessing import Event, Queue
|
||||
from pickle import UnpicklingError
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.common.utils.transition import Transition
|
||||
from tests.utils import require_cuda, require_package
|
||||
|
||||
|
||||
@require_package("grpc")
|
||||
def test_bytes_buffer_size_empty_buffer():
|
||||
from lerobot.common.transport.utils import bytes_buffer_size
|
||||
|
||||
"""Test with an empty buffer."""
|
||||
buffer = io.BytesIO()
|
||||
assert bytes_buffer_size(buffer) == 0
|
||||
# Ensure position is reset to beginning
|
||||
assert buffer.tell() == 0
|
||||
|
||||
|
||||
@require_package("grpc")
|
||||
def test_bytes_buffer_size_small_buffer():
|
||||
from lerobot.common.transport.utils import bytes_buffer_size
|
||||
|
||||
"""Test with a small buffer."""
|
||||
buffer = io.BytesIO(b"Hello, World!")
|
||||
assert bytes_buffer_size(buffer) == 13
|
||||
assert buffer.tell() == 0
|
||||
|
||||
|
||||
@require_package("grpc")
|
||||
def test_bytes_buffer_size_large_buffer():
|
||||
from lerobot.common.transport.utils import CHUNK_SIZE, bytes_buffer_size
|
||||
|
||||
"""Test with a large buffer."""
|
||||
data = b"x" * (CHUNK_SIZE * 2 + 1000)
|
||||
buffer = io.BytesIO(data)
|
||||
assert bytes_buffer_size(buffer) == len(data)
|
||||
assert buffer.tell() == 0
|
||||
|
||||
|
||||
@require_package("grpc")
|
||||
def test_send_bytes_in_chunks_empty_data():
|
||||
from lerobot.common.transport.utils import send_bytes_in_chunks, services_pb2
|
||||
|
||||
"""Test sending empty data."""
|
||||
message_class = services_pb2.InteractionMessage
|
||||
chunks = list(send_bytes_in_chunks(b"", message_class))
|
||||
assert len(chunks) == 0
|
||||
|
||||
|
||||
@require_package("grpc")
|
||||
def test_single_chunk_small_data():
|
||||
from lerobot.common.transport.utils import send_bytes_in_chunks, services_pb2
|
||||
|
||||
"""Test data that fits in a single chunk."""
|
||||
data = b"Some data"
|
||||
message_class = services_pb2.InteractionMessage
|
||||
chunks = list(send_bytes_in_chunks(data, message_class))
|
||||
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0].data == b"Some data"
|
||||
assert chunks[0].transfer_state == services_pb2.TransferState.TRANSFER_END
|
||||
|
||||
|
||||
@require_package("grpc")
|
||||
def test_not_silent_mode():
|
||||
from lerobot.common.transport.utils import send_bytes_in_chunks, services_pb2
|
||||
|
||||
"""Test not silent mode."""
|
||||
data = b"Some data"
|
||||
message_class = services_pb2.InteractionMessage
|
||||
chunks = list(send_bytes_in_chunks(data, message_class, silent=False))
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0].data == b"Some data"
|
||||
|
||||
|
||||
@require_package("grpc")
|
||||
def test_send_bytes_in_chunks_large_data():
|
||||
from lerobot.common.transport.utils import CHUNK_SIZE, send_bytes_in_chunks, services_pb2
|
||||
|
||||
"""Test sending large data."""
|
||||
data = b"x" * (CHUNK_SIZE * 2 + 1000)
|
||||
message_class = services_pb2.InteractionMessage
|
||||
chunks = list(send_bytes_in_chunks(data, message_class))
|
||||
assert len(chunks) == 3
|
||||
assert chunks[0].data == b"x" * CHUNK_SIZE
|
||||
assert chunks[0].transfer_state == services_pb2.TransferState.TRANSFER_BEGIN
|
||||
assert chunks[1].data == b"x" * CHUNK_SIZE
|
||||
assert chunks[1].transfer_state == services_pb2.TransferState.TRANSFER_MIDDLE
|
||||
assert chunks[2].data == b"x" * 1000
|
||||
assert chunks[2].transfer_state == services_pb2.TransferState.TRANSFER_END
|
||||
|
||||
|
||||
@require_package("grpc")
|
||||
def test_send_bytes_in_chunks_large_data_with_exact_chunk_size():
|
||||
from lerobot.common.transport.utils import CHUNK_SIZE, send_bytes_in_chunks, services_pb2
|
||||
|
||||
"""Test sending large data with exact chunk size."""
|
||||
data = b"x" * CHUNK_SIZE
|
||||
message_class = services_pb2.InteractionMessage
|
||||
chunks = list(send_bytes_in_chunks(data, message_class))
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0].data == data
|
||||
assert chunks[0].transfer_state == services_pb2.TransferState.TRANSFER_END
|
||||
|
||||
|
||||
@require_package("grpc")
|
||||
def test_receive_bytes_in_chunks_empty_data():
|
||||
from lerobot.common.transport.utils import receive_bytes_in_chunks
|
||||
|
||||
"""Test receiving empty data."""
|
||||
queue = Queue()
|
||||
shutdown_event = Event()
|
||||
|
||||
# Empty iterator
|
||||
receive_bytes_in_chunks(iter([]), queue, shutdown_event)
|
||||
|
||||
assert queue.empty()
|
||||
|
||||
|
||||
@require_package("grpc")
|
||||
def test_receive_bytes_in_chunks_single_chunk():
|
||||
from lerobot.common.transport.utils import receive_bytes_in_chunks, services_pb2
|
||||
|
||||
"""Test receiving a single chunk message."""
|
||||
queue = Queue()
|
||||
shutdown_event = Event()
|
||||
|
||||
data = b"Single chunk data"
|
||||
chunks = [
|
||||
services_pb2.InteractionMessage(data=data, transfer_state=services_pb2.TransferState.TRANSFER_END)
|
||||
]
|
||||
|
||||
receive_bytes_in_chunks(iter(chunks), queue, shutdown_event)
|
||||
|
||||
assert queue.get(timeout=0.01) == data
|
||||
assert queue.empty()
|
||||
|
||||
|
||||
@require_package("grpc")
|
||||
def test_receive_bytes_in_chunks_single_not_end_chunk():
|
||||
from lerobot.common.transport.utils import receive_bytes_in_chunks, services_pb2
|
||||
|
||||
"""Test receiving a single chunk message."""
|
||||
queue = Queue()
|
||||
shutdown_event = Event()
|
||||
|
||||
data = b"Single chunk data"
|
||||
chunks = [
|
||||
services_pb2.InteractionMessage(data=data, transfer_state=services_pb2.TransferState.TRANSFER_MIDDLE)
|
||||
]
|
||||
|
||||
receive_bytes_in_chunks(iter(chunks), queue, shutdown_event)
|
||||
|
||||
assert queue.empty()
|
||||
|
||||
|
||||
@require_package("grpc")
|
||||
def test_receive_bytes_in_chunks_multiple_chunks():
|
||||
from lerobot.common.transport.utils import receive_bytes_in_chunks, services_pb2
|
||||
|
||||
"""Test receiving a multi-chunk message."""
|
||||
queue = Queue()
|
||||
shutdown_event = Event()
|
||||
|
||||
chunks = [
|
||||
services_pb2.InteractionMessage(
|
||||
data=b"First ", transfer_state=services_pb2.TransferState.TRANSFER_BEGIN
|
||||
),
|
||||
services_pb2.InteractionMessage(
|
||||
data=b"Middle ", transfer_state=services_pb2.TransferState.TRANSFER_MIDDLE
|
||||
),
|
||||
services_pb2.InteractionMessage(data=b"Last", transfer_state=services_pb2.TransferState.TRANSFER_END),
|
||||
]
|
||||
|
||||
receive_bytes_in_chunks(iter(chunks), queue, shutdown_event)
|
||||
|
||||
assert queue.get(timeout=0.01) == b"First Middle Last"
|
||||
assert queue.empty()
|
||||
|
||||
|
||||
@require_package("grpc")
|
||||
def test_receive_bytes_in_chunks_multiple_messages():
|
||||
from lerobot.common.transport.utils import receive_bytes_in_chunks, services_pb2
|
||||
|
||||
"""Test receiving multiple complete messages in sequence."""
|
||||
queue = Queue()
|
||||
shutdown_event = Event()
|
||||
|
||||
chunks = [
|
||||
# First message - single chunk
|
||||
services_pb2.InteractionMessage(
|
||||
data=b"Message1", transfer_state=services_pb2.TransferState.TRANSFER_END
|
||||
),
|
||||
# Second message - multi chunk
|
||||
services_pb2.InteractionMessage(
|
||||
data=b"Start2 ", transfer_state=services_pb2.TransferState.TRANSFER_BEGIN
|
||||
),
|
||||
services_pb2.InteractionMessage(
|
||||
data=b"Middle2 ", transfer_state=services_pb2.TransferState.TRANSFER_MIDDLE
|
||||
),
|
||||
services_pb2.InteractionMessage(data=b"End2", transfer_state=services_pb2.TransferState.TRANSFER_END),
|
||||
# Third message - single chunk
|
||||
services_pb2.InteractionMessage(
|
||||
data=b"Message3", transfer_state=services_pb2.TransferState.TRANSFER_END
|
||||
),
|
||||
]
|
||||
|
||||
receive_bytes_in_chunks(iter(chunks), queue, shutdown_event)
|
||||
|
||||
# Should have three messages in queue
|
||||
assert queue.get(timeout=0.01) == b"Message1"
|
||||
assert queue.get(timeout=0.01) == b"Start2 Middle2 End2"
|
||||
assert queue.get(timeout=0.01) == b"Message3"
|
||||
assert queue.empty()
|
||||
|
||||
|
||||
@require_package("grpc")
|
||||
def test_receive_bytes_in_chunks_shutdown_during_receive():
|
||||
from lerobot.common.transport.utils import receive_bytes_in_chunks, services_pb2
|
||||
|
||||
"""Test that shutdown event stops receiving mid-stream."""
|
||||
queue = Queue()
|
||||
shutdown_event = Event()
|
||||
shutdown_event.set()
|
||||
|
||||
chunks = [
|
||||
services_pb2.InteractionMessage(
|
||||
data=b"First ", transfer_state=services_pb2.TransferState.TRANSFER_BEGIN
|
||||
),
|
||||
services_pb2.InteractionMessage(
|
||||
data=b"Middle ", transfer_state=services_pb2.TransferState.TRANSFER_MIDDLE
|
||||
),
|
||||
services_pb2.InteractionMessage(data=b"Last", transfer_state=services_pb2.TransferState.TRANSFER_END),
|
||||
]
|
||||
|
||||
receive_bytes_in_chunks(iter(chunks), queue, shutdown_event)
|
||||
|
||||
assert queue.empty()
|
||||
|
||||
|
||||
@require_package("grpc")
|
||||
def test_receive_bytes_in_chunks_only_begin_chunk():
|
||||
from lerobot.common.transport.utils import receive_bytes_in_chunks, services_pb2
|
||||
|
||||
"""Test receiving only a BEGIN chunk without END."""
|
||||
queue = Queue()
|
||||
shutdown_event = Event()
|
||||
|
||||
chunks = [
|
||||
services_pb2.InteractionMessage(
|
||||
data=b"Start", transfer_state=services_pb2.TransferState.TRANSFER_BEGIN
|
||||
),
|
||||
# No END chunk
|
||||
]
|
||||
|
||||
receive_bytes_in_chunks(iter(chunks), queue, shutdown_event)
|
||||
|
||||
assert queue.empty()
|
||||
|
||||
|
||||
@require_package("grpc")
|
||||
def test_receive_bytes_in_chunks_missing_begin():
|
||||
from lerobot.common.transport.utils import receive_bytes_in_chunks, services_pb2
|
||||
|
||||
"""Test receiving chunks starting with MIDDLE instead of BEGIN."""
|
||||
queue = Queue()
|
||||
shutdown_event = Event()
|
||||
|
||||
chunks = [
|
||||
# Missing BEGIN
|
||||
services_pb2.InteractionMessage(
|
||||
data=b"Middle", transfer_state=services_pb2.TransferState.TRANSFER_MIDDLE
|
||||
),
|
||||
services_pb2.InteractionMessage(data=b"End", transfer_state=services_pb2.TransferState.TRANSFER_END),
|
||||
]
|
||||
|
||||
receive_bytes_in_chunks(iter(chunks), queue, shutdown_event)
|
||||
|
||||
# The implementation continues from where it is, so we should get partial data
|
||||
assert queue.get(timeout=0.01) == b"MiddleEnd"
|
||||
assert queue.empty()
|
||||
|
||||
|
||||
# Tests for state_to_bytes and bytes_to_state_dict
|
||||
@require_package("grpc")
|
||||
def test_state_to_bytes_empty_dict():
|
||||
from lerobot.common.transport.utils import bytes_to_state_dict, state_to_bytes
|
||||
|
||||
"""Test converting empty state dict to bytes."""
|
||||
state_dict = {}
|
||||
data = state_to_bytes(state_dict)
|
||||
reconstructed = bytes_to_state_dict(data)
|
||||
assert reconstructed == state_dict
|
||||
|
||||
|
||||
@require_package("grpc")
|
||||
def test_bytes_to_state_dict_empty_data():
|
||||
from lerobot.common.transport.utils import bytes_to_state_dict
|
||||
|
||||
"""Test converting empty data to state dict."""
|
||||
with pytest.raises(EOFError):
|
||||
bytes_to_state_dict(b"")
|
||||
|
||||
|
||||
@require_package("grpc")
|
||||
def test_state_to_bytes_simple_dict():
|
||||
from lerobot.common.transport.utils import bytes_to_state_dict, state_to_bytes
|
||||
|
||||
"""Test converting simple state dict to bytes."""
|
||||
state_dict = {
|
||||
"layer1.weight": torch.randn(10, 5),
|
||||
"layer1.bias": torch.randn(10),
|
||||
"layer2.weight": torch.randn(1, 10),
|
||||
"layer2.bias": torch.randn(1),
|
||||
}
|
||||
|
||||
data = state_to_bytes(state_dict)
|
||||
assert isinstance(data, bytes)
|
||||
assert len(data) > 0
|
||||
|
||||
reconstructed = bytes_to_state_dict(data)
|
||||
|
||||
assert len(reconstructed) == len(state_dict)
|
||||
for key in state_dict:
|
||||
assert key in reconstructed
|
||||
assert torch.allclose(state_dict[key], reconstructed[key])
|
||||
|
||||
|
||||
@require_package("grpc")
|
||||
def test_state_to_bytes_various_dtypes():
|
||||
from lerobot.common.transport.utils import bytes_to_state_dict, state_to_bytes
|
||||
|
||||
"""Test converting state dict with various tensor dtypes."""
|
||||
state_dict = {
|
||||
"float32": torch.randn(5, 5),
|
||||
"float64": torch.randn(3, 3).double(),
|
||||
"int32": torch.randint(0, 100, (4, 4), dtype=torch.int32),
|
||||
"int64": torch.randint(0, 100, (2, 2), dtype=torch.int64),
|
||||
"bool": torch.tensor([True, False, True]),
|
||||
"uint8": torch.randint(0, 255, (3, 3), dtype=torch.uint8),
|
||||
}
|
||||
|
||||
data = state_to_bytes(state_dict)
|
||||
reconstructed = bytes_to_state_dict(data)
|
||||
|
||||
for key in state_dict:
|
||||
assert reconstructed[key].dtype == state_dict[key].dtype
|
||||
if state_dict[key].dtype == torch.bool:
|
||||
assert torch.equal(state_dict[key], reconstructed[key])
|
||||
else:
|
||||
assert torch.allclose(state_dict[key], reconstructed[key])
|
||||
|
||||
|
||||
@require_package("grpc")
|
||||
def test_bytes_to_state_dict_invalid_data():
|
||||
from lerobot.common.transport.utils import bytes_to_state_dict
|
||||
|
||||
"""Test bytes_to_state_dict with invalid data."""
|
||||
with pytest.raises(UnpicklingError):
|
||||
bytes_to_state_dict(b"This is not a valid torch save file")
|
||||
|
||||
|
||||
@require_cuda
|
||||
@require_package("grpc")
|
||||
def test_state_to_bytes_various_dtypes_cuda():
|
||||
from lerobot.common.transport.utils import bytes_to_state_dict, state_to_bytes
|
||||
|
||||
"""Test converting state dict with various tensor dtypes."""
|
||||
state_dict = {
|
||||
"float32": torch.randn(5, 5).cuda(),
|
||||
"float64": torch.randn(3, 3).double().cuda(),
|
||||
"int32": torch.randint(0, 100, (4, 4), dtype=torch.int32).cuda(),
|
||||
"int64": torch.randint(0, 100, (2, 2), dtype=torch.int64).cuda(),
|
||||
"bool": torch.tensor([True, False, True]),
|
||||
"uint8": torch.randint(0, 255, (3, 3), dtype=torch.uint8),
|
||||
}
|
||||
|
||||
data = state_to_bytes(state_dict)
|
||||
reconstructed = bytes_to_state_dict(data)
|
||||
|
||||
for key in state_dict:
|
||||
assert reconstructed[key].dtype == state_dict[key].dtype
|
||||
if state_dict[key].dtype == torch.bool:
|
||||
assert torch.equal(state_dict[key], reconstructed[key])
|
||||
else:
|
||||
assert torch.allclose(state_dict[key], reconstructed[key])
|
||||
|
||||
|
||||
@require_package("grpc")
|
||||
def test_python_object_to_bytes_none():
|
||||
from lerobot.common.transport.utils import bytes_to_python_object, python_object_to_bytes
|
||||
|
||||
"""Test converting None to bytes."""
|
||||
obj = None
|
||||
data = python_object_to_bytes(obj)
|
||||
reconstructed = bytes_to_python_object(data)
|
||||
assert reconstructed is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"obj",
|
||||
[
|
||||
42,
|
||||
-123,
|
||||
3.14159,
|
||||
-2.71828,
|
||||
"Hello, World!",
|
||||
"Unicode: 你好世界 🌍",
|
||||
True,
|
||||
False,
|
||||
b"byte string",
|
||||
[],
|
||||
[1, 2, 3],
|
||||
[1, "two", 3.0, True, None],
|
||||
{},
|
||||
{"key": "value", "number": 123, "nested": {"a": 1}},
|
||||
(),
|
||||
(1, 2, 3),
|
||||
],
|
||||
)
|
||||
@require_package("grpc")
|
||||
def test_python_object_to_bytes_simple_types(obj):
|
||||
from lerobot.common.transport.utils import bytes_to_python_object, python_object_to_bytes
|
||||
|
||||
"""Test converting simple Python types."""
|
||||
data = python_object_to_bytes(obj)
|
||||
reconstructed = bytes_to_python_object(data)
|
||||
assert reconstructed == obj
|
||||
assert type(reconstructed) is type(obj)
|
||||
|
||||
|
||||
@require_package("grpc")
|
||||
def test_python_object_to_bytes_with_tensors():
|
||||
from lerobot.common.transport.utils import bytes_to_python_object, python_object_to_bytes
|
||||
|
||||
"""Test converting objects containing PyTorch tensors."""
|
||||
obj = {
|
||||
"tensor": torch.randn(5, 5),
|
||||
"list_with_tensor": [1, 2, torch.randn(3, 3), "string"],
|
||||
"nested": {
|
||||
"tensor1": torch.randn(2, 2),
|
||||
"tensor2": torch.tensor([1, 2, 3]),
|
||||
},
|
||||
}
|
||||
|
||||
data = python_object_to_bytes(obj)
|
||||
reconstructed = bytes_to_python_object(data)
|
||||
|
||||
assert torch.allclose(obj["tensor"], reconstructed["tensor"])
|
||||
assert reconstructed["list_with_tensor"][0] == 1
|
||||
assert reconstructed["list_with_tensor"][3] == "string"
|
||||
assert torch.allclose(obj["list_with_tensor"][2], reconstructed["list_with_tensor"][2])
|
||||
assert torch.allclose(obj["nested"]["tensor1"], reconstructed["nested"]["tensor1"])
|
||||
assert torch.equal(obj["nested"]["tensor2"], reconstructed["nested"]["tensor2"])
|
||||
|
||||
|
||||
@require_package("grpc")
|
||||
def test_transitions_to_bytes_empty_list():
|
||||
from lerobot.common.transport.utils import bytes_to_transitions, transitions_to_bytes
|
||||
|
||||
"""Test converting empty transitions list."""
|
||||
transitions = []
|
||||
data = transitions_to_bytes(transitions)
|
||||
reconstructed = bytes_to_transitions(data)
|
||||
assert reconstructed == transitions
|
||||
assert isinstance(reconstructed, list)
|
||||
|
||||
|
||||
@require_package("grpc")
|
||||
def test_transitions_to_bytes_single_transition():
|
||||
from lerobot.common.transport.utils import bytes_to_transitions, transitions_to_bytes
|
||||
|
||||
"""Test converting a single transition."""
|
||||
transition = Transition(
|
||||
state={"image": torch.randn(3, 64, 64), "state": torch.randn(10)},
|
||||
action=torch.randn(5),
|
||||
reward=torch.tensor(1.5),
|
||||
done=torch.tensor(False),
|
||||
next_state={"image": torch.randn(3, 64, 64), "state": torch.randn(10)},
|
||||
)
|
||||
|
||||
transitions = [transition]
|
||||
data = transitions_to_bytes(transitions)
|
||||
reconstructed = bytes_to_transitions(data)
|
||||
|
||||
assert len(reconstructed) == 1
|
||||
|
||||
assert_transitions_equal(transitions[0], reconstructed[0])
|
||||
|
||||
|
||||
@require_package("grpc")
|
||||
def assert_transitions_equal(t1: Transition, t2: Transition):
|
||||
"""Helper to assert two transitions are equal."""
|
||||
assert_observation_equal(t1["state"], t2["state"])
|
||||
assert torch.allclose(t1["action"], t2["action"])
|
||||
assert torch.allclose(t1["reward"], t2["reward"])
|
||||
assert torch.equal(t1["done"], t2["done"])
|
||||
assert_observation_equal(t1["next_state"], t2["next_state"])
|
||||
|
||||
|
||||
@require_package("grpc")
|
||||
def assert_observation_equal(o1: dict, o2: dict):
|
||||
"""Helper to assert two observations are equal."""
|
||||
assert set(o1.keys()) == set(o2.keys())
|
||||
for key in o1:
|
||||
assert torch.allclose(o1[key], o2[key])
|
||||
|
||||
|
||||
@require_package("grpc")
|
||||
def test_transitions_to_bytes_multiple_transitions():
|
||||
from lerobot.common.transport.utils import bytes_to_transitions, transitions_to_bytes
|
||||
|
||||
"""Test converting multiple transitions."""
|
||||
transitions = []
|
||||
for i in range(5):
|
||||
transition = Transition(
|
||||
state={"data": torch.randn(10)},
|
||||
action=torch.randn(3),
|
||||
reward=torch.tensor(float(i)),
|
||||
done=torch.tensor(i == 4),
|
||||
next_state={"data": torch.randn(10)},
|
||||
)
|
||||
transitions.append(transition)
|
||||
|
||||
data = transitions_to_bytes(transitions)
|
||||
reconstructed = bytes_to_transitions(data)
|
||||
|
||||
assert len(reconstructed) == len(transitions)
|
||||
for original, reconstructed_item in zip(transitions, reconstructed, strict=False):
|
||||
assert_transitions_equal(original, reconstructed_item)
|
||||
|
||||
|
||||
@require_package("grpc")
|
||||
def test_receive_bytes_in_chunks_unknown_state():
|
||||
from lerobot.common.transport.utils import receive_bytes_in_chunks
|
||||
|
||||
"""Test receive_bytes_in_chunks with an unknown transfer state."""
|
||||
|
||||
# Mock the gRPC message object, which has `transfer_state` and `data` attributes.
|
||||
class MockMessage:
|
||||
def __init__(self, transfer_state, data):
|
||||
self.transfer_state = transfer_state
|
||||
self.data = data
|
||||
|
||||
# 10 is not a valid TransferState enum value
|
||||
bad_iterator = [MockMessage(transfer_state=10, data=b"bad_data")]
|
||||
output_queue = Queue()
|
||||
shutdown_event = Event()
|
||||
|
||||
with pytest.raises(ValueError, match="Received unknown transfer state"):
|
||||
receive_bytes_in_chunks(bad_iterator, output_queue, shutdown_event)
|
||||
112
tests/utils/test_process.py
Normal file
112
tests/utils/test_process.py
Normal file
@@ -0,0 +1,112 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import multiprocessing
|
||||
import os
|
||||
import signal
|
||||
import threading
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.common.utils.process import ProcessSignalHandler
|
||||
|
||||
|
||||
# Fixture to reset shutdown_event_counter and original signal handlers before and after each test
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_globals_and_handlers():
|
||||
# Store original signal handlers
|
||||
original_handlers = {
|
||||
sig: signal.getsignal(sig)
|
||||
for sig in [signal.SIGINT, signal.SIGTERM, signal.SIGHUP, signal.SIGQUIT]
|
||||
if hasattr(signal, sig.name)
|
||||
}
|
||||
|
||||
yield
|
||||
|
||||
# Restore original signal handlers
|
||||
for sig, handler in original_handlers.items():
|
||||
signal.signal(sig, handler)
|
||||
|
||||
|
||||
def test_setup_process_handlers_event_with_threads():
|
||||
"""Test that setup_process_handlers returns the correct event type."""
|
||||
handler = ProcessSignalHandler(use_threads=True)
|
||||
shutdown_event = handler.shutdown_event
|
||||
assert isinstance(shutdown_event, threading.Event), "Should be a threading.Event"
|
||||
assert not shutdown_event.is_set(), "Event should initially be unset"
|
||||
|
||||
|
||||
def test_setup_process_handlers_event_with_processes():
|
||||
"""Test that setup_process_handlers returns the correct event type."""
|
||||
handler = ProcessSignalHandler(use_threads=False)
|
||||
shutdown_event = handler.shutdown_event
|
||||
assert isinstance(shutdown_event, type(multiprocessing.Event())), "Should be a multiprocessing.Event"
|
||||
assert not shutdown_event.is_set(), "Event should initially be unset"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_threads", [True, False])
|
||||
@pytest.mark.parametrize(
|
||||
"sig",
|
||||
[
|
||||
signal.SIGINT,
|
||||
signal.SIGTERM,
|
||||
# SIGHUP and SIGQUIT are not reliably available on all platforms (e.g. Windows)
|
||||
pytest.param(
|
||||
signal.SIGHUP,
|
||||
marks=pytest.mark.skipif(not hasattr(signal, "SIGHUP"), reason="SIGHUP not available"),
|
||||
),
|
||||
pytest.param(
|
||||
signal.SIGQUIT,
|
||||
marks=pytest.mark.skipif(not hasattr(signal, "SIGQUIT"), reason="SIGQUIT not available"),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_signal_handler_sets_event(use_threads, sig):
|
||||
"""Test that the signal handler sets the event on receiving a signal."""
|
||||
handler = ProcessSignalHandler(use_threads=use_threads)
|
||||
shutdown_event = handler.shutdown_event
|
||||
|
||||
assert handler.counter == 0
|
||||
|
||||
os.kill(os.getpid(), sig)
|
||||
|
||||
# In some environments, the signal might take a moment to be handled.
|
||||
shutdown_event.wait(timeout=1.0)
|
||||
|
||||
assert shutdown_event.is_set(), f"Event should be set after receiving signal {sig}"
|
||||
|
||||
# Ensure the internal counter was incremented
|
||||
assert handler.counter == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_threads", [True, False])
|
||||
@patch("sys.exit")
|
||||
def test_force_shutdown_on_second_signal(mock_sys_exit, use_threads):
|
||||
"""Test that a second signal triggers a force shutdown."""
|
||||
handler = ProcessSignalHandler(use_threads=use_threads)
|
||||
|
||||
os.kill(os.getpid(), signal.SIGINT)
|
||||
# Give a moment for the first signal to be processed
|
||||
import time
|
||||
|
||||
time.sleep(0.1)
|
||||
os.kill(os.getpid(), signal.SIGINT)
|
||||
|
||||
time.sleep(0.1)
|
||||
|
||||
assert handler.counter == 2
|
||||
mock_sys_exit.assert_called_once_with(1)
|
||||
150
tests/utils/test_queue.py
Normal file
150
tests/utils/test_queue.py
Normal file
@@ -0,0 +1,150 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import threading
|
||||
import time
|
||||
from queue import Queue
|
||||
|
||||
from lerobot.common.utils.queue import get_last_item_from_queue
|
||||
|
||||
|
||||
def test_get_last_item_single_item():
|
||||
"""Test getting the last item when queue has only one item."""
|
||||
queue = Queue()
|
||||
queue.put("single_item")
|
||||
|
||||
result = get_last_item_from_queue(queue)
|
||||
|
||||
assert result == "single_item"
|
||||
assert queue.empty()
|
||||
|
||||
|
||||
def test_get_last_item_multiple_items():
|
||||
"""Test getting the last item when queue has multiple items."""
|
||||
queue = Queue()
|
||||
items = ["first", "second", "third", "fourth", "last"]
|
||||
|
||||
for item in items:
|
||||
queue.put(item)
|
||||
|
||||
result = get_last_item_from_queue(queue)
|
||||
|
||||
assert result == "last"
|
||||
assert queue.empty()
|
||||
|
||||
|
||||
def test_get_last_item_different_types():
|
||||
"""Test with different data types in the queue."""
|
||||
queue = Queue()
|
||||
items = [1, 2.5, "string", {"key": "value"}, [1, 2, 3], ("tuple", "data")]
|
||||
|
||||
for item in items:
|
||||
queue.put(item)
|
||||
|
||||
result = get_last_item_from_queue(queue)
|
||||
|
||||
assert result == ("tuple", "data")
|
||||
assert queue.empty()
|
||||
|
||||
|
||||
def test_get_last_item_maxsize_queue():
|
||||
"""Test with a queue that has a maximum size."""
|
||||
queue = Queue(maxsize=5)
|
||||
|
||||
# Fill the queue
|
||||
for i in range(5):
|
||||
queue.put(i)
|
||||
|
||||
# Give the queue time to fill
|
||||
time.sleep(0.1)
|
||||
|
||||
result = get_last_item_from_queue(queue)
|
||||
|
||||
assert result == 4
|
||||
assert queue.empty()
|
||||
|
||||
|
||||
def test_get_last_item_with_none_values():
|
||||
"""Test with None values in the queue."""
|
||||
queue = Queue()
|
||||
items = [1, None, 2, None, 3]
|
||||
|
||||
for item in items:
|
||||
queue.put(item)
|
||||
|
||||
# Give the queue time to fill
|
||||
time.sleep(0.1)
|
||||
|
||||
result = get_last_item_from_queue(queue)
|
||||
|
||||
assert result == 3
|
||||
assert queue.empty()
|
||||
|
||||
|
||||
def test_get_last_item_blocking_timeout():
|
||||
"""Test get_last_item_from_queue returns None on timeout."""
|
||||
queue = Queue()
|
||||
result = get_last_item_from_queue(queue, block=True, timeout=0.1)
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_get_last_item_non_blocking_empty():
|
||||
"""Test get_last_item_from_queue with block=False on an empty queue returns None."""
|
||||
queue = Queue()
|
||||
result = get_last_item_from_queue(queue, block=False)
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_get_last_item_non_blocking_success():
|
||||
"""Test get_last_item_from_queue with block=False on a non-empty queue."""
|
||||
queue = Queue()
|
||||
items = ["first", "second", "last"]
|
||||
for item in items:
|
||||
queue.put(item)
|
||||
|
||||
# Give the queue time to fill
|
||||
time.sleep(0.1)
|
||||
|
||||
result = get_last_item_from_queue(queue, block=False)
|
||||
assert result == "last"
|
||||
assert queue.empty()
|
||||
|
||||
|
||||
def test_get_last_item_blocking_waits_for_item():
|
||||
"""Test that get_last_item_from_queue waits for an item if block=True."""
|
||||
queue = Queue()
|
||||
result = []
|
||||
|
||||
def producer():
|
||||
queue.put("item1")
|
||||
queue.put("item2")
|
||||
|
||||
def consumer():
|
||||
# This will block until the producer puts the first item
|
||||
item = get_last_item_from_queue(queue, block=True, timeout=0.2)
|
||||
result.append(item)
|
||||
|
||||
producer_thread = threading.Thread(target=producer)
|
||||
consumer_thread = threading.Thread(target=consumer)
|
||||
|
||||
producer_thread.start()
|
||||
consumer_thread.start()
|
||||
|
||||
producer_thread.join()
|
||||
consumer_thread.join()
|
||||
|
||||
assert result == ["item2"]
|
||||
assert queue.empty()
|
||||
682
tests/utils/test_replay_buffer.py
Normal file
682
tests/utils/test_replay_buffer.py
Normal file
@@ -0,0 +1,682 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.utils.buffer import BatchTransition, ReplayBuffer, random_crop_vectorized
|
||||
from tests.fixtures.constants import DUMMY_REPO_ID
|
||||
|
||||
|
||||
def state_dims() -> list[str]:
|
||||
return ["observation.image", "observation.state"]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def replay_buffer() -> ReplayBuffer:
|
||||
return create_empty_replay_buffer()
|
||||
|
||||
|
||||
def clone_state(state: dict) -> dict:
|
||||
return {k: v.clone() for k, v in state.items()}
|
||||
|
||||
|
||||
def create_empty_replay_buffer(
|
||||
optimize_memory: bool = False,
|
||||
use_drq: bool = False,
|
||||
image_augmentation_function: Callable | None = None,
|
||||
) -> ReplayBuffer:
|
||||
buffer_capacity = 10
|
||||
device = "cpu"
|
||||
return ReplayBuffer(
|
||||
buffer_capacity,
|
||||
device,
|
||||
state_dims(),
|
||||
optimize_memory=optimize_memory,
|
||||
use_drq=use_drq,
|
||||
image_augmentation_function=image_augmentation_function,
|
||||
)
|
||||
|
||||
|
||||
def create_random_image() -> torch.Tensor:
|
||||
return torch.rand(3, 84, 84)
|
||||
|
||||
|
||||
def create_dummy_transition() -> dict:
|
||||
return {
|
||||
"observation.image": create_random_image(),
|
||||
"action": torch.randn(4),
|
||||
"reward": torch.tensor(1.0),
|
||||
"observation.state": torch.randn(
|
||||
10,
|
||||
),
|
||||
"done": torch.tensor(False),
|
||||
"truncated": torch.tensor(False),
|
||||
"complementary_info": {},
|
||||
}
|
||||
|
||||
|
||||
def create_dataset_from_replay_buffer(tmp_path) -> tuple[LeRobotDataset, ReplayBuffer]:
|
||||
dummy_state_1 = create_dummy_state()
|
||||
dummy_action_1 = create_dummy_action()
|
||||
|
||||
dummy_state_2 = create_dummy_state()
|
||||
dummy_action_2 = create_dummy_action()
|
||||
|
||||
dummy_state_3 = create_dummy_state()
|
||||
dummy_action_3 = create_dummy_action()
|
||||
|
||||
dummy_state_4 = create_dummy_state()
|
||||
dummy_action_4 = create_dummy_action()
|
||||
|
||||
replay_buffer = create_empty_replay_buffer()
|
||||
replay_buffer.add(dummy_state_1, dummy_action_1, 1.0, dummy_state_1, False, False)
|
||||
replay_buffer.add(dummy_state_2, dummy_action_2, 1.0, dummy_state_2, False, False)
|
||||
replay_buffer.add(dummy_state_3, dummy_action_3, 1.0, dummy_state_3, True, True)
|
||||
replay_buffer.add(dummy_state_4, dummy_action_4, 1.0, dummy_state_4, True, True)
|
||||
|
||||
root = tmp_path / "test"
|
||||
return (replay_buffer.to_lerobot_dataset(DUMMY_REPO_ID, root=root), replay_buffer)
|
||||
|
||||
|
||||
def create_dummy_state() -> dict:
|
||||
return {
|
||||
"observation.image": create_random_image(),
|
||||
"observation.state": torch.randn(
|
||||
10,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def get_tensor_memory_consumption(tensor):
|
||||
return tensor.nelement() * tensor.element_size()
|
||||
|
||||
|
||||
def get_tensors_memory_consumption(obj, visited_addresses):
|
||||
total_size = 0
|
||||
|
||||
address = id(obj)
|
||||
if address in visited_addresses:
|
||||
return 0
|
||||
|
||||
visited_addresses.add(address)
|
||||
|
||||
if isinstance(obj, torch.Tensor):
|
||||
return get_tensor_memory_consumption(obj)
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
for item in obj:
|
||||
total_size += get_tensors_memory_consumption(item, visited_addresses)
|
||||
elif isinstance(obj, dict):
|
||||
for value in obj.values():
|
||||
total_size += get_tensors_memory_consumption(value, visited_addresses)
|
||||
elif hasattr(obj, "__dict__"):
|
||||
# It's an object, we need to get the size of the attributes
|
||||
for _, attr in vars(obj).items():
|
||||
total_size += get_tensors_memory_consumption(attr, visited_addresses)
|
||||
|
||||
return total_size
|
||||
|
||||
|
||||
def get_object_memory(obj):
|
||||
# Track visited addresses to avoid infinite loops
|
||||
# and cases when two properties point to the same object
|
||||
visited_addresses = set()
|
||||
|
||||
# Get the size of the object in bytes
|
||||
total_size = sys.getsizeof(obj)
|
||||
|
||||
# Get the size of the tensor attributes
|
||||
total_size += get_tensors_memory_consumption(obj, visited_addresses)
|
||||
|
||||
return total_size
|
||||
|
||||
|
||||
def create_dummy_action() -> torch.Tensor:
|
||||
return torch.randn(4)
|
||||
|
||||
|
||||
def dict_properties() -> list:
|
||||
return ["state", "next_state"]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_state() -> dict:
|
||||
return create_dummy_state()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def next_dummy_state() -> dict:
|
||||
return create_dummy_state()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_action() -> torch.Tensor:
|
||||
return torch.randn(4)
|
||||
|
||||
|
||||
def test_empty_buffer_sample_raises_error(replay_buffer):
|
||||
assert len(replay_buffer) == 0, "Replay buffer should be empty."
|
||||
assert replay_buffer.capacity == 10, "Replay buffer capacity should be 10."
|
||||
with pytest.raises(RuntimeError, match="Cannot sample from an empty buffer"):
|
||||
replay_buffer.sample(1)
|
||||
|
||||
|
||||
def test_zero_capacity_buffer_raises_error():
|
||||
with pytest.raises(ValueError, match="Capacity must be greater than 0."):
|
||||
ReplayBuffer(0, "cpu", ["observation", "next_observation"])
|
||||
|
||||
|
||||
def test_add_transition(replay_buffer, dummy_state, dummy_action):
|
||||
replay_buffer.add(dummy_state, dummy_action, 1.0, dummy_state, False, False)
|
||||
assert len(replay_buffer) == 1, "Replay buffer should have one transition after adding."
|
||||
assert torch.equal(replay_buffer.actions[0], dummy_action), (
|
||||
"Action should be equal to the first transition."
|
||||
)
|
||||
assert replay_buffer.rewards[0] == 1.0, "Reward should be equal to the first transition."
|
||||
assert not replay_buffer.dones[0], "Done should be False for the first transition."
|
||||
assert not replay_buffer.truncateds[0], "Truncated should be False for the first transition."
|
||||
|
||||
for dim in state_dims():
|
||||
assert torch.equal(replay_buffer.states[dim][0], dummy_state[dim]), (
|
||||
"Observation should be equal to the first transition."
|
||||
)
|
||||
assert torch.equal(replay_buffer.next_states[dim][0], dummy_state[dim]), (
|
||||
"Next observation should be equal to the first transition."
|
||||
)
|
||||
|
||||
|
||||
def test_add_over_capacity():
|
||||
replay_buffer = ReplayBuffer(2, "cpu", ["observation", "next_observation"])
|
||||
dummy_state_1 = create_dummy_state()
|
||||
dummy_action_1 = create_dummy_action()
|
||||
|
||||
dummy_state_2 = create_dummy_state()
|
||||
dummy_action_2 = create_dummy_action()
|
||||
|
||||
dummy_state_3 = create_dummy_state()
|
||||
dummy_action_3 = create_dummy_action()
|
||||
|
||||
replay_buffer.add(dummy_state_1, dummy_action_1, 1.0, dummy_state_1, False, False)
|
||||
replay_buffer.add(dummy_state_2, dummy_action_2, 1.0, dummy_state_2, False, False)
|
||||
replay_buffer.add(dummy_state_3, dummy_action_3, 1.0, dummy_state_3, True, True)
|
||||
|
||||
assert len(replay_buffer) == 2, "Replay buffer should have 2 transitions after adding 3."
|
||||
|
||||
for dim in state_dims():
|
||||
assert torch.equal(replay_buffer.states[dim][0], dummy_state_3[dim]), (
|
||||
"Observation should be equal to the first transition."
|
||||
)
|
||||
assert torch.equal(replay_buffer.next_states[dim][0], dummy_state_3[dim]), (
|
||||
"Next observation should be equal to the first transition."
|
||||
)
|
||||
|
||||
assert torch.equal(replay_buffer.actions[0], dummy_action_3), (
|
||||
"Action should be equal to the last transition."
|
||||
)
|
||||
assert replay_buffer.rewards[0] == 1.0, "Reward should be equal to the last transition."
|
||||
assert replay_buffer.dones[0], "Done should be True for the first transition."
|
||||
assert replay_buffer.truncateds[0], "Truncated should be True for the first transition."
|
||||
|
||||
|
||||
def test_sample_from_empty_buffer(replay_buffer):
|
||||
with pytest.raises(RuntimeError, match="Cannot sample from an empty buffer"):
|
||||
replay_buffer.sample(1)
|
||||
|
||||
|
||||
def test_sample_with_1_transition(replay_buffer, dummy_state, next_dummy_state, dummy_action):
|
||||
replay_buffer.add(dummy_state, dummy_action, 1.0, next_dummy_state, False, False)
|
||||
got_batch_transition = replay_buffer.sample(1)
|
||||
|
||||
expected_batch_transition = BatchTransition(
|
||||
state=clone_state(dummy_state),
|
||||
action=dummy_action.clone(),
|
||||
reward=1.0,
|
||||
next_state=clone_state(next_dummy_state),
|
||||
done=False,
|
||||
truncated=False,
|
||||
)
|
||||
|
||||
for buffer_property in dict_properties():
|
||||
for k, v in expected_batch_transition[buffer_property].items():
|
||||
got_state = got_batch_transition[buffer_property][k]
|
||||
|
||||
assert got_state.shape[0] == 1, f"{k} should have 1 transition."
|
||||
assert got_state.device.type == "cpu", f"{k} should be on cpu."
|
||||
|
||||
assert torch.equal(got_state[0], v), f"{k} should be equal to the expected batch transition."
|
||||
|
||||
for key, _value in expected_batch_transition.items():
|
||||
if key in dict_properties():
|
||||
continue
|
||||
|
||||
got_value = got_batch_transition[key]
|
||||
|
||||
v_tensor = expected_batch_transition[key]
|
||||
if not isinstance(v_tensor, torch.Tensor):
|
||||
v_tensor = torch.tensor(v_tensor)
|
||||
|
||||
assert got_value.shape[0] == 1, f"{key} should have 1 transition."
|
||||
assert got_value.device.type == "cpu", f"{key} should be on cpu."
|
||||
assert torch.equal(got_value[0], v_tensor), f"{key} should be equal to the expected batch transition."
|
||||
|
||||
|
||||
def test_sample_with_batch_bigger_than_buffer_size(
|
||||
replay_buffer, dummy_state, next_dummy_state, dummy_action
|
||||
):
|
||||
replay_buffer.add(dummy_state, dummy_action, 1.0, next_dummy_state, False, False)
|
||||
got_batch_transition = replay_buffer.sample(10)
|
||||
|
||||
expected_batch_transition = BatchTransition(
|
||||
state=dummy_state,
|
||||
action=dummy_action,
|
||||
reward=1.0,
|
||||
next_state=next_dummy_state,
|
||||
done=False,
|
||||
truncated=False,
|
||||
)
|
||||
|
||||
for buffer_property in dict_properties():
|
||||
for k in expected_batch_transition[buffer_property]:
|
||||
got_state = got_batch_transition[buffer_property][k]
|
||||
|
||||
assert got_state.shape[0] == 1, f"{k} should have 1 transition."
|
||||
|
||||
for key in expected_batch_transition:
|
||||
if key in dict_properties():
|
||||
continue
|
||||
|
||||
got_value = got_batch_transition[key]
|
||||
assert got_value.shape[0] == 1, f"{key} should have 1 transition."
|
||||
|
||||
|
||||
def test_sample_batch(replay_buffer):
|
||||
dummy_state_1 = create_dummy_state()
|
||||
dummy_action_1 = create_dummy_action()
|
||||
|
||||
dummy_state_2 = create_dummy_state()
|
||||
dummy_action_2 = create_dummy_action()
|
||||
|
||||
dummy_state_3 = create_dummy_state()
|
||||
dummy_action_3 = create_dummy_action()
|
||||
|
||||
dummy_state_4 = create_dummy_state()
|
||||
dummy_action_4 = create_dummy_action()
|
||||
|
||||
replay_buffer.add(dummy_state_1, dummy_action_1, 1.0, dummy_state_1, False, False)
|
||||
replay_buffer.add(dummy_state_2, dummy_action_2, 2.0, dummy_state_2, False, False)
|
||||
replay_buffer.add(dummy_state_3, dummy_action_3, 3.0, dummy_state_3, True, True)
|
||||
replay_buffer.add(dummy_state_4, dummy_action_4, 4.0, dummy_state_4, True, True)
|
||||
|
||||
dummy_states = [dummy_state_1, dummy_state_2, dummy_state_3, dummy_state_4]
|
||||
dummy_actions = [dummy_action_1, dummy_action_2, dummy_action_3, dummy_action_4]
|
||||
|
||||
got_batch_transition = replay_buffer.sample(3)
|
||||
|
||||
for buffer_property in dict_properties():
|
||||
for k in got_batch_transition[buffer_property]:
|
||||
got_state = got_batch_transition[buffer_property][k]
|
||||
|
||||
assert got_state.shape[0] == 3, f"{k} should have 3 transition."
|
||||
|
||||
for got_state_item in got_state:
|
||||
assert any(torch.equal(got_state_item, dummy_state[k]) for dummy_state in dummy_states), (
|
||||
f"{k} should be equal to one of the dummy states."
|
||||
)
|
||||
|
||||
for got_action_item in got_batch_transition["action"]:
|
||||
assert any(torch.equal(got_action_item, dummy_action) for dummy_action in dummy_actions), (
|
||||
"Actions should be equal to the dummy actions."
|
||||
)
|
||||
|
||||
for k in got_batch_transition:
|
||||
if k in dict_properties() or k == "complementary_info":
|
||||
continue
|
||||
|
||||
got_value = got_batch_transition[k]
|
||||
assert got_value.shape[0] == 3, f"{k} should have 3 transition."
|
||||
|
||||
|
||||
def test_to_lerobot_dataset_with_empty_buffer(replay_buffer):
|
||||
with pytest.raises(ValueError, match="The replay buffer is empty. Cannot convert to a dataset."):
|
||||
replay_buffer.to_lerobot_dataset("dummy_repo")
|
||||
|
||||
|
||||
def test_to_lerobot_dataset(tmp_path):
|
||||
ds, buffer = create_dataset_from_replay_buffer(tmp_path)
|
||||
|
||||
assert len(ds) == len(buffer), "Dataset should have the same size as the Replay Buffer"
|
||||
assert ds.fps == 1, "FPS should be 1"
|
||||
assert ds.repo_id == "dummy/repo", "The dataset should have `dummy/repo` repo id"
|
||||
|
||||
for dim in state_dims():
|
||||
assert dim in ds.features
|
||||
assert ds.features[dim]["shape"] == buffer.states[dim][0].shape
|
||||
|
||||
assert ds.num_episodes == 2
|
||||
assert ds.num_frames == 4
|
||||
|
||||
for j, value in enumerate(ds):
|
||||
print(torch.equal(value["observation.image"], buffer.next_states["observation.image"][j]))
|
||||
|
||||
for i in range(len(ds)):
|
||||
for feature, value in ds[i].items():
|
||||
if feature == "action":
|
||||
assert torch.equal(value, buffer.actions[i])
|
||||
elif feature == "next.reward":
|
||||
assert torch.equal(value, buffer.rewards[i])
|
||||
elif feature == "next.done":
|
||||
assert torch.equal(value, buffer.dones[i])
|
||||
elif feature == "observation.image":
|
||||
# Tenssor -> numpy is not precise, so we have some diff there
|
||||
# TODO: Check and fix it
|
||||
torch.testing.assert_close(value, buffer.states["observation.image"][i], rtol=0.3, atol=0.003)
|
||||
elif feature == "observation.state":
|
||||
assert torch.equal(value, buffer.states["observation.state"][i])
|
||||
|
||||
|
||||
def test_from_lerobot_dataset(tmp_path):
|
||||
dummy_state_1 = create_dummy_state()
|
||||
dummy_action_1 = create_dummy_action()
|
||||
|
||||
dummy_state_2 = create_dummy_state()
|
||||
dummy_action_2 = create_dummy_action()
|
||||
|
||||
dummy_state_3 = create_dummy_state()
|
||||
dummy_action_3 = create_dummy_action()
|
||||
|
||||
dummy_state_4 = create_dummy_state()
|
||||
dummy_action_4 = create_dummy_action()
|
||||
|
||||
replay_buffer = create_empty_replay_buffer()
|
||||
replay_buffer.add(dummy_state_1, dummy_action_1, 1.0, dummy_state_1, False, False)
|
||||
replay_buffer.add(dummy_state_2, dummy_action_2, 1.0, dummy_state_2, False, False)
|
||||
replay_buffer.add(dummy_state_3, dummy_action_3, 1.0, dummy_state_3, True, True)
|
||||
replay_buffer.add(dummy_state_4, dummy_action_4, 1.0, dummy_state_4, True, True)
|
||||
|
||||
root = tmp_path / "test"
|
||||
ds = replay_buffer.to_lerobot_dataset(DUMMY_REPO_ID, root=root)
|
||||
|
||||
reconverted_buffer = ReplayBuffer.from_lerobot_dataset(
|
||||
ds, state_keys=list(state_dims()), device="cpu", capacity=replay_buffer.capacity, use_drq=False
|
||||
)
|
||||
|
||||
# Check only the part of the buffer that's actually filled with data
|
||||
assert torch.equal(
|
||||
reconverted_buffer.actions[: len(replay_buffer)],
|
||||
replay_buffer.actions[: len(replay_buffer)],
|
||||
), "Actions from converted buffer should be equal to the original replay buffer."
|
||||
assert torch.equal(
|
||||
reconverted_buffer.rewards[: len(replay_buffer)], replay_buffer.rewards[: len(replay_buffer)]
|
||||
), "Rewards from converted buffer should be equal to the original replay buffer."
|
||||
assert torch.equal(
|
||||
reconverted_buffer.dones[: len(replay_buffer)], replay_buffer.dones[: len(replay_buffer)]
|
||||
), "Dones from converted buffer should be equal to the original replay buffer."
|
||||
|
||||
# Lerobot DS haven't supported truncateds yet
|
||||
expected_truncateds = torch.zeros(len(replay_buffer)).bool()
|
||||
assert torch.equal(reconverted_buffer.truncateds[: len(replay_buffer)], expected_truncateds), (
|
||||
"Truncateds from converted buffer should be equal False"
|
||||
)
|
||||
|
||||
assert torch.equal(
|
||||
replay_buffer.states["observation.state"][: len(replay_buffer)],
|
||||
reconverted_buffer.states["observation.state"][: len(replay_buffer)],
|
||||
), "State should be the same after converting to dataset and return back"
|
||||
|
||||
for i in range(4):
|
||||
torch.testing.assert_close(
|
||||
replay_buffer.states["observation.image"][i],
|
||||
reconverted_buffer.states["observation.image"][i],
|
||||
rtol=0.4,
|
||||
atol=0.004,
|
||||
)
|
||||
|
||||
# The 2, 3 frames have done flag, so their values will be equal to the current state
|
||||
for i in range(2):
|
||||
# In the current implementation we take the next state from the `states` and ignore `next_states`
|
||||
next_index = (i + 1) % 4
|
||||
|
||||
torch.testing.assert_close(
|
||||
replay_buffer.states["observation.image"][next_index],
|
||||
reconverted_buffer.next_states["observation.image"][i],
|
||||
rtol=0.4,
|
||||
atol=0.004,
|
||||
)
|
||||
|
||||
for i in range(2, 4):
|
||||
assert torch.equal(
|
||||
replay_buffer.states["observation.state"][i],
|
||||
reconverted_buffer.next_states["observation.state"][i],
|
||||
)
|
||||
|
||||
|
||||
def test_buffer_sample_alignment():
|
||||
# Initialize buffer
|
||||
buffer = ReplayBuffer(capacity=100, device="cpu", state_keys=["state_value"], storage_device="cpu")
|
||||
|
||||
# Fill buffer with patterned data
|
||||
for i in range(100):
|
||||
signature = float(i) / 100.0
|
||||
state = {"state_value": torch.tensor([[signature]]).float()}
|
||||
action = torch.tensor([[2.0 * signature]]).float()
|
||||
reward = 3.0 * signature
|
||||
|
||||
is_end = (i + 1) % 10 == 0
|
||||
if is_end:
|
||||
next_state = {"state_value": torch.tensor([[signature]]).float()}
|
||||
done = True
|
||||
else:
|
||||
next_signature = float(i + 1) / 100.0
|
||||
next_state = {"state_value": torch.tensor([[next_signature]]).float()}
|
||||
done = False
|
||||
|
||||
buffer.add(state, action, reward, next_state, done, False)
|
||||
|
||||
# Sample and verify
|
||||
batch = buffer.sample(50)
|
||||
|
||||
for i in range(50):
|
||||
state_sig = batch["state"]["state_value"][i].item()
|
||||
action_val = batch["action"][i].item()
|
||||
reward_val = batch["reward"][i].item()
|
||||
next_state_sig = batch["next_state"]["state_value"][i].item()
|
||||
is_done = batch["done"][i].item() > 0.5
|
||||
|
||||
# Verify relationships
|
||||
assert abs(action_val - 2.0 * state_sig) < 1e-4, (
|
||||
f"Action {action_val} should be 2x state signature {state_sig}"
|
||||
)
|
||||
|
||||
assert abs(reward_val - 3.0 * state_sig) < 1e-4, (
|
||||
f"Reward {reward_val} should be 3x state signature {state_sig}"
|
||||
)
|
||||
|
||||
if is_done:
|
||||
assert abs(next_state_sig - state_sig) < 1e-4, (
|
||||
f"For done states, next_state {next_state_sig} should equal state {state_sig}"
|
||||
)
|
||||
else:
|
||||
# Either it's the next sequential state (+0.01) or same state (for episode boundaries)
|
||||
valid_next = (
|
||||
abs(next_state_sig - state_sig - 0.01) < 1e-4 or abs(next_state_sig - state_sig) < 1e-4
|
||||
)
|
||||
assert valid_next, (
|
||||
f"Next state {next_state_sig} should be either state+0.01 or same as state {state_sig}"
|
||||
)
|
||||
|
||||
|
||||
def test_memory_optimization():
|
||||
dummy_state_1 = create_dummy_state()
|
||||
dummy_action_1 = create_dummy_action()
|
||||
|
||||
dummy_state_2 = create_dummy_state()
|
||||
dummy_action_2 = create_dummy_action()
|
||||
|
||||
dummy_state_3 = create_dummy_state()
|
||||
dummy_action_3 = create_dummy_action()
|
||||
|
||||
dummy_state_4 = create_dummy_state()
|
||||
dummy_action_4 = create_dummy_action()
|
||||
|
||||
replay_buffer = create_empty_replay_buffer()
|
||||
replay_buffer.add(dummy_state_1, dummy_action_1, 1.0, dummy_state_2, False, False)
|
||||
replay_buffer.add(dummy_state_2, dummy_action_2, 1.0, dummy_state_3, False, False)
|
||||
replay_buffer.add(dummy_state_3, dummy_action_3, 1.0, dummy_state_4, False, False)
|
||||
replay_buffer.add(dummy_state_4, dummy_action_4, 1.0, dummy_state_4, True, True)
|
||||
|
||||
optimized_replay_buffer = create_empty_replay_buffer(True)
|
||||
optimized_replay_buffer.add(dummy_state_1, dummy_action_1, 1.0, dummy_state_2, False, False)
|
||||
optimized_replay_buffer.add(dummy_state_2, dummy_action_2, 1.0, dummy_state_3, False, False)
|
||||
optimized_replay_buffer.add(dummy_state_3, dummy_action_3, 1.0, dummy_state_4, False, False)
|
||||
optimized_replay_buffer.add(dummy_state_4, dummy_action_4, 1.0, None, True, True)
|
||||
|
||||
assert get_object_memory(optimized_replay_buffer) < get_object_memory(replay_buffer), (
|
||||
"Optimized replay buffer should be smaller than the original replay buffer"
|
||||
)
|
||||
|
||||
|
||||
def test_check_image_augmentations_with_drq_and_dummy_image_augmentation_function(dummy_state, dummy_action):
|
||||
def dummy_image_augmentation_function(x):
|
||||
return torch.ones_like(x) * 10
|
||||
|
||||
replay_buffer = create_empty_replay_buffer(
|
||||
use_drq=True, image_augmentation_function=dummy_image_augmentation_function
|
||||
)
|
||||
|
||||
replay_buffer.add(dummy_state, dummy_action, 1.0, dummy_state, False, False)
|
||||
|
||||
sampled_transitions = replay_buffer.sample(1)
|
||||
assert torch.all(sampled_transitions["state"]["observation.image"] == 10), (
|
||||
"Image augmentations should be applied"
|
||||
)
|
||||
assert torch.all(sampled_transitions["next_state"]["observation.image"] == 10), (
|
||||
"Image augmentations should be applied"
|
||||
)
|
||||
|
||||
|
||||
def test_check_image_augmentations_with_drq_and_default_image_augmentation_function(
|
||||
dummy_state, dummy_action
|
||||
):
|
||||
replay_buffer = create_empty_replay_buffer(use_drq=True)
|
||||
|
||||
replay_buffer.add(dummy_state, dummy_action, 1.0, dummy_state, False, False)
|
||||
|
||||
# Let's check that it doesn't fail and shapes are correct
|
||||
sampled_transitions = replay_buffer.sample(1)
|
||||
assert sampled_transitions["state"]["observation.image"].shape == (1, 3, 84, 84)
|
||||
assert sampled_transitions["next_state"]["observation.image"].shape == (1, 3, 84, 84)
|
||||
|
||||
|
||||
def test_random_crop_vectorized_basic():
|
||||
# Create a batch of 2 images with known patterns
|
||||
batch_size, channels, height, width = 2, 3, 10, 8
|
||||
images = torch.zeros((batch_size, channels, height, width))
|
||||
|
||||
# Fill with unique values for testing
|
||||
for b in range(batch_size):
|
||||
images[b] = b + 1
|
||||
|
||||
crop_size = (6, 4) # Smaller than original
|
||||
cropped = random_crop_vectorized(images, crop_size)
|
||||
|
||||
# Check output shape
|
||||
assert cropped.shape == (batch_size, channels, *crop_size)
|
||||
|
||||
# Check that values are preserved (should be either 1s or 2s for respective batches)
|
||||
assert torch.all(cropped[0] == 1)
|
||||
assert torch.all(cropped[1] == 2)
|
||||
|
||||
|
||||
def test_random_crop_vectorized_invalid_size():
|
||||
images = torch.zeros((2, 3, 10, 8))
|
||||
|
||||
# Test crop size larger than image
|
||||
with pytest.raises(ValueError, match="Requested crop size .* is bigger than the image size"):
|
||||
random_crop_vectorized(images, (12, 8))
|
||||
|
||||
with pytest.raises(ValueError, match="Requested crop size .* is bigger than the image size"):
|
||||
random_crop_vectorized(images, (10, 10))
|
||||
|
||||
|
||||
def _populate_buffer_for_async_test(capacity: int = 10) -> ReplayBuffer:
|
||||
"""Create a small buffer with deterministic 3×128×128 images and 11-D state."""
|
||||
buffer = ReplayBuffer(
|
||||
capacity=capacity,
|
||||
device="cpu",
|
||||
state_keys=["observation.image", "observation.state"],
|
||||
storage_device="cpu",
|
||||
)
|
||||
|
||||
for i in range(capacity):
|
||||
img = torch.ones(3, 128, 128) * i
|
||||
state_vec = torch.arange(11).float() + i
|
||||
state = {
|
||||
"observation.image": img,
|
||||
"observation.state": state_vec,
|
||||
}
|
||||
buffer.add(
|
||||
state=state,
|
||||
action=torch.tensor([0.0]),
|
||||
reward=0.0,
|
||||
next_state=state,
|
||||
done=False,
|
||||
truncated=False,
|
||||
)
|
||||
return buffer
|
||||
|
||||
|
||||
def test_async_iterator_shapes_basic():
|
||||
buffer = _populate_buffer_for_async_test()
|
||||
batch_size = 2
|
||||
iterator = buffer.get_iterator(batch_size=batch_size, async_prefetch=True, queue_size=1)
|
||||
batch = next(iterator)
|
||||
|
||||
images = batch["state"]["observation.image"]
|
||||
states = batch["state"]["observation.state"]
|
||||
|
||||
assert images.shape == (batch_size, 3, 128, 128)
|
||||
assert states.shape == (batch_size, 11)
|
||||
|
||||
next_images = batch["next_state"]["observation.image"]
|
||||
next_states = batch["next_state"]["observation.state"]
|
||||
|
||||
assert next_images.shape == (batch_size, 3, 128, 128)
|
||||
assert next_states.shape == (batch_size, 11)
|
||||
|
||||
|
||||
def test_async_iterator_multiple_iterations():
|
||||
buffer = _populate_buffer_for_async_test()
|
||||
batch_size = 2
|
||||
iterator = buffer.get_iterator(batch_size=batch_size, async_prefetch=True, queue_size=2)
|
||||
|
||||
for _ in range(5):
|
||||
batch = next(iterator)
|
||||
images = batch["state"]["observation.image"]
|
||||
states = batch["state"]["observation.state"]
|
||||
assert images.shape == (batch_size, 3, 128, 128)
|
||||
assert states.shape == (batch_size, 11)
|
||||
|
||||
next_images = batch["next_state"]["observation.image"]
|
||||
next_states = batch["next_state"]["observation.state"]
|
||||
assert next_images.shape == (batch_size, 3, 128, 128)
|
||||
assert next_states.shape == (batch_size, 11)
|
||||
|
||||
# Ensure iterator can be disposed without blocking
|
||||
del iterator
|
||||
Reference in New Issue
Block a user