Add OpenPi, Pi0 and Pi0.5 (#1910)
* initial commit * change device in test * do detailed import * adhere to python 3.11 syntax * fix autodocstring * additionally * do same in other files * add model. prefix to all keys in state dict * use dummy stats * add pi05 * also shorten action_steps * fix test * all test pass! and fix tokenizer max length between 05 and 0 * remove test * fix transformer dependency * fix test * split pi0 and pi05 policy in seperate files * fix test * fix push to hub test * add some comments, license and readme * remove warning in config * add pi05 to factory * remove check * rename action_horizon to chunk_size * clean up padding of state and action (more in line with lerobot pi0) * add openpi image transforms for training and add more flexibility to _preprocess_images similar to lerobot pi0 * fix key match from pytorch state dict (similar keys to openpi implementation now) * also for pi05 * update to python 3.11 * revert to openpi transformer replace python 3.11 * fix(modeling pi0): nit warning message * use safeauto_docstring * fix: remove unused param * fix from pretrained * add preprocess tests * also compile forward method * Do not add model prefix to normalization * use same name for action and state dim as lerobot pi0 and remove fixed image keys * load from pretrained_path * temp: hardcode base model * fix override self.pretrained_path = None overwrite * rename to loss * remove additional image augmentations, lerobot dataset already does this * Add docs * put tests in test folder * Add test to instatiate all base models * go back to python 3.10 * update docs * adapt docs pi05 * change docs: finetune base model options * minor docs fixes and dependencies * remove todo * cast float64 to float32 for mps * skip if no transformers * fix tests * add new models to modelcard * add back init * fix circular input * feat: only run pi test on GPU * remove require_nightly_gpu * replace decorator test_pi0_openpi * rename action_dim, state_dim to max_action_dim, max_state_dim * fix doc and constants * cleanup tests * fix from pretrained * fix tests * add comment pi0 pi05 tests, add image features to pi0 pi05 hub tests * fix, state is included in language not in flow head * Move test to specific folder * and paligemma task with newline * remove add_special_tokens, not needed * feedback pr * Remove previous pi0 and rename pi0_openpi and pi05_openpi * Add Quantile stats to LeRobotDataset (#1985) * - Add RunningQuantileStats class for efficient histogram-based quantile computation - Integrate quantile parameters (compute_quantiles, quantiles) into LeRobotDataset - Support quantile computation during episode collection and aggregation - Add comprehensive function-based test suite (24 tests) for quantile functionality - Maintain full backward compatibility with existing stats computation - Enable configurable quantiles (default: [0.01, 0.99]) for robust normalization * style fixes, make quantiles computation by default to new datasets * fix tests * - Added DEFAULT_QUANTILES=[0.01, 0.10, 0.50, 0.90, 0.99] to be computed for each features instead of being chosen by the user - Fortified tests. * - add helper functions to reshape stats - add missing test for quantiles * - Add QUANTILE normalization mode to normalize the data with the 1st and 99th percentiles. - Add QUANTILE10 normalization mode to normalize the data with the 10th and 90th percentiles. * style fixes * Added missing lisence * Simplify compute_stats * - added script `augment_dataset_quantile_stats.py` so that we can add quantile stats to existing v3 datasets that dont have quatniles - modified quantile computation instead of using the edge for the value, interpolate the values in the bin * rename pi0/pi05 files * Remove open pi patch and use custom transformer branch for now * renaming * fix * Revert "fix" This reverts commit 1ea65730ac2cbca6e5869df734fbd4392561b3c6. * fix naming * feet(pi0/pi0.5): add pipeline (#2009) * feat(processor): convert openpi model with processor * TODO: Make test works * fix(modeling_pi0openpi): update attention mask value and time scaling; improve task handling in tests - Changed the attention mask value from `self.config.attention_mask_value` to a fixed value of `-2.3819763e38`. - Updated time scaling in the `sample_noise` method to use a constant factor of `0.999` and an offset of `0.001`. - Enhanced task handling in tests to ensure proper formatting and batch size consistency. - Cleaned up commented-out test code for clarity. * refactor(pi0): rename PI0OpenPIConfig and PI0OpenPIPolicy to PI0Config and PI0Policy - Updated imports and references throughout the codebase to reflect the new naming convention. - Introduced a new processor file for PI0 to handle pre-processing and post-processing steps. - Adjusted tests to utilize the renamed classes, ensuring consistency and functionality. - Enhanced clarity and maintainability by removing outdated naming conventions. * refactor(pi05): rename PI0OpenPIPolicy to PI0Policy and update configuration - Renamed `PI0OpenPIPolicy` to `PI0Policy` for consistency with naming conventions. - Updated the `PI05OpenPIConfig` to include a new `tokenizer_max_length` attribute and changed the normalization mode for state from `MEAN_STD` to `QUANTILES`. - Simplified model initialization in `PI05OpenPIPolicy` by removing unused `dataset_stats` parameter. - Added a new processor class for `Pi05PrepareStateTokenizerProcessorStep` with `@dataclass` for improved readability. - Introduced a test script to compare the integration of the PI0OpenPI policy with the original implementation, ensuring local testing compatibility. * feat(processor): convert openpi model with processor * TODO: Make test works * fix(modeling_pi0openpi): update attention mask value and time scaling; improve task handling in tests - Changed the attention mask value from `self.config.attention_mask_value` to a fixed value of `-2.3819763e38`. - Updated time scaling in the `sample_noise` method to use a constant factor of `0.999` and an offset of `0.001`. - Enhanced task handling in tests to ensure proper formatting and batch size consistency. - Cleaned up commented-out test code for clarity. * refactor(pi0): rename PI0OpenPIConfig and PI0OpenPIPolicy to PI0Config and PI0Policy - Updated imports and references throughout the codebase to reflect the new naming convention. - Introduced a new processor file for PI0 to handle pre-processing and post-processing steps. - Adjusted tests to utilize the renamed classes, ensuring consistency and functionality. - Enhanced clarity and maintainability by removing outdated naming conventions. * refactor(pi05): rename PI0OpenPIPolicy to PI0Policy and update configuration - Renamed `PI0OpenPIPolicy` to `PI0Policy` for consistency with naming conventions. - Updated the `PI05OpenPIConfig` to include a new `tokenizer_max_length` attribute and changed the normalization mode for state from `MEAN_STD` to `QUANTILES`. - Simplified model initialization in `PI05OpenPIPolicy` by removing unused `dataset_stats` parameter. - Added a new processor class for `Pi05PrepareStateTokenizerProcessorStep` with `@dataclass` for improved readability. - Introduced a test script to compare the integration of the PI0OpenPI policy with the original implementation, ensuring local testing compatibility. * refactor(pi05): update imports and rename configuration classes - Changed imports to reflect the new naming convention for PI05 configuration and policy classes. - Renamed `PI05OpenPIConfig` to `PI05Config` and `PI05OpenPIPolicy` to `PI05Policy` for consistency. - Introduced a new processor file for PI05, implementing pre-processing and post-processing steps. - Updated tests to utilize the renamed classes, ensuring functionality and consistency across the codebase. * update(pi05): increase tokenizer_max_length for improved processing - Changed the `tokenizer_max_length` from 48 to 200 to enhance the model's capability in handling longer sequences. - This adjustment aims to improve the overall performance and flexibility of the PI05 configuration. * add default for state (max_state_dim) * correct naming * fix import * cleanup code * remove unused test * us quantiles for action * move to device * remove discrete state assert * fix pi05 test * move pi05 to device * use base models in comparison tests * small renames for tests * change number of tokens pi05 test * fix openpi tokenization in test * fix hub test * fix test * assert lerobot vs openpi tests --------- Co-authored-by: Pepijn <pepijn@huggingface.co> * add headers * add back previously removed imports * update if statement load processor with dataset stats * remove to avoid circular import * inject dataset stats for pretrained models * check normalization before applying * add link to quantile augument script * fix(policies): transformers import for ci in PI0 & PI05 (#2039) * fix(policies): transformers import for ci in PI0 * fix(policies): transformers import for ci in PI05 * test(processor): fix expected raise when normalization types are missing (#2040) * switch normalization order pipeline for pi05 * Fix/quantiles script (#2064) * refactor augment stats with quantiles script add parallelization for faster processing shift the quantile normalization between -1 1 * fix replay buffer tests * fix comment * overwrite the pipeline normalization features with the policy features * remove double normalization overwrite * cleanup from pretrained * remove typo * also set norm_map * fix(augment_quantiles) images incorrectly divided by 255 * clamp quantiles * link to lerobot base models * rename tests * encorperate PR feedback * update docstring for RunningQuantileStats * update doc links * Revert "clamp quantiles" This reverts commit 172207471c8f2cb62958e9a9e6a0535ba3ff67d4. * fix self.paligemma * fix tests related to quantiles that were scaled to [0,1], the new range is [-1, 1] * fix libero doc and use different transformer branch * use fix branch instead of feat * update results libero * add new line * fix formatting * precommit * update results libero * update libero doc * update title * final changes * add quantiles to test * run pre commit --------- Signed-off-by: Steven Palma <imstevenpmwork@ieee.org> Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co> Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com> Co-authored-by: Steven Palma <imstevenpmwork@ieee.org> Co-authored-by: Steven Palma <steven.palma@huggingface.co>
This commit is contained in:
@@ -19,6 +19,7 @@ import numpy as np
|
||||
import pytest
|
||||
|
||||
from lerobot.datasets.compute_stats import (
|
||||
RunningQuantileStats,
|
||||
_assert_type_and_shape,
|
||||
aggregate_feature_stats,
|
||||
aggregate_stats,
|
||||
@@ -102,6 +103,9 @@ def test_get_feature_stats_axis_1(sample_array):
|
||||
"count": np.array([3]),
|
||||
}
|
||||
result = get_feature_stats(sample_array, axis=(1,), keepdims=False)
|
||||
|
||||
# Check that basic stats are correct (quantiles are also included now)
|
||||
assert set(expected.keys()).issubset(set(result.keys()))
|
||||
for key in expected:
|
||||
np.testing.assert_allclose(result[key], expected[key])
|
||||
|
||||
@@ -115,6 +119,9 @@ def test_get_feature_stats_no_axis(sample_array):
|
||||
"count": np.array([3]),
|
||||
}
|
||||
result = get_feature_stats(sample_array, axis=None, keepdims=False)
|
||||
|
||||
# Check that basic stats are correct (quantiles are also included now)
|
||||
assert set(expected.keys()).issubset(set(result.keys()))
|
||||
for key in expected:
|
||||
np.testing.assert_allclose(result[key], expected[key])
|
||||
|
||||
@@ -308,3 +315,520 @@ def test_aggregate_stats():
|
||||
results[fkey]["std"], expected_agg_stats[fkey]["std"], atol=1e-04, rtol=1e-04
|
||||
)
|
||||
np.testing.assert_allclose(results[fkey]["count"], expected_agg_stats[fkey]["count"])
|
||||
|
||||
|
||||
def test_running_quantile_stats_initialization():
|
||||
"""Test proper initialization of RunningQuantileStats."""
|
||||
running_stats = RunningQuantileStats()
|
||||
assert running_stats._count == 0
|
||||
assert running_stats._mean is None
|
||||
assert running_stats._num_quantile_bins == 5000
|
||||
|
||||
# Test custom bin size
|
||||
running_stats_custom = RunningQuantileStats(num_quantile_bins=1000)
|
||||
assert running_stats_custom._num_quantile_bins == 1000
|
||||
|
||||
|
||||
def test_running_quantile_stats_single_batch_update():
|
||||
"""Test updating with a single batch."""
|
||||
np.random.seed(42)
|
||||
data = np.random.normal(0, 1, (100, 3))
|
||||
|
||||
running_stats = RunningQuantileStats()
|
||||
running_stats.update(data)
|
||||
|
||||
assert running_stats._count == 100
|
||||
assert running_stats._mean.shape == (3,)
|
||||
assert len(running_stats._histograms) == 3
|
||||
assert len(running_stats._bin_edges) == 3
|
||||
|
||||
# Verify basic statistics are reasonable
|
||||
np.testing.assert_allclose(running_stats._mean, np.mean(data, axis=0), atol=1e-10)
|
||||
|
||||
|
||||
def test_running_quantile_stats_multiple_batch_updates():
|
||||
"""Test updating with multiple batches."""
|
||||
np.random.seed(42)
|
||||
data1 = np.random.normal(0, 1, (100, 2))
|
||||
data2 = np.random.normal(1, 1, (150, 2))
|
||||
|
||||
running_stats = RunningQuantileStats()
|
||||
running_stats.update(data1)
|
||||
running_stats.update(data2)
|
||||
|
||||
assert running_stats._count == 250
|
||||
|
||||
# Verify running mean is correct
|
||||
combined_data = np.vstack([data1, data2])
|
||||
expected_mean = np.mean(combined_data, axis=0)
|
||||
np.testing.assert_allclose(running_stats._mean, expected_mean, atol=1e-10)
|
||||
|
||||
|
||||
def test_running_quantile_stats_get_statistics_basic():
|
||||
"""Test getting basic statistics without quantiles."""
|
||||
np.random.seed(42)
|
||||
data = np.random.normal(0, 1, (100, 2))
|
||||
|
||||
running_stats = RunningQuantileStats()
|
||||
running_stats.update(data)
|
||||
|
||||
stats = running_stats.get_statistics()
|
||||
|
||||
# Should have basic stats
|
||||
expected_keys = {"min", "max", "mean", "std", "count"}
|
||||
assert expected_keys.issubset(set(stats.keys()))
|
||||
|
||||
# Verify values
|
||||
np.testing.assert_allclose(stats["mean"], np.mean(data, axis=0), atol=1e-10)
|
||||
np.testing.assert_allclose(stats["std"], np.std(data, axis=0), atol=1e-6)
|
||||
np.testing.assert_equal(stats["count"], np.array([100]))
|
||||
|
||||
|
||||
def test_running_quantile_stats_get_statistics_with_quantiles():
|
||||
"""Test getting statistics with quantiles."""
|
||||
np.random.seed(42)
|
||||
data = np.random.normal(0, 1, (1000, 2))
|
||||
|
||||
running_stats = RunningQuantileStats()
|
||||
running_stats.update(data)
|
||||
|
||||
stats = running_stats.get_statistics()
|
||||
|
||||
# Should have basic stats plus quantiles
|
||||
expected_keys = {"min", "max", "mean", "std", "count", "q01", "q10", "q50", "q90", "q99"}
|
||||
assert expected_keys.issubset(set(stats.keys()))
|
||||
|
||||
# Verify quantile values are reasonable
|
||||
from lerobot.datasets.compute_stats import DEFAULT_QUANTILES
|
||||
|
||||
for i, q in enumerate(DEFAULT_QUANTILES):
|
||||
q_key = f"q{int(q * 100):02d}"
|
||||
assert q_key in stats
|
||||
assert stats[q_key].shape == (2,)
|
||||
|
||||
# Check that quantiles are in reasonable order
|
||||
if i > 0:
|
||||
prev_q_key = f"q{int(DEFAULT_QUANTILES[i - 1] * 100):02d}"
|
||||
assert np.all(stats[prev_q_key] <= stats[q_key])
|
||||
|
||||
|
||||
def test_running_quantile_stats_histogram_adjustment():
|
||||
"""Test that histograms adjust when min/max change."""
|
||||
running_stats = RunningQuantileStats()
|
||||
|
||||
# Initial data with small range
|
||||
data1 = np.array([[0.0, 1.0], [0.1, 1.1], [0.2, 1.2]])
|
||||
running_stats.update(data1)
|
||||
|
||||
initial_edges_0 = running_stats._bin_edges[0].copy()
|
||||
initial_edges_1 = running_stats._bin_edges[1].copy()
|
||||
|
||||
# Add data with much larger range
|
||||
data2 = np.array([[10.0, -10.0], [11.0, -11.0]])
|
||||
running_stats.update(data2)
|
||||
|
||||
# Bin edges should have changed
|
||||
assert not np.array_equal(initial_edges_0, running_stats._bin_edges[0])
|
||||
assert not np.array_equal(initial_edges_1, running_stats._bin_edges[1])
|
||||
|
||||
# New edges should cover the expanded range
|
||||
# First dimension: min should still be ~0.0, max should be ~11.0
|
||||
assert running_stats._bin_edges[0][0] <= 0.0
|
||||
assert running_stats._bin_edges[0][-1] >= 11.0
|
||||
|
||||
# Second dimension: min should be ~-11.0, max should be ~1.2
|
||||
assert running_stats._bin_edges[1][0] <= -11.0
|
||||
assert running_stats._bin_edges[1][-1] >= 1.2
|
||||
|
||||
|
||||
def test_running_quantile_stats_insufficient_data_error():
|
||||
"""Test error when trying to get stats with insufficient data."""
|
||||
running_stats = RunningQuantileStats()
|
||||
|
||||
with pytest.raises(ValueError, match="Cannot compute statistics for less than 2 vectors"):
|
||||
running_stats.get_statistics()
|
||||
|
||||
# Single vector should also fail
|
||||
running_stats.update(np.array([[1.0]]))
|
||||
with pytest.raises(ValueError, match="Cannot compute statistics for less than 2 vectors"):
|
||||
running_stats.get_statistics()
|
||||
|
||||
|
||||
def test_running_quantile_stats_vector_length_consistency():
|
||||
"""Test error when vector lengths don't match."""
|
||||
running_stats = RunningQuantileStats()
|
||||
running_stats.update(np.array([[1.0, 2.0], [3.0, 4.0]]))
|
||||
|
||||
with pytest.raises(ValueError, match="The length of new vectors does not match"):
|
||||
running_stats.update(np.array([[1.0, 2.0, 3.0]])) # Different length
|
||||
|
||||
|
||||
def test_running_quantile_stats_reshape_handling():
|
||||
"""Test that various input shapes are handled correctly."""
|
||||
running_stats = RunningQuantileStats()
|
||||
|
||||
# Test 3D input (e.g., images)
|
||||
data_3d = np.random.normal(0, 1, (10, 32, 32))
|
||||
running_stats.update(data_3d)
|
||||
|
||||
assert running_stats._count == 10 * 32
|
||||
assert running_stats._mean.shape == (32,)
|
||||
|
||||
# Test 1D input
|
||||
running_stats_1d = RunningQuantileStats()
|
||||
data_1d = np.array([1, 2, 3, 4, 5]).reshape(-1, 1)
|
||||
running_stats_1d.update(data_1d)
|
||||
|
||||
assert running_stats_1d._count == 5
|
||||
assert running_stats_1d._mean.shape == (1,)
|
||||
|
||||
|
||||
def test_get_feature_stats_quantiles_enabled_by_default():
|
||||
"""Test that quantiles are computed by default."""
|
||||
data = np.random.normal(0, 1, (100, 5))
|
||||
stats = get_feature_stats(data, axis=0, keepdims=False)
|
||||
|
||||
expected_keys = {"min", "max", "mean", "std", "count", "q01", "q10", "q50", "q90", "q99"}
|
||||
assert set(stats.keys()) == expected_keys
|
||||
|
||||
|
||||
def test_get_feature_stats_quantiles_with_vector_data():
|
||||
"""Test quantile computation with vector data."""
|
||||
np.random.seed(42)
|
||||
data = np.random.normal(0, 1, (100, 5))
|
||||
|
||||
stats = get_feature_stats(data, axis=0, keepdims=False)
|
||||
|
||||
expected_keys = {"min", "max", "mean", "std", "count", "q01", "q10", "q50", "q90", "q99"}
|
||||
assert set(stats.keys()) == expected_keys
|
||||
|
||||
# Verify shapes
|
||||
assert stats["q01"].shape == (5,)
|
||||
assert stats["q99"].shape == (5,)
|
||||
|
||||
# Verify quantiles are reasonable
|
||||
assert np.all(stats["q01"] < stats["q99"])
|
||||
|
||||
|
||||
def test_get_feature_stats_quantiles_with_image_data():
|
||||
"""Test quantile computation with image data."""
|
||||
np.random.seed(42)
|
||||
data = np.random.normal(0, 1, (50, 3, 32, 32)) # batch, channels, height, width
|
||||
|
||||
stats = get_feature_stats(data, axis=(0, 2, 3), keepdims=True)
|
||||
|
||||
expected_keys = {"min", "max", "mean", "std", "count", "q01", "q10", "q50", "q90", "q99"}
|
||||
assert set(stats.keys()) == expected_keys
|
||||
|
||||
# Verify shapes for images (should be (1, channels, 1, 1))
|
||||
assert stats["q01"].shape == (1, 3, 1, 1)
|
||||
assert stats["q50"].shape == (1, 3, 1, 1)
|
||||
assert stats["q99"].shape == (1, 3, 1, 1)
|
||||
|
||||
|
||||
def test_get_feature_stats_fixed_quantiles():
|
||||
"""Test that fixed quantiles are always computed."""
|
||||
data = np.random.normal(0, 1, (200, 3))
|
||||
|
||||
stats = get_feature_stats(data, axis=0, keepdims=False)
|
||||
|
||||
expected_quantile_keys = {"q01", "q10", "q50", "q90", "q99"}
|
||||
assert expected_quantile_keys.issubset(set(stats.keys()))
|
||||
|
||||
|
||||
def test_get_feature_stats_unsupported_axis_error():
|
||||
"""Test error for unsupported axis configuration."""
|
||||
data = np.random.normal(0, 1, (10, 5))
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported axis configuration"):
|
||||
get_feature_stats(
|
||||
data,
|
||||
axis=(1, 2), # Unsupported axis
|
||||
keepdims=False,
|
||||
)
|
||||
|
||||
|
||||
def test_compute_episode_stats_backward_compatibility():
|
||||
"""Test that existing functionality is preserved."""
|
||||
episode_data = {
|
||||
"action": np.random.normal(0, 1, (100, 7)),
|
||||
"observation.state": np.random.normal(0, 1, (100, 10)),
|
||||
}
|
||||
features = {
|
||||
"action": {"dtype": "float32", "shape": (7,)},
|
||||
"observation.state": {"dtype": "float32", "shape": (10,)},
|
||||
}
|
||||
|
||||
stats = compute_episode_stats(episode_data, features)
|
||||
|
||||
for key in ["action", "observation.state"]:
|
||||
expected_keys = {"min", "max", "mean", "std", "count", "q01", "q10", "q50", "q90", "q99"}
|
||||
assert set(stats[key].keys()) == expected_keys
|
||||
|
||||
|
||||
def test_compute_episode_stats_with_custom_quantiles():
|
||||
"""Test quantile computation with custom quantile values."""
|
||||
np.random.seed(42)
|
||||
episode_data = {
|
||||
"action": np.random.normal(0, 1, (100, 7)),
|
||||
"observation.state": np.random.normal(2, 1, (100, 10)),
|
||||
}
|
||||
features = {
|
||||
"action": {"dtype": "float32", "shape": (7,)},
|
||||
"observation.state": {"dtype": "float32", "shape": (10,)},
|
||||
}
|
||||
|
||||
stats = compute_episode_stats(episode_data, features)
|
||||
|
||||
# Should have quantiles
|
||||
for key in ["action", "observation.state"]:
|
||||
expected_keys = {"min", "max", "mean", "std", "count", "q01", "q10", "q50", "q90", "q99"}
|
||||
assert set(stats[key].keys()) == expected_keys
|
||||
|
||||
# Verify shapes
|
||||
assert stats[key]["q01"].shape == (features[key]["shape"][0],)
|
||||
assert stats[key]["q99"].shape == (features[key]["shape"][0],)
|
||||
|
||||
|
||||
def test_compute_episode_stats_with_image_data():
|
||||
"""Test quantile computation with image features."""
|
||||
image_paths = [f"image_{i}.jpg" for i in range(50)]
|
||||
episode_data = {
|
||||
"observation.image": image_paths,
|
||||
"action": np.random.normal(0, 1, (50, 5)),
|
||||
}
|
||||
features = {
|
||||
"observation.image": {"dtype": "image"},
|
||||
"action": {"dtype": "float32", "shape": (5,)},
|
||||
}
|
||||
|
||||
with patch("lerobot.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy):
|
||||
stats = compute_episode_stats(episode_data, features)
|
||||
|
||||
# Image quantiles should be normalized and have correct shape
|
||||
assert "q01" in stats["observation.image"]
|
||||
assert "q50" in stats["observation.image"]
|
||||
assert "q99" in stats["observation.image"]
|
||||
assert stats["observation.image"]["q01"].shape == (3, 1, 1)
|
||||
assert stats["observation.image"]["q50"].shape == (3, 1, 1)
|
||||
assert stats["observation.image"]["q99"].shape == (3, 1, 1)
|
||||
|
||||
# Action quantiles should have correct shape
|
||||
assert stats["action"]["q01"].shape == (5,)
|
||||
assert stats["action"]["q50"].shape == (5,)
|
||||
assert stats["action"]["q99"].shape == (5,)
|
||||
|
||||
|
||||
def test_compute_episode_stats_string_features_skipped():
|
||||
"""Test that string features are properly skipped."""
|
||||
episode_data = {
|
||||
"task": ["pick_apple"] * 100, # String feature
|
||||
"action": np.random.normal(0, 1, (100, 5)),
|
||||
}
|
||||
features = {
|
||||
"task": {"dtype": "string"},
|
||||
"action": {"dtype": "float32", "shape": (5,)},
|
||||
}
|
||||
|
||||
stats = compute_episode_stats(
|
||||
episode_data,
|
||||
features,
|
||||
)
|
||||
|
||||
# String features should be skipped
|
||||
assert "task" not in stats
|
||||
assert "action" in stats
|
||||
assert "q01" in stats["action"]
|
||||
|
||||
|
||||
def test_aggregate_feature_stats_with_quantiles():
|
||||
"""Test aggregating feature stats that include quantiles."""
|
||||
stats_ft_list = [
|
||||
{
|
||||
"min": np.array([1.0]),
|
||||
"max": np.array([10.0]),
|
||||
"mean": np.array([5.0]),
|
||||
"std": np.array([2.0]),
|
||||
"count": np.array([100]),
|
||||
"q01": np.array([1.5]),
|
||||
"q99": np.array([9.5]),
|
||||
},
|
||||
{
|
||||
"min": np.array([2.0]),
|
||||
"max": np.array([12.0]),
|
||||
"mean": np.array([6.0]),
|
||||
"std": np.array([2.5]),
|
||||
"count": np.array([150]),
|
||||
"q01": np.array([2.5]),
|
||||
"q99": np.array([11.5]),
|
||||
},
|
||||
]
|
||||
|
||||
result = aggregate_feature_stats(stats_ft_list)
|
||||
|
||||
# Should preserve quantiles
|
||||
assert "q01" in result
|
||||
assert "q99" in result
|
||||
|
||||
# Verify quantile aggregation (weighted average)
|
||||
expected_q01 = (1.5 * 100 + 2.5 * 150) / 250 # ≈ 2.1
|
||||
expected_q99 = (9.5 * 100 + 11.5 * 150) / 250 # ≈ 10.7
|
||||
|
||||
np.testing.assert_allclose(result["q01"], np.array([expected_q01]), atol=1e-6)
|
||||
np.testing.assert_allclose(result["q99"], np.array([expected_q99]), atol=1e-6)
|
||||
|
||||
|
||||
def test_aggregate_stats_mixed_quantiles():
|
||||
"""Test aggregating stats where some have quantiles and some don't."""
|
||||
stats_with_quantiles = {
|
||||
"feature1": {
|
||||
"min": np.array([1.0]),
|
||||
"max": np.array([10.0]),
|
||||
"mean": np.array([5.0]),
|
||||
"std": np.array([2.0]),
|
||||
"count": np.array([100]),
|
||||
"q01": np.array([1.5]),
|
||||
"q99": np.array([9.5]),
|
||||
}
|
||||
}
|
||||
|
||||
stats_without_quantiles = {
|
||||
"feature2": {
|
||||
"min": np.array([0.0]),
|
||||
"max": np.array([5.0]),
|
||||
"mean": np.array([2.5]),
|
||||
"std": np.array([1.5]),
|
||||
"count": np.array([50]),
|
||||
}
|
||||
}
|
||||
|
||||
all_stats = [stats_with_quantiles, stats_without_quantiles]
|
||||
result = aggregate_stats(all_stats)
|
||||
|
||||
# Feature1 should keep its quantiles
|
||||
assert "q01" in result["feature1"]
|
||||
assert "q99" in result["feature1"]
|
||||
|
||||
# Feature2 should not have quantiles
|
||||
assert "q01" not in result["feature2"]
|
||||
assert "q99" not in result["feature2"]
|
||||
|
||||
|
||||
def test_assert_type_and_shape_with_quantiles():
|
||||
"""Test validation works correctly with quantile keys."""
|
||||
# Valid stats with quantiles
|
||||
valid_stats = [
|
||||
{
|
||||
"observation.image": {
|
||||
"min": np.array([0.0, 0.0, 0.0]).reshape(3, 1, 1),
|
||||
"max": np.array([1.0, 1.0, 1.0]).reshape(3, 1, 1),
|
||||
"mean": np.array([0.5, 0.5, 0.5]).reshape(3, 1, 1),
|
||||
"std": np.array([0.2, 0.2, 0.2]).reshape(3, 1, 1),
|
||||
"count": np.array([100]),
|
||||
"q01": np.array([0.1, 0.1, 0.1]).reshape(3, 1, 1),
|
||||
"q99": np.array([0.9, 0.9, 0.9]).reshape(3, 1, 1),
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
# Should not raise error
|
||||
_assert_type_and_shape(valid_stats)
|
||||
|
||||
# Invalid shape for quantile
|
||||
invalid_stats = [
|
||||
{
|
||||
"observation.image": {
|
||||
"count": np.array([100]),
|
||||
"q01": np.array([0.1, 0.2]), # Wrong shape for image quantile
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
with pytest.raises(ValueError, match="Shape of quantile 'q01' must be \\(3,1,1\\)"):
|
||||
_assert_type_and_shape(invalid_stats)
|
||||
|
||||
|
||||
def test_quantile_integration_single_value_quantiles():
|
||||
"""Test quantile computation with single repeated value."""
|
||||
data = np.ones((100, 3)) # All ones
|
||||
|
||||
running_stats = RunningQuantileStats()
|
||||
running_stats.update(data)
|
||||
|
||||
stats = running_stats.get_statistics()
|
||||
|
||||
# All quantiles should be approximately 1.0
|
||||
np.testing.assert_allclose(stats["q01"], np.array([1.0, 1.0, 1.0]), atol=1e-6)
|
||||
np.testing.assert_allclose(stats["q50"], np.array([1.0, 1.0, 1.0]), atol=1e-6)
|
||||
np.testing.assert_allclose(stats["q99"], np.array([1.0, 1.0, 1.0]), atol=1e-6)
|
||||
|
||||
|
||||
def test_quantile_integration_fixed_quantiles():
|
||||
"""Test that fixed quantiles are computed."""
|
||||
np.random.seed(42)
|
||||
data = np.random.normal(0, 1, (1000, 2))
|
||||
|
||||
stats = get_feature_stats(data, axis=0, keepdims=False)
|
||||
|
||||
# Check all fixed quantiles are present
|
||||
assert "q01" in stats
|
||||
assert "q10" in stats
|
||||
assert "q50" in stats
|
||||
assert "q90" in stats
|
||||
assert "q99" in stats
|
||||
|
||||
|
||||
def test_quantile_integration_large_dataset_quantiles():
|
||||
"""Test quantile computation efficiency with large datasets."""
|
||||
np.random.seed(42)
|
||||
large_data = np.random.normal(0, 1, (10000, 5))
|
||||
|
||||
running_stats = RunningQuantileStats(num_quantile_bins=1000) # Reduced bins for speed
|
||||
running_stats.update(large_data)
|
||||
|
||||
stats = running_stats.get_statistics()
|
||||
|
||||
# Should complete without issues and produce reasonable results
|
||||
assert stats["count"][0] == 10000
|
||||
assert len(stats["q01"]) == 5
|
||||
|
||||
|
||||
def test_fixed_quantiles_always_computed():
|
||||
"""Test that the fixed quantiles [0.01, 0.10, 0.50, 0.90, 0.99] are always computed."""
|
||||
np.random.seed(42)
|
||||
# Test with vector data
|
||||
vector_data = np.random.normal(0, 1, (100, 5))
|
||||
vector_stats = get_feature_stats(vector_data, axis=0, keepdims=False)
|
||||
|
||||
# Check all fixed quantiles are present
|
||||
expected_quantiles = ["q01", "q10", "q50", "q90", "q99"]
|
||||
for q_key in expected_quantiles:
|
||||
assert q_key in vector_stats
|
||||
assert vector_stats[q_key].shape == (5,)
|
||||
|
||||
# Test with image data
|
||||
image_data = np.random.randint(0, 256, (50, 3, 32, 32), dtype=np.uint8)
|
||||
image_stats = get_feature_stats(image_data, axis=(0, 2, 3), keepdims=True)
|
||||
|
||||
# Check all fixed quantiles are present for images
|
||||
for q_key in expected_quantiles:
|
||||
assert q_key in image_stats
|
||||
assert image_stats[q_key].shape == (1, 3, 1, 1)
|
||||
|
||||
# Test with episode data
|
||||
episode_data = {
|
||||
"action": np.random.normal(0, 1, (100, 7)),
|
||||
"observation.state": np.random.normal(0, 1, (100, 10)),
|
||||
}
|
||||
features = {
|
||||
"action": {"dtype": "float32", "shape": (7,)},
|
||||
"observation.state": {"dtype": "float32", "shape": (10,)},
|
||||
}
|
||||
|
||||
episode_stats = compute_episode_stats(episode_data, features)
|
||||
|
||||
# Check all fixed quantiles are present in episode stats
|
||||
for key in ["action", "observation.state"]:
|
||||
for q_key in expected_quantiles:
|
||||
assert q_key in episode_stats[key]
|
||||
assert episode_stats[key][q_key].shape == (features[key]["shape"][0],)
|
||||
|
||||
212
tests/datasets/test_quantiles_dataset_integration.py
Normal file
212
tests/datasets/test_quantiles_dataset_integration.py
Normal file
@@ -0,0 +1,212 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Integration tests for quantile functionality in LeRobotDataset."""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
|
||||
def mock_load_image_as_numpy(path, dtype, channel_first):
|
||||
"""Mock image loading for consistent test results."""
|
||||
return np.ones((3, 32, 32), dtype=dtype) if channel_first else np.ones((32, 32, 3), dtype=dtype)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def simple_features():
|
||||
"""Simple feature configuration for testing."""
|
||||
return {
|
||||
"action": {
|
||||
"dtype": "float32",
|
||||
"shape": (4,),
|
||||
"names": ["arm_x", "arm_y", "arm_z", "gripper"],
|
||||
},
|
||||
"observation.state": {
|
||||
"dtype": "float32",
|
||||
"shape": (10,),
|
||||
"names": [f"joint_{i}" for i in range(10)],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def test_create_dataset_with_fixed_quantiles(tmp_path, simple_features):
|
||||
"""Test creating dataset with fixed quantiles."""
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id="test_dataset_fixed_quantiles",
|
||||
fps=30,
|
||||
features=simple_features,
|
||||
root=tmp_path / "create_fixed_quantiles",
|
||||
)
|
||||
|
||||
# Dataset should be created successfully
|
||||
assert dataset is not None
|
||||
|
||||
|
||||
def test_save_episode_computes_all_quantiles(tmp_path, simple_features):
|
||||
"""Test that all fixed quantiles are computed when saving an episode."""
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id="test_dataset_save_episode",
|
||||
fps=30,
|
||||
features=simple_features,
|
||||
root=tmp_path / "save_episode_quantiles",
|
||||
)
|
||||
|
||||
# Add some frames
|
||||
for _ in range(10):
|
||||
dataset.add_frame(
|
||||
{
|
||||
"action": np.random.randn(4).astype(np.float32), # Correct shape for action
|
||||
"observation.state": np.random.randn(10).astype(np.float32),
|
||||
"task": "test_task",
|
||||
}
|
||||
)
|
||||
|
||||
dataset.save_episode()
|
||||
|
||||
# Check that all fixed quantiles were computed
|
||||
stats = dataset.meta.stats
|
||||
for key in ["action", "observation.state"]:
|
||||
assert "q01" in stats[key]
|
||||
assert "q10" in stats[key]
|
||||
assert "q50" in stats[key]
|
||||
assert "q90" in stats[key]
|
||||
assert "q99" in stats[key]
|
||||
|
||||
|
||||
def test_quantile_values_ordering(tmp_path, simple_features):
|
||||
"""Test that quantile values are properly ordered."""
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id="test_dataset_quantile_ordering",
|
||||
fps=30,
|
||||
features=simple_features,
|
||||
root=tmp_path / "quantile_ordering",
|
||||
)
|
||||
|
||||
# Add data with known distribution
|
||||
np.random.seed(42)
|
||||
for _ in range(100):
|
||||
dataset.add_frame(
|
||||
{
|
||||
"action": np.random.randn(4).astype(np.float32), # Correct shape for action
|
||||
"observation.state": np.random.randn(10).astype(np.float32),
|
||||
"task": "test_task",
|
||||
}
|
||||
)
|
||||
|
||||
dataset.save_episode()
|
||||
stats = dataset.meta.stats
|
||||
|
||||
# Verify quantile ordering
|
||||
for key in ["action", "observation.state"]:
|
||||
assert np.all(stats[key]["q01"] <= stats[key]["q10"])
|
||||
assert np.all(stats[key]["q10"] <= stats[key]["q50"])
|
||||
assert np.all(stats[key]["q50"] <= stats[key]["q90"])
|
||||
assert np.all(stats[key]["q90"] <= stats[key]["q99"])
|
||||
|
||||
|
||||
def test_save_episode_with_fixed_quantiles(tmp_path, simple_features):
|
||||
"""Test saving episode always computes fixed quantiles."""
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id="test_dataset_save_fixed",
|
||||
fps=30,
|
||||
features=simple_features,
|
||||
root=tmp_path / "save_fixed_quantiles",
|
||||
)
|
||||
|
||||
# Add frames to episode
|
||||
np.random.seed(42)
|
||||
for _ in range(50):
|
||||
frame = {
|
||||
"action": np.random.normal(0, 1, (4,)).astype(np.float32),
|
||||
"observation.state": np.random.normal(0, 1, (10,)).astype(np.float32),
|
||||
"task": "test_task",
|
||||
}
|
||||
dataset.add_frame(frame)
|
||||
|
||||
dataset.save_episode()
|
||||
|
||||
# Check that all fixed quantiles are included
|
||||
stats = dataset.meta.stats
|
||||
for key in ["action", "observation.state"]:
|
||||
feature_stats = stats[key]
|
||||
expected_keys = {"min", "max", "mean", "std", "count", "q01", "q10", "q50", "q90", "q99"}
|
||||
assert set(feature_stats.keys()) == expected_keys
|
||||
|
||||
|
||||
def test_quantile_aggregation_across_episodes(tmp_path, simple_features):
|
||||
"""Test quantile aggregation across multiple episodes."""
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id="test_dataset_aggregation",
|
||||
fps=30,
|
||||
features=simple_features,
|
||||
root=tmp_path / "quantile_aggregation",
|
||||
)
|
||||
|
||||
# Add frames to episode
|
||||
np.random.seed(42)
|
||||
for _ in range(100):
|
||||
frame = {
|
||||
"action": np.random.normal(0, 1, (4,)).astype(np.float32),
|
||||
"observation.state": np.random.normal(2, 1, (10,)).astype(np.float32),
|
||||
"task": "test_task",
|
||||
}
|
||||
dataset.add_frame(frame)
|
||||
|
||||
dataset.save_episode()
|
||||
|
||||
# Check stats include all fixed quantiles
|
||||
stats = dataset.meta.stats
|
||||
for key in ["action", "observation.state"]:
|
||||
feature_stats = stats[key]
|
||||
expected_keys = {"min", "max", "mean", "std", "count", "q01", "q10", "q50", "q90", "q99"}
|
||||
assert set(feature_stats.keys()) == expected_keys
|
||||
assert feature_stats["q01"].shape == (simple_features[key]["shape"][0],)
|
||||
assert feature_stats["q50"].shape == (simple_features[key]["shape"][0],)
|
||||
assert feature_stats["q99"].shape == (simple_features[key]["shape"][0],)
|
||||
assert np.all(feature_stats["q01"] <= feature_stats["q50"])
|
||||
assert np.all(feature_stats["q50"] <= feature_stats["q99"])
|
||||
|
||||
|
||||
def test_save_multiple_episodes_with_quantiles(tmp_path, simple_features):
|
||||
"""Test quantile aggregation across multiple episodes."""
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id="test_dataset_multiple_episodes",
|
||||
fps=30,
|
||||
features=simple_features,
|
||||
root=tmp_path / "multiple_episodes",
|
||||
)
|
||||
|
||||
# Save multiple episodes
|
||||
np.random.seed(42)
|
||||
for episode_idx in range(3):
|
||||
for _ in range(50):
|
||||
frame = {
|
||||
"action": np.random.normal(episode_idx * 2.0, 1, (4,)).astype(np.float32),
|
||||
"observation.state": np.random.normal(-episode_idx * 1.5, 1, (10,)).astype(np.float32),
|
||||
"task": f"task_{episode_idx}",
|
||||
}
|
||||
dataset.add_frame(frame)
|
||||
|
||||
dataset.save_episode()
|
||||
|
||||
# Verify final stats include properly aggregated quantiles
|
||||
stats = dataset.meta.stats
|
||||
for key in ["action", "observation.state"]:
|
||||
feature_stats = stats[key]
|
||||
assert "q01" in feature_stats and "q99" in feature_stats
|
||||
assert feature_stats["count"][0] == 150 # 3 episodes * 50 frames
|
||||
117
tests/policies/pi0_pi05/test_pi0.py
Normal file
117
tests/policies/pi0_pi05/test_pi0.py
Normal file
@@ -0,0 +1,117 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
"""Test script to verify PI0 policy integration with LeRobot, only meant to be run locally!"""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
# Skip this entire module in CI
|
||||
pytestmark = pytest.mark.skipif(
|
||||
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
|
||||
reason="This test requires local OpenPI installation and is not meant for CI",
|
||||
)
|
||||
|
||||
from lerobot.policies.factory import make_policy_config # noqa: E402
|
||||
from lerobot.policies.pi0 import ( # noqa: E402
|
||||
PI0Config,
|
||||
PI0Policy,
|
||||
make_pi0_pre_post_processors, # noqa: E402
|
||||
)
|
||||
from lerobot.utils.random_utils import set_seed # noqa: E402
|
||||
from tests.utils import require_cuda # noqa: E402
|
||||
|
||||
|
||||
@require_cuda
|
||||
def test_policy_instantiation():
|
||||
# Create config
|
||||
set_seed(42)
|
||||
config = PI0Config(max_action_dim=7, max_state_dim=14, dtype="float32")
|
||||
|
||||
# Set up input_features and output_features in the config
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
|
||||
config.input_features = {
|
||||
"observation.state": PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(14,),
|
||||
),
|
||||
"observation.images.base_0_rgb": PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, 224, 224),
|
||||
),
|
||||
}
|
||||
|
||||
config.output_features = {
|
||||
"action": PolicyFeature(
|
||||
type=FeatureType.ACTION,
|
||||
shape=(7,),
|
||||
),
|
||||
}
|
||||
|
||||
# Create dummy dataset stats
|
||||
dataset_stats = {
|
||||
"observation.state": {
|
||||
"mean": torch.zeros(14),
|
||||
"std": torch.ones(14),
|
||||
},
|
||||
"action": {
|
||||
"mean": torch.zeros(7),
|
||||
"std": torch.ones(7),
|
||||
},
|
||||
"observation.images.base_0_rgb": {
|
||||
"mean": torch.zeros(3, 224, 224),
|
||||
"std": torch.ones(3, 224, 224),
|
||||
},
|
||||
}
|
||||
|
||||
# Instantiate policy
|
||||
policy = PI0Policy(config)
|
||||
preprocessor, postprocessor = make_pi0_pre_post_processors(config=config, dataset_stats=dataset_stats)
|
||||
# Test forward pass with dummy data
|
||||
batch_size = 1
|
||||
device = config.device
|
||||
batch = {
|
||||
"observation.state": torch.randn(batch_size, 14, dtype=torch.float32, device=device),
|
||||
"action": torch.randn(batch_size, config.chunk_size, 7, dtype=torch.float32, device=device),
|
||||
"observation.images.base_0_rgb": torch.rand(
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=device
|
||||
), # Use rand for [0,1] range
|
||||
"task": ["Pick up the object"] * batch_size,
|
||||
}
|
||||
batch = preprocessor(batch)
|
||||
try:
|
||||
loss, loss_dict = policy.forward(batch)
|
||||
print(f"Forward pass successful. Loss: {loss_dict['loss']:.4f}")
|
||||
except Exception as e:
|
||||
print(f"Forward pass failed: {e}")
|
||||
raise
|
||||
|
||||
try:
|
||||
with torch.no_grad():
|
||||
action = policy.select_action(batch)
|
||||
action = postprocessor(action)
|
||||
print(f"Action: {action}")
|
||||
print(f"Action prediction successful. Action shape: {action.shape}")
|
||||
except Exception as e:
|
||||
print(f"Action prediction failed: {e}")
|
||||
raise
|
||||
|
||||
|
||||
@require_cuda
|
||||
def test_config_creation():
|
||||
"""Test policy config creation through factory."""
|
||||
try:
|
||||
config = make_policy_config(
|
||||
policy_type="pi0",
|
||||
max_action_dim=7,
|
||||
max_state_dim=14,
|
||||
)
|
||||
print("Config created successfully through factory")
|
||||
print(f" Config type: {type(config).__name__}")
|
||||
print(f" PaliGemma variant: {config.paligemma_variant}")
|
||||
print(f" Action expert variant: {config.action_expert_variant}")
|
||||
except Exception as e:
|
||||
print(f"Config creation failed: {e}")
|
||||
raise
|
||||
154
tests/policies/pi0_pi05/test_pi05.py
Normal file
154
tests/policies/pi0_pi05/test_pi05.py
Normal file
@@ -0,0 +1,154 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
"""Test script to verify PI0.5 (pi05) support in PI0 policy, only meant to be run locally!"""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.utils.random_utils import set_seed
|
||||
|
||||
# Skip this entire module in CI
|
||||
pytestmark = pytest.mark.skipif(
|
||||
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
|
||||
reason="This test requires local OpenPI installation and is not meant for CI",
|
||||
)
|
||||
|
||||
from lerobot.policies.factory import make_policy_config # noqa: E402
|
||||
from lerobot.policies.pi05 import ( # noqa: E402
|
||||
PI05Config,
|
||||
PI05Policy,
|
||||
make_pi05_pre_post_processors, # noqa: E402
|
||||
)
|
||||
from tests.utils import require_cuda # noqa: E402
|
||||
|
||||
|
||||
@require_cuda
|
||||
def test_policy_instantiation():
|
||||
# Create config
|
||||
set_seed(42)
|
||||
config = PI05Config(max_action_dim=7, max_state_dim=14, dtype="float32")
|
||||
|
||||
# Set up input_features and output_features in the config
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
|
||||
config.input_features = {
|
||||
"observation.state": PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(14,),
|
||||
),
|
||||
"observation.images.base_0_rgb": PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, 224, 224),
|
||||
),
|
||||
}
|
||||
|
||||
config.output_features = {
|
||||
"action": PolicyFeature(
|
||||
type=FeatureType.ACTION,
|
||||
shape=(7,),
|
||||
),
|
||||
}
|
||||
|
||||
assert config.tokenizer_max_length == 200, (
|
||||
f"Expected tokenizer_max_length=200 for pi05, got {config.tokenizer_max_length}"
|
||||
)
|
||||
|
||||
# Create dummy dataset stats
|
||||
dataset_stats = {
|
||||
"observation.state": {
|
||||
"mean": torch.zeros(14),
|
||||
"std": torch.ones(14),
|
||||
"min": torch.zeros(14),
|
||||
"max": torch.ones(14),
|
||||
"q01": torch.zeros(14),
|
||||
"q99": torch.ones(14),
|
||||
},
|
||||
"action": {
|
||||
"mean": torch.zeros(7),
|
||||
"std": torch.ones(7),
|
||||
"min": torch.zeros(7),
|
||||
"max": torch.ones(7),
|
||||
"q01": torch.zeros(7),
|
||||
"q99": torch.ones(7),
|
||||
},
|
||||
"observation.images.base_0_rgb": {
|
||||
"mean": torch.zeros(3, 224, 224),
|
||||
"std": torch.ones(3, 224, 224),
|
||||
"q01": torch.zeros(3, 224, 224),
|
||||
"q99": torch.ones(3, 224, 224),
|
||||
},
|
||||
}
|
||||
|
||||
# Instantiate policy
|
||||
policy = PI05Policy(config)
|
||||
# Test forward pass with dummy data
|
||||
batch_size = 1
|
||||
preprocessor, postprocessor = make_pi05_pre_post_processors(config=config, dataset_stats=dataset_stats)
|
||||
device = config.device
|
||||
batch = {
|
||||
"observation.state": torch.randn(batch_size, 14, dtype=torch.float32, device=device),
|
||||
"action": torch.randn(batch_size, config.chunk_size, 7, dtype=torch.float32, device=device),
|
||||
"observation.images.base_0_rgb": torch.rand(
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=device
|
||||
), # Use rand for [0,1] range
|
||||
"task": ["Pick up the object"] * batch_size,
|
||||
}
|
||||
batch = preprocessor(batch)
|
||||
try:
|
||||
loss, loss_dict = policy.forward(batch)
|
||||
print(f"Forward pass successful. Loss: {loss_dict['loss']:.4f}")
|
||||
except Exception as e:
|
||||
print(f"Forward pass failed: {e}")
|
||||
raise
|
||||
try:
|
||||
with torch.no_grad():
|
||||
action = policy.select_action(batch)
|
||||
action = postprocessor(action)
|
||||
print(f"Action: {action}")
|
||||
print(f"Action prediction successful. Action shape: {action.shape}")
|
||||
except Exception as e:
|
||||
print(f"Action prediction failed: {e}")
|
||||
raise
|
||||
|
||||
# Verify pi05 model components exist
|
||||
# Check that time_mlp layers exist (for AdaRMS conditioning)
|
||||
assert hasattr(policy.model, "time_mlp_in"), "Missing time_mlp_in layer for pi05"
|
||||
assert hasattr(policy.model, "time_mlp_out"), "Missing time_mlp_out layer for pi05"
|
||||
|
||||
# Check that action_time_mlp layers don't exist (pi0 only)
|
||||
assert not hasattr(policy.model, "action_time_mlp_in"), "action_time_mlp_in should not exist in pi05 mode"
|
||||
assert not hasattr(policy.model, "action_time_mlp_out"), (
|
||||
"action_time_mlp_out should not exist in pi05 mode"
|
||||
)
|
||||
|
||||
# Check that state_proj doesn't exist in pi05 mode
|
||||
assert not hasattr(policy.model, "state_proj"), "state_proj should not exist in pi05 mode"
|
||||
|
||||
# Check AdaRMS configuration in the underlying model
|
||||
adarms_config = policy.model.paligemma_with_expert.paligemma.config.text_config.use_adarms
|
||||
assert adarms_config == False, f"PaliGemma should not use AdaRMS, got {adarms_config}" # noqa: E712
|
||||
|
||||
adarms_expert_config = policy.model.paligemma_with_expert.gemma_expert.config.use_adarms
|
||||
assert adarms_expert_config == True, ( # noqa: E712
|
||||
f"Action expert should use AdaRMS in pi05, got {adarms_expert_config}"
|
||||
)
|
||||
|
||||
|
||||
@require_cuda
|
||||
def test_config_creation():
|
||||
"""Test policy config creation through factory."""
|
||||
try:
|
||||
config = make_policy_config(
|
||||
policy_type="pi0",
|
||||
max_action_dim=7,
|
||||
max_state_dim=14,
|
||||
)
|
||||
print("Config created successfully through factory")
|
||||
print(f" Config type: {type(config).__name__}")
|
||||
print(f" PaliGemma variant: {config.paligemma_variant}")
|
||||
print(f" Action expert variant: {config.action_expert_variant}")
|
||||
except Exception as e:
|
||||
print(f"Config creation failed: {e}")
|
||||
raise
|
||||
419
tests/policies/pi0_pi05/test_pi05_original_vs_lerobot.py
Normal file
419
tests/policies/pi0_pi05/test_pi05_original_vs_lerobot.py
Normal file
@@ -0,0 +1,419 @@
|
||||
"""Test script to verify PI0OpenPI policy integration with LeRobot vs the original implementation, only meant to be run locally!"""
|
||||
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
# Skip if openpi or transformers is not available
|
||||
pytest.importorskip("openpi")
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
# Skip this entire module in CI
|
||||
pytestmark = pytest.mark.skipif(
|
||||
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
|
||||
reason="This test requires local OpenPI installation and is not meant for CI",
|
||||
)
|
||||
|
||||
from openpi.models_pytorch import preprocessing_pytorch as openpi_preprocessing # noqa: E402
|
||||
|
||||
# NOTE: Assumes PYTHONPATH is set to include OpenPI src as per instructions.
|
||||
from openpi.models_pytorch.pi0_pytorch import PI0Pytorch # noqa: E402
|
||||
from transformers import AutoTokenizer # noqa: E402
|
||||
|
||||
from lerobot.policies.pi05 import PI05Config, PI05Policy # noqa: E402
|
||||
from lerobot.policies.pi05.processor_pi05 import make_pi05_pre_post_processors # noqa: E402
|
||||
from lerobot.processor import PolicyAction, PolicyProcessorPipeline # noqa: E402
|
||||
|
||||
# TODO: ADDING DEFAULT IMAGES_FEATURES TO CONFIG
|
||||
DUMMY_ACTION_DIM = 32
|
||||
DUMMY_STATE_DIM = 32
|
||||
DUMMY_ACTION_HORIZON = 50
|
||||
DUMMY_MAX_TOKEN_LEN = 200
|
||||
DEVICE = "cpu" # Use CPU to avoid memory issues for testing
|
||||
|
||||
DUMMY_DATASET_STATS = {
|
||||
"observation.state": {
|
||||
"mean": torch.zeros(DUMMY_STATE_DIM),
|
||||
"std": torch.ones(DUMMY_STATE_DIM),
|
||||
"q01": torch.zeros(DUMMY_STATE_DIM),
|
||||
"q99": torch.ones(DUMMY_STATE_DIM),
|
||||
},
|
||||
"action": {
|
||||
"mean": torch.zeros(DUMMY_ACTION_DIM),
|
||||
"std": torch.ones(DUMMY_ACTION_DIM),
|
||||
"q01": torch.zeros(DUMMY_ACTION_DIM),
|
||||
"q99": torch.ones(DUMMY_ACTION_DIM),
|
||||
},
|
||||
"images": {
|
||||
"base_0_rgb": {
|
||||
"mean": torch.zeros(3, 224, 224),
|
||||
"std": torch.ones(3, 224, 224),
|
||||
"q01": torch.zeros(3, 224, 224),
|
||||
"q99": torch.ones(3, 224, 224),
|
||||
},
|
||||
"left_wrist_0_rgb": {
|
||||
"mean": torch.zeros(3, 224, 224),
|
||||
"std": torch.ones(3, 224, 224),
|
||||
"q01": torch.zeros(3, 224, 224),
|
||||
"q99": torch.ones(3, 224, 224),
|
||||
},
|
||||
"right_wrist_0_rgb": {
|
||||
"mean": torch.zeros(3, 224, 224),
|
||||
"std": torch.ones(3, 224, 224),
|
||||
"q01": torch.zeros(3, 224, 224),
|
||||
"q99": torch.ones(3, 224, 224),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class PI05BaseOriginalConfig:
|
||||
action_dim: int = DUMMY_ACTION_DIM
|
||||
action_horizon: int = DUMMY_ACTION_HORIZON
|
||||
paligemma_variant: str = "gemma_2b"
|
||||
action_expert_variant: str = "gemma_300m"
|
||||
precision: str = "float32"
|
||||
pi05: bool = True
|
||||
dtype: str = "float32"
|
||||
|
||||
|
||||
def instantiate_lerobot_pi05(
|
||||
from_pretrained: bool = False,
|
||||
) -> tuple[
|
||||
PI05Policy,
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
if from_pretrained:
|
||||
# Load the policy first
|
||||
policy = PI05Policy.from_pretrained(pretrained_name_or_path="lerobot/pi05_base", strict=True)
|
||||
else:
|
||||
config = PI05Config(max_action_dim=DUMMY_ACTION_DIM, max_state_dim=DUMMY_STATE_DIM, dtype="float32")
|
||||
policy = PI05Policy(config)
|
||||
|
||||
policy.to(DEVICE)
|
||||
policy.config.device = DEVICE
|
||||
preprocessor, postprocessor = make_pi05_pre_post_processors(
|
||||
config=policy.config, dataset_stats=DUMMY_DATASET_STATS
|
||||
)
|
||||
return (policy, preprocessor, postprocessor)
|
||||
|
||||
|
||||
def instantiate_original_pi05(from_pretrained: bool = False, model_path: str | None = None):
|
||||
config = PI05BaseOriginalConfig()
|
||||
policy = PI0Pytorch(config)
|
||||
|
||||
if from_pretrained:
|
||||
try:
|
||||
print("Loading converted PyTorch weights from HuggingFace Hub (lerobot/pi05_base)...")
|
||||
|
||||
# Download the model from HuggingFace Hub
|
||||
import safetensors.torch
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
# Download the entire repository
|
||||
if model_path and os.path.exists(model_path):
|
||||
cache_dir = model_path
|
||||
print(f"Using cached model from: {cache_dir}")
|
||||
else:
|
||||
cache_dir = snapshot_download(repo_id="lerobot/pi05_base", repo_type="model")
|
||||
print(f"Downloaded model to: {cache_dir}")
|
||||
|
||||
# Try to load safetensors format first
|
||||
model_file = os.path.join(cache_dir, "model.safetensors")
|
||||
if os.path.exists(model_file):
|
||||
state_dict = safetensors.torch.load_file(model_file)
|
||||
print(f"Loaded {len(state_dict)} parameters from safetensors")
|
||||
else:
|
||||
raise FileNotFoundError(f"No safetensors file found in {cache_dir}")
|
||||
|
||||
# Load the state dict into the model
|
||||
missing_keys, unexpected_keys = policy.load_state_dict(state_dict, strict=False)
|
||||
|
||||
if missing_keys:
|
||||
print(f"Missing keys: {len(missing_keys)}")
|
||||
if len(missing_keys) <= 5:
|
||||
for key in missing_keys:
|
||||
print(f" - {key}")
|
||||
else:
|
||||
for key in missing_keys[:5]:
|
||||
print(f" - {key}")
|
||||
print(f" ... and {len(missing_keys) - 5} more")
|
||||
|
||||
if unexpected_keys:
|
||||
print(f"Unexpected keys: {len(unexpected_keys)}")
|
||||
if len(unexpected_keys) <= 5:
|
||||
for key in unexpected_keys:
|
||||
print(f" - {key}")
|
||||
else:
|
||||
for key in unexpected_keys[:5]:
|
||||
print(f" - {key}")
|
||||
print(f" ... and {len(unexpected_keys) - 5} more")
|
||||
|
||||
if not missing_keys and not unexpected_keys:
|
||||
print("All pretrained weights loaded successfully!")
|
||||
else:
|
||||
print("Pretrained weights loaded with some missing/unexpected keys (this may be normal)")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to load pretrained weights: {e}")
|
||||
print(" Using randomly initialized weights...")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
policy.to(DEVICE)
|
||||
return policy
|
||||
|
||||
|
||||
def create_dummy_data():
|
||||
batch_size = 2 # Reduce batch size for testing
|
||||
device = DEVICE
|
||||
|
||||
# Use the exact same prompt for both implementations
|
||||
prompt = "Pick up the red block and place it in the bin"
|
||||
|
||||
batch = {
|
||||
"observation.state": torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=device),
|
||||
"action": torch.randn(
|
||||
batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM, dtype=torch.float32, device=device
|
||||
),
|
||||
# Create images in [0, 1] range as expected by LeRobot (will be converted to [-1, 1] internally)
|
||||
"observation.images.base_0_rgb": torch.rand(
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=device
|
||||
),
|
||||
"observation.images.left_wrist_0_rgb": torch.rand(
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=device
|
||||
),
|
||||
"observation.images.right_wrist_0_rgb": torch.rand(
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=device
|
||||
),
|
||||
# Add the task prompt for LeRobot - provide as list with single element to trigger expansion
|
||||
"task": [prompt for _ in range(batch_size)],
|
||||
}
|
||||
return batch
|
||||
|
||||
|
||||
def extract_lerobot_processed_inputs(lerobot_pi0, batch):
|
||||
"""Extract the exact same processed inputs that LeRobot uses internally."""
|
||||
# Get the tokenized language from LeRobot's internal method
|
||||
lang_tokens, lang_masks = lerobot_pi0._tokenize_language(batch)
|
||||
|
||||
# Get the preprocessed images from LeRobot's internal method
|
||||
images, img_masks = lerobot_pi0._preprocess_images(batch, train=False)
|
||||
|
||||
# Create dummy token_ar_mask and token_loss_mask for original implementation
|
||||
token_ar_mask = torch.zeros_like(lang_tokens, dtype=torch.int32)
|
||||
token_loss_mask = torch.ones_like(lang_masks, dtype=torch.bool)
|
||||
|
||||
return images, img_masks, lang_tokens, lang_masks, token_ar_mask, token_loss_mask
|
||||
|
||||
|
||||
class PI05Observation:
|
||||
"""Observation class that matches the original OpenPI format."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
state,
|
||||
images,
|
||||
image_masks,
|
||||
tokenized_prompt,
|
||||
tokenized_prompt_mask,
|
||||
token_ar_mask,
|
||||
token_loss_mask,
|
||||
):
|
||||
self.state = state
|
||||
self.images = images
|
||||
self.image_masks = image_masks
|
||||
self.tokenized_prompt = tokenized_prompt
|
||||
self.tokenized_prompt_mask = tokenized_prompt_mask
|
||||
self.token_ar_mask = token_ar_mask
|
||||
self.token_loss_mask = token_loss_mask
|
||||
|
||||
|
||||
def create_original_observation_with_openpi_preprocessing(batch):
|
||||
"""Create observation object for OpenPI using OpenPI's own preprocessing with pi05 state tokenizer."""
|
||||
batch_size = batch["observation.state"].shape[0]
|
||||
device = batch["observation.state"].device
|
||||
|
||||
# Create tokenizer for OpenPI (same as LeRobot uses)
|
||||
tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
|
||||
|
||||
# Get task description (pi05 processor handles all text formatting)
|
||||
tasks = batch.get("task", ["Pick up the object"] * batch_size)
|
||||
if isinstance(tasks, str):
|
||||
tasks = [tasks] * batch_size
|
||||
elif len(tasks) == 1:
|
||||
tasks = tasks * batch_size
|
||||
|
||||
# Use pi05 state and input tokenizer logic (same as Pi05PrepareStateTokenizerProcessorStep)
|
||||
state = batch["observation.state"]
|
||||
state = deepcopy(state)
|
||||
|
||||
# Prepare state (pad to max_state_dim)
|
||||
from lerobot.policies.pi05.modeling_pi05 import pad_vector
|
||||
|
||||
state = pad_vector(state, DUMMY_STATE_DIM)
|
||||
|
||||
# Normalize state to [-1, 1] range if needed (assuming it's already normalized from normalize_inputs)
|
||||
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
|
||||
state_np = state.cpu().numpy()
|
||||
discretized_states = np.digitize(state_np, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
|
||||
|
||||
# Create pi05-formatted prompts that include state information
|
||||
full_prompts = []
|
||||
for i, task in enumerate(tasks):
|
||||
cleaned_text = task.strip().replace("_", " ").replace("\n", " ")
|
||||
state_str = " ".join(map(str, discretized_states[i]))
|
||||
full_prompt = f"Task: {cleaned_text}, State: {state_str};\nAction: "
|
||||
full_prompts.append(full_prompt)
|
||||
|
||||
# Tokenize with max_length padding to match OpenPI's expected format
|
||||
tokenized = tokenizer(
|
||||
full_prompts,
|
||||
padding="max_length",
|
||||
padding_side="right",
|
||||
truncation=True,
|
||||
max_length=DUMMY_MAX_TOKEN_LEN,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
lang_tokens = tokenized["input_ids"].to(device)
|
||||
lang_masks = tokenized["attention_mask"].to(device, dtype=torch.bool)
|
||||
|
||||
# Create dummy token_ar_mask and token_loss_mask for OpenPI
|
||||
token_ar_mask = torch.zeros_like(lang_tokens, dtype=torch.int32)
|
||||
token_loss_mask = torch.ones_like(lang_masks, dtype=torch.bool)
|
||||
|
||||
# Convert LeRobot images format to OpenPI format (convert [0,1] to [-1,1] range)
|
||||
image_dict = {
|
||||
"base_0_rgb": batch["observation.images.base_0_rgb"] * 2.0 - 1.0,
|
||||
"left_wrist_0_rgb": batch["observation.images.left_wrist_0_rgb"] * 2.0 - 1.0,
|
||||
"right_wrist_0_rgb": batch["observation.images.right_wrist_0_rgb"] * 2.0 - 1.0,
|
||||
}
|
||||
|
||||
# Create image masks (all ones for real images)
|
||||
image_masks_dict = {}
|
||||
for key in image_dict:
|
||||
image_masks_dict[key] = torch.ones(batch_size, dtype=torch.bool, device=device)
|
||||
|
||||
# Create raw observation object (before preprocessing)
|
||||
raw_observation = PI05Observation(
|
||||
state=batch["observation.state"],
|
||||
images=image_dict,
|
||||
image_masks=image_masks_dict,
|
||||
tokenized_prompt=lang_tokens,
|
||||
tokenized_prompt_mask=lang_masks,
|
||||
token_ar_mask=token_ar_mask,
|
||||
token_loss_mask=token_loss_mask,
|
||||
)
|
||||
|
||||
# Now use OpenPI's preprocessing
|
||||
processed_obs = openpi_preprocessing.preprocess_observation_pytorch(raw_observation, train=False)
|
||||
|
||||
return processed_obs
|
||||
|
||||
|
||||
def create_original_observation_from_lerobot(lerobot_pi0, batch):
|
||||
"""Create observation object compatible with original OpenPI using the exact same inputs as LeRobot."""
|
||||
_batch_size = batch["observation.state"].shape[0]
|
||||
_device = batch["observation.state"].device
|
||||
|
||||
# Extract the exact same processed inputs that LeRobot uses
|
||||
images, img_masks, lang_tokens, lang_masks, token_ar_mask, token_loss_mask = (
|
||||
extract_lerobot_processed_inputs(lerobot_pi0, batch)
|
||||
)
|
||||
|
||||
# Convert images list to dict with original OpenPI keys
|
||||
image_dict = {
|
||||
"base_0_rgb": images[0],
|
||||
"left_wrist_0_rgb": images[1],
|
||||
"right_wrist_0_rgb": images[2],
|
||||
}
|
||||
|
||||
# Convert image masks list to dict with original OpenPI keys
|
||||
image_masks_dict = {
|
||||
"base_0_rgb": img_masks[0],
|
||||
"left_wrist_0_rgb": img_masks[1],
|
||||
"right_wrist_0_rgb": img_masks[2],
|
||||
}
|
||||
|
||||
return PI05Observation(
|
||||
state=batch["observation.state"],
|
||||
images=image_dict,
|
||||
image_masks=image_masks_dict,
|
||||
tokenized_prompt=lang_tokens,
|
||||
tokenized_prompt_mask=lang_masks,
|
||||
token_ar_mask=token_ar_mask,
|
||||
token_loss_mask=token_loss_mask,
|
||||
)
|
||||
|
||||
|
||||
def test_pi05_original_vs_lerobot():
|
||||
"""Test PI05 original implementation vs LeRobot implementation."""
|
||||
print("Initializing models...")
|
||||
lerobot_pi05, lerobot_preprocessor, lerobot_postprocessor = instantiate_lerobot_pi05(
|
||||
from_pretrained=True
|
||||
) # Load pretrained LeRobot model
|
||||
original_pi0 = instantiate_original_pi05(
|
||||
from_pretrained=True
|
||||
) # Load pretrained OpenPI model from HuggingFace Hub
|
||||
|
||||
print("Creating dummy data...")
|
||||
batch = create_dummy_data()
|
||||
batch_lerobot = deepcopy(batch)
|
||||
|
||||
# Test each model with its own preprocessing (more realistic end-to-end test)
|
||||
print("\nTest each model with its own preprocessing")
|
||||
print("Creating observation for OpenPI using OpenPI's own preprocessing...")
|
||||
pi0_obs_openpi = create_original_observation_with_openpi_preprocessing(batch)
|
||||
|
||||
print(f"Task prompt: '{batch['task'][0]}'")
|
||||
print(f"OpenPI tokenized prompt shape: {pi0_obs_openpi.tokenized_prompt.shape}")
|
||||
print(f"OpenPI image shapes: {[img.shape for img in pi0_obs_openpi.images.values()]}")
|
||||
print(f"OpenPI state shape: {pi0_obs_openpi.state.shape}")
|
||||
|
||||
print("Testing OpenPI with own preprocessing...")
|
||||
original_pi0.eval()
|
||||
torch.manual_seed(42) # Set seed for reproducibility
|
||||
batch_size = batch["observation.state"].shape[0]
|
||||
noise_shape = (batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM)
|
||||
fixed_noise = torch.randn(noise_shape, dtype=torch.float32, device=DEVICE)
|
||||
|
||||
with torch.no_grad():
|
||||
openpi_actions = original_pi0.sample_actions(
|
||||
device=DEVICE, observation=pi0_obs_openpi, noise=fixed_noise, num_steps=10
|
||||
)
|
||||
openpi_actions_unit = openpi_actions[:, 0, :]
|
||||
print(f"OpenPI (own preprocessing) Actions shape: {openpi_actions.shape}")
|
||||
print(f"OpenPI (own preprocessing) Actions unit shape: {openpi_actions_unit.shape}")
|
||||
print(f"OpenPI (own preprocessing) Actions mean: {openpi_actions.mean().item():.6f}")
|
||||
print(f"OpenPI (own preprocessing) Actions std: {openpi_actions.std().item():.6f}")
|
||||
|
||||
print("Testing LeRobot with own preprocessing...")
|
||||
lerobot_pi05.eval()
|
||||
torch.manual_seed(42) # Set the same seed
|
||||
|
||||
batch_lerobot_processed = lerobot_preprocessor(batch_lerobot)
|
||||
with torch.no_grad():
|
||||
lerobot_actions_own = lerobot_pi05.predict_action_chunk(
|
||||
batch_lerobot_processed
|
||||
) # batch_size, n_action_steps, action_dim
|
||||
lerobot_actions_unit = lerobot_actions_own[:, 0, :]
|
||||
print(f"LeRobot (own preprocessing) Actions shape: {lerobot_actions_own.shape}")
|
||||
print(f"LeRobot (own preprocessing) Actions unit shape: {lerobot_actions_unit.shape}")
|
||||
print(f"LeRobot (own preprocessing) Actions mean: {lerobot_actions_own.mean().item():.6f}")
|
||||
print(f"LeRobot (own preprocessing) Actions std: {lerobot_actions_own.std().item():.6f}")
|
||||
|
||||
print("\nComparing end-to-end implementations:")
|
||||
print(f"Actions close (atol=1e-4): {torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-4)}")
|
||||
print(f"Actions close (atol=1e-2): {torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-2)}")
|
||||
print(f"Max absolute difference: {torch.abs(lerobot_actions_own - openpi_actions).max().item():.6f}")
|
||||
|
||||
assert torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-4)
|
||||
assert torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-2)
|
||||
assert torch.abs(lerobot_actions_own - openpi_actions).max().item() < 1e-4
|
||||
410
tests/policies/pi0_pi05/test_pi0_original_vs_lerobot.py
Normal file
410
tests/policies/pi0_pi05/test_pi0_original_vs_lerobot.py
Normal file
@@ -0,0 +1,410 @@
|
||||
"""Test script to verify PI0 policy integration with LeRobot vs the original implementation, only meant to be run locally!"""
|
||||
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
# Skip if openpi or transformers is not available
|
||||
pytest.importorskip("openpi")
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
# Skip this entire module in CI
|
||||
pytestmark = pytest.mark.skipif(
|
||||
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
|
||||
reason="This test requires local OpenPI installation and is not meant for CI",
|
||||
)
|
||||
|
||||
from openpi.models_pytorch import preprocessing_pytorch as openpi_preprocessing # noqa: E402
|
||||
|
||||
# NOTE: Assumes PYTHONPATH is set to include OpenPI src as per instructions.
|
||||
from openpi.models_pytorch.pi0_pytorch import PI0Pytorch # noqa: E402
|
||||
from transformers import AutoTokenizer # noqa: E402
|
||||
|
||||
from lerobot.policies.pi0 import PI0Config, PI0Policy # noqa: E402
|
||||
from lerobot.policies.pi0.processor_pi0 import make_pi0_pre_post_processors # noqa: E402
|
||||
from lerobot.processor import PolicyAction, PolicyProcessorPipeline # noqa: E402
|
||||
|
||||
# TODO: ADDING DEFAULT IMAGES_FEATURES TO CONFIG
|
||||
DUMMY_ACTION_DIM = 32
|
||||
DUMMY_STATE_DIM = 32
|
||||
DUMMY_ACTION_HORIZON = 50
|
||||
DUMMY_MAX_TOKEN_LEN = 48 # Default for PI0 (non-pi05)
|
||||
DEVICE = "cpu" # Use CPU to avoid memory issues for testing
|
||||
|
||||
DUMMY_DATASET_STATS = {
|
||||
"observation.state": {
|
||||
"mean": torch.zeros(DUMMY_STATE_DIM),
|
||||
"std": torch.ones(DUMMY_STATE_DIM),
|
||||
"q01": torch.zeros(DUMMY_STATE_DIM),
|
||||
"q99": torch.ones(DUMMY_STATE_DIM),
|
||||
},
|
||||
"action": {
|
||||
"mean": torch.zeros(DUMMY_ACTION_DIM),
|
||||
"std": torch.ones(DUMMY_ACTION_DIM),
|
||||
"q01": torch.zeros(DUMMY_ACTION_DIM),
|
||||
"q99": torch.ones(DUMMY_ACTION_DIM),
|
||||
},
|
||||
"images": {
|
||||
"base_0_rgb": {
|
||||
"mean": torch.zeros(3, 224, 224),
|
||||
"std": torch.ones(3, 224, 224),
|
||||
"q01": torch.zeros(3, 224, 224),
|
||||
"q99": torch.ones(3, 224, 224),
|
||||
},
|
||||
"left_wrist_0_rgb": {
|
||||
"mean": torch.zeros(3, 224, 224),
|
||||
"std": torch.ones(3, 224, 224),
|
||||
"q01": torch.zeros(3, 224, 224),
|
||||
"q99": torch.ones(3, 224, 224),
|
||||
},
|
||||
"right_wrist_0_rgb": {
|
||||
"mean": torch.zeros(3, 224, 224),
|
||||
"std": torch.ones(3, 224, 224),
|
||||
"q01": torch.zeros(3, 224, 224),
|
||||
"q99": torch.ones(3, 224, 224),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class PI0BaseOriginalConfig:
|
||||
action_dim: int = DUMMY_ACTION_DIM
|
||||
action_horizon: int = DUMMY_ACTION_HORIZON
|
||||
paligemma_variant: str = "gemma_2b"
|
||||
action_expert_variant: str = "gemma_300m"
|
||||
precision: str = "float32"
|
||||
pi05: bool = False
|
||||
dtype: str = "float32"
|
||||
|
||||
|
||||
def instantiate_lerobot_pi0(
|
||||
from_pretrained: bool = False,
|
||||
) -> tuple[
|
||||
PI0Policy,
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
if from_pretrained:
|
||||
# Load the policy first
|
||||
policy = PI0Policy.from_pretrained(pretrained_name_or_path="lerobot/pi0_base", strict=True)
|
||||
else:
|
||||
config = PI0Config(max_action_dim=DUMMY_ACTION_DIM, max_state_dim=DUMMY_STATE_DIM, dtype="float32")
|
||||
policy = PI0Policy(config)
|
||||
|
||||
policy.to(DEVICE)
|
||||
policy.config.device = DEVICE
|
||||
preprocessor, postprocessor = make_pi0_pre_post_processors(
|
||||
config=policy.config, dataset_stats=DUMMY_DATASET_STATS
|
||||
)
|
||||
return (policy, preprocessor, postprocessor)
|
||||
|
||||
|
||||
def instantiate_original_pi0(from_pretrained: bool = False, model_path: str = None):
|
||||
config = PI0BaseOriginalConfig()
|
||||
policy = PI0Pytorch(config)
|
||||
|
||||
if from_pretrained:
|
||||
try:
|
||||
print("Loading converted PyTorch weights from HuggingFace Hub (lerobot/pi0_base)...")
|
||||
|
||||
# Download the model from HuggingFace Hub
|
||||
import safetensors.torch
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
# Download the entire repository
|
||||
if model_path and os.path.exists(model_path):
|
||||
cache_dir = model_path
|
||||
print(f"Using cached model from: {cache_dir}")
|
||||
else:
|
||||
cache_dir = snapshot_download(repo_id="lerobot/pi0_base", repo_type="model")
|
||||
print(f"Downloaded model to: {cache_dir}")
|
||||
|
||||
# Try to load safetensors format first
|
||||
model_file = os.path.join(cache_dir, "model.safetensors")
|
||||
if os.path.exists(model_file):
|
||||
state_dict = safetensors.torch.load_file(model_file)
|
||||
print(f"Loaded {len(state_dict)} parameters from safetensors")
|
||||
else:
|
||||
raise FileNotFoundError(f"No safetensors file found in {cache_dir}")
|
||||
|
||||
# Load the state dict into the model
|
||||
missing_keys, unexpected_keys = policy.load_state_dict(state_dict, strict=False)
|
||||
|
||||
if missing_keys:
|
||||
print(f"Missing keys: {len(missing_keys)}")
|
||||
if len(missing_keys) <= 5:
|
||||
for key in missing_keys:
|
||||
print(f" - {key}")
|
||||
else:
|
||||
for key in missing_keys[:5]:
|
||||
print(f" - {key}")
|
||||
print(f" ... and {len(missing_keys) - 5} more")
|
||||
|
||||
if unexpected_keys:
|
||||
print(f"Unexpected keys: {len(unexpected_keys)}")
|
||||
if len(unexpected_keys) <= 5:
|
||||
for key in unexpected_keys:
|
||||
print(f" - {key}")
|
||||
else:
|
||||
for key in unexpected_keys[:5]:
|
||||
print(f" - {key}")
|
||||
print(f" ... and {len(unexpected_keys) - 5} more")
|
||||
|
||||
if not missing_keys and not unexpected_keys:
|
||||
print("All pretrained weights loaded successfully!")
|
||||
else:
|
||||
print("Pretrained weights loaded with some missing/unexpected keys (this may be normal)")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to load pretrained weights: {e}")
|
||||
print(" Using randomly initialized weights...")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
policy.to(DEVICE)
|
||||
return policy
|
||||
|
||||
|
||||
def create_dummy_data():
|
||||
batch_size = 2 # Reduce batch size for testing
|
||||
device = DEVICE
|
||||
|
||||
# Use the exact same prompt for both implementations
|
||||
prompt = "Pick up the red block and place it in the bin"
|
||||
|
||||
batch = {
|
||||
"observation.state": torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=device),
|
||||
"action": torch.randn(
|
||||
batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM, dtype=torch.float32, device=device
|
||||
),
|
||||
# Create images in [0, 1] range as expected by LeRobot (will be converted to [-1, 1] internally)
|
||||
"observation.images.base_0_rgb": torch.rand(
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=device
|
||||
),
|
||||
"observation.images.left_wrist_0_rgb": torch.rand(
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=device
|
||||
),
|
||||
"observation.images.right_wrist_0_rgb": torch.rand(
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=device
|
||||
),
|
||||
# Add the task prompt for LeRobot - provide as list with single element to trigger expansion
|
||||
"task": [prompt for _ in range(batch_size)],
|
||||
}
|
||||
return batch
|
||||
|
||||
|
||||
def extract_lerobot_processed_inputs(lerobot_pi0, batch):
|
||||
"""Extract the exact same processed inputs that LeRobot uses internally."""
|
||||
# Get the tokenized language from LeRobot's internal method
|
||||
lang_tokens, lang_masks = lerobot_pi0._tokenize_language(batch)
|
||||
|
||||
# Get the preprocessed images from LeRobot's internal method
|
||||
images, img_masks = lerobot_pi0._preprocess_images(batch, train=False)
|
||||
|
||||
# Create dummy token_ar_mask and token_loss_mask for original implementation
|
||||
token_ar_mask = torch.zeros_like(lang_tokens, dtype=torch.int32)
|
||||
token_loss_mask = torch.ones_like(lang_masks, dtype=torch.bool)
|
||||
|
||||
return images, img_masks, lang_tokens, lang_masks, token_ar_mask, token_loss_mask
|
||||
|
||||
|
||||
class PI0Observation:
|
||||
"""Observation class that matches the original OpenPI format."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
state,
|
||||
images,
|
||||
image_masks,
|
||||
tokenized_prompt,
|
||||
tokenized_prompt_mask,
|
||||
token_ar_mask,
|
||||
token_loss_mask,
|
||||
):
|
||||
self.state = state
|
||||
self.images = images
|
||||
self.image_masks = image_masks
|
||||
self.tokenized_prompt = tokenized_prompt
|
||||
self.tokenized_prompt_mask = tokenized_prompt_mask
|
||||
self.token_ar_mask = token_ar_mask
|
||||
self.token_loss_mask = token_loss_mask
|
||||
|
||||
|
||||
def create_original_observation_with_openpi_preprocessing(batch):
|
||||
"""Create observation object for OpenPI using OpenPI's own preprocessing."""
|
||||
batch_size = batch["observation.state"].shape[0]
|
||||
device = batch["observation.state"].device
|
||||
|
||||
# Create tokenizer for OpenPI (same as LeRobot uses)
|
||||
tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
|
||||
|
||||
# Get task description
|
||||
if "task" in batch:
|
||||
tasks = batch["task"]
|
||||
if isinstance(tasks, str):
|
||||
# Single string: add newline if not present, then convert to list
|
||||
if not tasks.endswith("\n"):
|
||||
tasks = f"{tasks}\n"
|
||||
tasks = [tasks]
|
||||
elif isinstance(tasks, list) and all(isinstance(t, str) for t in tasks):
|
||||
# List of strings: add newline to each if not present
|
||||
tasks = [t if t.endswith("\n") else f"{t}\n" for t in tasks]
|
||||
if len(tasks) == 1:
|
||||
# Expand to batch size
|
||||
tasks = tasks * batch_size
|
||||
if len(tasks) != batch_size:
|
||||
raise ValueError(f"Expected batch size {batch_size}, got {len(tasks)}")
|
||||
# If task is neither string nor list of strings, leave unchanged
|
||||
else:
|
||||
# Default task if not provided
|
||||
tasks = ["Pick up the object\n"] * batch_size
|
||||
|
||||
# Tokenize with max_length padding to match OpenPI's expected format
|
||||
tokenized = tokenizer(
|
||||
tasks,
|
||||
padding="max_length",
|
||||
padding_side="right",
|
||||
truncation=True,
|
||||
max_length=DUMMY_MAX_TOKEN_LEN,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
lang_tokens = tokenized["input_ids"].to(device)
|
||||
lang_masks = tokenized["attention_mask"].to(device, dtype=torch.bool)
|
||||
|
||||
# Create dummy token_ar_mask and token_loss_mask for OpenPI
|
||||
token_ar_mask = torch.zeros_like(lang_tokens, dtype=torch.int32)
|
||||
token_loss_mask = torch.ones_like(lang_masks, dtype=torch.bool)
|
||||
|
||||
# Convert LeRobot images format to OpenPI format (convert [0,1] to [-1,1] range)
|
||||
image_dict = {
|
||||
"base_0_rgb": batch["observation.images.base_0_rgb"] * 2.0 - 1.0,
|
||||
"left_wrist_0_rgb": batch["observation.images.left_wrist_0_rgb"] * 2.0 - 1.0,
|
||||
"right_wrist_0_rgb": batch["observation.images.right_wrist_0_rgb"] * 2.0 - 1.0,
|
||||
}
|
||||
|
||||
# Create image masks (all ones for real images)
|
||||
image_masks_dict = {}
|
||||
for key in image_dict:
|
||||
image_masks_dict[key] = torch.ones(batch_size, dtype=torch.bool, device=device)
|
||||
|
||||
# Create raw observation object (before preprocessing)
|
||||
raw_observation = PI0Observation(
|
||||
state=batch["observation.state"],
|
||||
images=image_dict,
|
||||
image_masks=image_masks_dict,
|
||||
tokenized_prompt=lang_tokens,
|
||||
tokenized_prompt_mask=lang_masks,
|
||||
token_ar_mask=token_ar_mask,
|
||||
token_loss_mask=token_loss_mask,
|
||||
)
|
||||
|
||||
# Now use OpenPI's preprocessing
|
||||
processed_obs = openpi_preprocessing.preprocess_observation_pytorch(raw_observation, train=False)
|
||||
|
||||
return processed_obs
|
||||
|
||||
|
||||
def create_original_observation_from_lerobot(lerobot_pi0, batch):
|
||||
"""Create observation object compatible with original OpenPI using the exact same inputs as LeRobot."""
|
||||
_batch_size = batch["observation.state"].shape[0]
|
||||
_device = batch["observation.state"].device
|
||||
|
||||
# Extract the exact same processed inputs that LeRobot uses
|
||||
images, img_masks, lang_tokens, lang_masks, token_ar_mask, token_loss_mask = (
|
||||
extract_lerobot_processed_inputs(lerobot_pi0, batch)
|
||||
)
|
||||
|
||||
# Convert images list to dict with original OpenPI keys
|
||||
image_dict = {
|
||||
"base_0_rgb": images[0],
|
||||
"left_wrist_0_rgb": images[1],
|
||||
"right_wrist_0_rgb": images[2],
|
||||
}
|
||||
|
||||
# Convert image masks list to dict with original OpenPI keys
|
||||
image_masks_dict = {
|
||||
"base_0_rgb": img_masks[0],
|
||||
"left_wrist_0_rgb": img_masks[1],
|
||||
"right_wrist_0_rgb": img_masks[2],
|
||||
}
|
||||
|
||||
return PI0Observation(
|
||||
state=batch["observation.state"],
|
||||
images=image_dict,
|
||||
image_masks=image_masks_dict,
|
||||
tokenized_prompt=lang_tokens,
|
||||
tokenized_prompt_mask=lang_masks,
|
||||
token_ar_mask=token_ar_mask,
|
||||
token_loss_mask=token_loss_mask,
|
||||
)
|
||||
|
||||
|
||||
def test_pi0_original_vs_lerobot():
|
||||
"""Test PI0 original implementation vs LeRobot implementation."""
|
||||
print("Initializing models...")
|
||||
lerobot_pi0, lerobot_preprocessor, lerobot_postprocessor = instantiate_lerobot_pi0(
|
||||
from_pretrained=True
|
||||
) # Load pretrained LeRobot model
|
||||
original_pi0 = instantiate_original_pi0(
|
||||
from_pretrained=True
|
||||
) # Load pretrained OpenPI model from HuggingFace Hub
|
||||
|
||||
print("Creating dummy data...")
|
||||
batch = create_dummy_data()
|
||||
batch_lerobot = deepcopy(batch)
|
||||
|
||||
# Test each model with its own preprocessing (more realistic end-to-end test)
|
||||
print("\nTest each model with its own preprocessing")
|
||||
print("Creating observation for OpenPI using OpenPI's own preprocessing...")
|
||||
pi0_obs_openpi = create_original_observation_with_openpi_preprocessing(batch)
|
||||
|
||||
print(f"Task prompt: '{batch['task'][0]}'")
|
||||
print(f"OpenPI tokenized prompt shape: {pi0_obs_openpi.tokenized_prompt.shape}")
|
||||
print(f"OpenPI image shapes: {[img.shape for img in pi0_obs_openpi.images.values()]}")
|
||||
print(f"OpenPI state shape: {pi0_obs_openpi.state.shape}")
|
||||
|
||||
print("Testing OpenPI with own preprocessing...")
|
||||
original_pi0.eval()
|
||||
torch.manual_seed(42) # Set seed for reproducibility
|
||||
batch_size = batch["observation.state"].shape[0]
|
||||
noise_shape = (batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM)
|
||||
fixed_noise = torch.randn(noise_shape, dtype=torch.float32, device=DEVICE)
|
||||
|
||||
with torch.no_grad():
|
||||
openpi_actions = original_pi0.sample_actions(
|
||||
device=DEVICE, observation=pi0_obs_openpi, noise=fixed_noise, num_steps=10
|
||||
)
|
||||
openpi_actions_unit = openpi_actions[:, 0, :]
|
||||
print(f"OpenPI (own preprocessing) Actions shape: {openpi_actions.shape}")
|
||||
print(f"OpenPI (own preprocessing) Actions unit shape: {openpi_actions_unit.shape}")
|
||||
print(f"OpenPI (own preprocessing) Actions mean: {openpi_actions.mean().item():.6f}")
|
||||
print(f"OpenPI (own preprocessing) Actions std: {openpi_actions.std().item():.6f}")
|
||||
|
||||
print("Testing LeRobot with own preprocessing...")
|
||||
lerobot_pi0.eval()
|
||||
torch.manual_seed(42) # Set the same seed
|
||||
|
||||
batch_lerobot_processed = lerobot_preprocessor(batch_lerobot)
|
||||
with torch.no_grad():
|
||||
lerobot_actions_own = lerobot_pi0.predict_action_chunk(
|
||||
batch_lerobot_processed
|
||||
) # batch_size, n_action_steps, action_dim
|
||||
lerobot_actions_unit = lerobot_actions_own[:, 0, :]
|
||||
print(f"LeRobot (own preprocessing) Actions shape: {lerobot_actions_own.shape}")
|
||||
print(f"LeRobot (own preprocessing) Actions unit shape: {lerobot_actions_unit.shape}")
|
||||
print(f"LeRobot (own preprocessing) Actions mean: {lerobot_actions_own.mean().item():.6f}")
|
||||
print(f"LeRobot (own preprocessing) Actions std: {lerobot_actions_own.std().item():.6f}")
|
||||
|
||||
print("\nComparing end-to-end implementations:")
|
||||
print(f"Actions close (atol=1e-4): {torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-4)}")
|
||||
print(f"Actions close (atol=1e-2): {torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-2)}")
|
||||
print(f"Max absolute difference: {torch.abs(lerobot_actions_own - openpi_actions).max().item():.6f}")
|
||||
|
||||
assert torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-4)
|
||||
assert torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-2)
|
||||
assert torch.abs(lerobot_actions_own - openpi_actions).max().item() < 1e-4
|
||||
@@ -166,6 +166,226 @@ def test_min_max_normalization(observation_normalizer):
|
||||
assert torch.allclose(normalized_obs[OBS_STATE], expected_state, atol=1e-6)
|
||||
|
||||
|
||||
def test_quantile_normalization():
|
||||
"""Test QUANTILES mode using 1st-99th percentiles."""
|
||||
features = {
|
||||
"observation.state": PolicyFeature(FeatureType.STATE, (2,)),
|
||||
}
|
||||
norm_map = {
|
||||
FeatureType.STATE: NormalizationMode.QUANTILES,
|
||||
}
|
||||
stats = {
|
||||
"observation.state": {
|
||||
"q01": np.array([0.1, -0.8]), # 1st percentile
|
||||
"q99": np.array([0.9, 0.8]), # 99th percentile
|
||||
},
|
||||
}
|
||||
|
||||
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
|
||||
|
||||
observation = {
|
||||
"observation.state": torch.tensor([0.5, 0.0]),
|
||||
}
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
normalized_transition = normalizer(transition)
|
||||
normalized_obs = normalized_transition[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check quantile normalization to [-1, 1]
|
||||
# For state[0]: 2 * (0.5 - 0.1) / (0.9 - 0.1) - 1 = 2 * 0.4 / 0.8 - 1 = 0.0
|
||||
# For state[1]: 2 * (0.0 - (-0.8)) / (0.8 - (-0.8)) - 1 = 2 * 0.8 / 1.6 - 1 = 0.0
|
||||
expected_state = torch.tensor([0.0, 0.0])
|
||||
assert torch.allclose(normalized_obs["observation.state"], expected_state, atol=1e-6)
|
||||
|
||||
|
||||
def test_quantile10_normalization():
|
||||
"""Test QUANTILE10 mode using 10th-90th percentiles."""
|
||||
features = {
|
||||
"observation.state": PolicyFeature(FeatureType.STATE, (2,)),
|
||||
}
|
||||
norm_map = {
|
||||
FeatureType.STATE: NormalizationMode.QUANTILE10,
|
||||
}
|
||||
stats = {
|
||||
"observation.state": {
|
||||
"q10": np.array([0.2, -0.6]), # 10th percentile
|
||||
"q90": np.array([0.8, 0.6]), # 90th percentile
|
||||
},
|
||||
}
|
||||
|
||||
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
|
||||
|
||||
observation = {
|
||||
"observation.state": torch.tensor([0.5, 0.0]),
|
||||
}
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
normalized_transition = normalizer(transition)
|
||||
normalized_obs = normalized_transition[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check quantile normalization to [-1, 1]
|
||||
# For state[0]: 2 * (0.5 - 0.2) / (0.8 - 0.2) - 1 = 2 * 0.3 / 0.6 - 1 = 0.0
|
||||
# For state[1]: 2 * (0.0 - (-0.6)) / (0.6 - (-0.6)) - 1 = 2 * 0.6 / 1.2 - 1 = 0.0
|
||||
expected_state = torch.tensor([0.0, 0.0])
|
||||
assert torch.allclose(normalized_obs["observation.state"], expected_state, atol=1e-6)
|
||||
|
||||
|
||||
def test_quantile_unnormalization():
|
||||
"""Test that quantile normalization can be reversed properly."""
|
||||
features = {
|
||||
"action": PolicyFeature(FeatureType.ACTION, (2,)),
|
||||
}
|
||||
norm_map = {
|
||||
FeatureType.ACTION: NormalizationMode.QUANTILES,
|
||||
}
|
||||
stats = {
|
||||
"action": {
|
||||
"q01": np.array([0.1, -0.8]),
|
||||
"q99": np.array([0.9, 0.8]),
|
||||
},
|
||||
}
|
||||
|
||||
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
|
||||
unnormalizer = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
|
||||
|
||||
# Test round-trip normalization
|
||||
original_action = torch.tensor([0.5, 0.0])
|
||||
transition = create_transition(action=original_action)
|
||||
|
||||
# Normalize then unnormalize
|
||||
normalized = normalizer(transition)
|
||||
unnormalized = unnormalizer(normalized)
|
||||
|
||||
# Should recover original values
|
||||
recovered_action = unnormalized[TransitionKey.ACTION]
|
||||
assert torch.allclose(recovered_action, original_action, atol=1e-6)
|
||||
|
||||
|
||||
def test_quantile_division_by_zero():
|
||||
"""Test quantile normalization handles edge case where q01 == q99."""
|
||||
features = {
|
||||
"observation.state": PolicyFeature(FeatureType.STATE, (1,)),
|
||||
}
|
||||
norm_map = {
|
||||
FeatureType.STATE: NormalizationMode.QUANTILES,
|
||||
}
|
||||
stats = {
|
||||
"observation.state": {
|
||||
"q01": np.array([0.5]), # Same value
|
||||
"q99": np.array([0.5]), # Same value -> division by zero case
|
||||
},
|
||||
}
|
||||
|
||||
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
|
||||
|
||||
observation = {
|
||||
"observation.state": torch.tensor([0.5]),
|
||||
}
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
# Should not crash and should handle gracefully
|
||||
normalized_transition = normalizer(transition)
|
||||
normalized_obs = normalized_transition[TransitionKey.OBSERVATION]
|
||||
|
||||
# When quantiles are identical, should normalize to 0 (due to epsilon handling)
|
||||
assert torch.isfinite(normalized_obs["observation.state"]).all()
|
||||
|
||||
|
||||
def test_quantile_partial_stats():
|
||||
"""Test that quantile normalization handles missing quantile stats by raising."""
|
||||
features = {
|
||||
"observation.state": PolicyFeature(FeatureType.STATE, (2,)),
|
||||
}
|
||||
norm_map = {
|
||||
FeatureType.STATE: NormalizationMode.QUANTILES,
|
||||
}
|
||||
|
||||
# Missing q99 - should pass through unchanged
|
||||
stats_partial = {
|
||||
"observation.state": {
|
||||
"q01": np.array([0.1, -0.8]), # Only q01, missing q99
|
||||
},
|
||||
}
|
||||
|
||||
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats_partial)
|
||||
|
||||
observation = {
|
||||
"observation.state": torch.tensor([0.5, 0.0]),
|
||||
}
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
with pytest.raises(ValueError, match="QUANTILES normalization mode requires q01 and q99 stats"):
|
||||
_ = normalizer(transition)
|
||||
|
||||
|
||||
def test_quantile_mixed_with_other_modes():
|
||||
"""Test quantile normalization mixed with other normalization modes."""
|
||||
features = {
|
||||
"observation.image": PolicyFeature(FeatureType.VISUAL, (3,)),
|
||||
"observation.state": PolicyFeature(FeatureType.STATE, (2,)),
|
||||
"action": PolicyFeature(FeatureType.ACTION, (2,)),
|
||||
}
|
||||
norm_map = {
|
||||
FeatureType.VISUAL: NormalizationMode.MEAN_STD, # Standard normalization
|
||||
FeatureType.STATE: NormalizationMode.QUANTILES, # Quantile normalization
|
||||
FeatureType.ACTION: NormalizationMode.QUANTILE10, # Different quantile mode
|
||||
}
|
||||
stats = {
|
||||
"observation.image": {"mean": [0.5, 0.5, 0.5], "std": [0.2, 0.2, 0.2]},
|
||||
"observation.state": {"q01": [0.1, -0.8], "q99": [0.9, 0.8]},
|
||||
"action": {"q10": [0.2, -0.6], "q90": [0.8, 0.6]},
|
||||
}
|
||||
|
||||
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
|
||||
|
||||
observation = {
|
||||
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
|
||||
"observation.state": torch.tensor([0.5, 0.0]), # Should use QUANTILES
|
||||
}
|
||||
action = torch.tensor([0.5, 0.0]) # Should use QUANTILE10
|
||||
transition = create_transition(observation=observation, action=action)
|
||||
|
||||
normalized_transition = normalizer(transition)
|
||||
normalized_obs = normalized_transition[TransitionKey.OBSERVATION]
|
||||
normalized_action = normalized_transition[TransitionKey.ACTION]
|
||||
|
||||
# Image should be mean/std normalized: (0.7 - 0.5) / 0.2 = 1.0, etc.
|
||||
expected_image = (torch.tensor([0.7, 0.5, 0.3]) - 0.5) / 0.2
|
||||
assert torch.allclose(normalized_obs["observation.image"], expected_image)
|
||||
|
||||
# State should be quantile normalized: 2 * (0.5 - 0.1) / (0.9 - 0.1) - 1 = 0.0, etc.
|
||||
expected_state = torch.tensor([0.0, 0.0])
|
||||
assert torch.allclose(normalized_obs["observation.state"], expected_state, atol=1e-6)
|
||||
|
||||
# Action should be quantile10 normalized: 2 * (0.5 - 0.2) / (0.8 - 0.2) - 1 = 0.0, etc.
|
||||
expected_action = torch.tensor([0.0, 0.0])
|
||||
assert torch.allclose(normalized_action, expected_action, atol=1e-6)
|
||||
|
||||
|
||||
def test_quantile_with_missing_stats():
|
||||
"""Test that quantile normalization handles completely missing stats gracefully."""
|
||||
features = {
|
||||
"observation.state": PolicyFeature(FeatureType.STATE, (2,)),
|
||||
}
|
||||
norm_map = {
|
||||
FeatureType.STATE: NormalizationMode.QUANTILES,
|
||||
}
|
||||
stats = {} # No stats provided
|
||||
|
||||
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
|
||||
|
||||
observation = {
|
||||
"observation.state": torch.tensor([0.5, 0.0]),
|
||||
}
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
normalized_transition = normalizer(transition)
|
||||
normalized_obs = normalized_transition[TransitionKey.OBSERVATION]
|
||||
|
||||
# Should pass through unchanged when no stats available
|
||||
assert torch.allclose(normalized_obs["observation.state"], observation["observation.state"])
|
||||
|
||||
|
||||
def test_selective_normalization(observation_stats):
|
||||
features = _create_observation_features()
|
||||
norm_map = _create_observation_norm_map()
|
||||
@@ -547,7 +767,7 @@ def test_empty_stats():
|
||||
|
||||
|
||||
def test_partial_stats():
|
||||
"""If statistics are incomplete, the value should pass through unchanged."""
|
||||
"""If statistics are incomplete, we should raise."""
|
||||
stats = {OBS_IMAGE: {"mean": [0.5]}} # Missing std / (min,max)
|
||||
features = {OBS_IMAGE: PolicyFeature(FeatureType.VISUAL, (3, 96, 96))}
|
||||
norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD}
|
||||
@@ -555,8 +775,8 @@ def test_partial_stats():
|
||||
observation = {OBS_IMAGE: torch.tensor([0.7])}
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
processed = normalizer(transition)[TransitionKey.OBSERVATION]
|
||||
assert torch.allclose(processed[OBS_IMAGE], observation[OBS_IMAGE])
|
||||
with pytest.raises(ValueError, match="MEAN_STD normalization mode requires mean and std stats"):
|
||||
_ = normalizer(transition)[TransitionKey.OBSERVATION]
|
||||
|
||||
|
||||
def test_missing_action_stats_no_error():
|
||||
|
||||
@@ -1,424 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for PI0 policy processor."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.policies.pi0.processor_pi0 import Pi0NewLineProcessor, make_pi0_pre_post_processors
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
EnvTransition,
|
||||
NormalizerProcessorStep,
|
||||
ProcessorStep,
|
||||
RenameObservationsProcessorStep,
|
||||
TransitionKey,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import create_transition, transition_to_batch
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE
|
||||
|
||||
|
||||
class MockTokenizerProcessorStep(ProcessorStep):
|
||||
"""Mock tokenizer processor step for testing."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
# Accept any arguments to mimic the real TokenizerProcessorStep interface
|
||||
pass
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
# Pass through transition unchanged
|
||||
return transition
|
||||
|
||||
def transform_features(self, features):
|
||||
# Pass through features unchanged
|
||||
return features
|
||||
|
||||
|
||||
def create_default_config():
|
||||
"""Create a default PI0 configuration for testing."""
|
||||
config = PI0Config()
|
||||
config.input_features = {
|
||||
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,)),
|
||||
OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
config.output_features = {
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(6,)),
|
||||
}
|
||||
config.normalization_mapping = {
|
||||
FeatureType.STATE: NormalizationMode.MEAN_STD,
|
||||
FeatureType.VISUAL: NormalizationMode.IDENTITY,
|
||||
FeatureType.ACTION: NormalizationMode.MIN_MAX,
|
||||
}
|
||||
config.device = "cpu"
|
||||
config.tokenizer_max_length = 128
|
||||
return config
|
||||
|
||||
|
||||
def create_default_stats():
|
||||
"""Create default dataset statistics for testing."""
|
||||
return {
|
||||
OBS_STATE: {"mean": torch.zeros(10), "std": torch.ones(10)},
|
||||
OBS_IMAGE: {}, # No normalization for images
|
||||
ACTION: {"min": torch.full((6,), -1.0), "max": torch.ones(6)},
|
||||
}
|
||||
|
||||
|
||||
def test_make_pi0_processor_basic():
|
||||
"""Test basic creation of PI0 processor."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep", MockTokenizerProcessorStep):
|
||||
preprocessor, postprocessor = make_pi0_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
|
||||
# Check processor names
|
||||
assert preprocessor.name == "policy_preprocessor"
|
||||
assert postprocessor.name == "policy_postprocessor"
|
||||
|
||||
# Check steps in preprocessor
|
||||
assert len(preprocessor.steps) == 6
|
||||
assert isinstance(preprocessor.steps[0], RenameObservationsProcessorStep)
|
||||
assert isinstance(preprocessor.steps[1], AddBatchDimensionProcessorStep)
|
||||
assert isinstance(preprocessor.steps[2], Pi0NewLineProcessor)
|
||||
# Step 3 would be TokenizerProcessorStep but it's mocked
|
||||
assert isinstance(preprocessor.steps[4], DeviceProcessorStep)
|
||||
assert isinstance(preprocessor.steps[5], NormalizerProcessorStep)
|
||||
|
||||
# Check steps in postprocessor
|
||||
assert len(postprocessor.steps) == 2
|
||||
assert isinstance(postprocessor.steps[0], UnnormalizerProcessorStep)
|
||||
assert isinstance(postprocessor.steps[1], DeviceProcessorStep)
|
||||
|
||||
|
||||
def test_pi0_newline_processor_single_task():
|
||||
"""Test Pi0NewLineProcessor with single task string."""
|
||||
processor = Pi0NewLineProcessor()
|
||||
|
||||
# Test with task that doesn't have newline
|
||||
transition = create_transition(complementary_data={"task": "test task"})
|
||||
result = processor(transition)
|
||||
assert result[TransitionKey.COMPLEMENTARY_DATA]["task"] == "test task\n"
|
||||
|
||||
# Test with task that already has newline
|
||||
transition = create_transition(complementary_data={"task": "test task\n"})
|
||||
result = processor(transition)
|
||||
assert result[TransitionKey.COMPLEMENTARY_DATA]["task"] == "test task\n"
|
||||
|
||||
|
||||
def test_pi0_newline_processor_list_of_tasks():
|
||||
"""Test Pi0NewLineProcessor with list of task strings."""
|
||||
processor = Pi0NewLineProcessor()
|
||||
|
||||
# Test with list of tasks
|
||||
tasks = ["task1", "task2\n", "task3"]
|
||||
transition = create_transition(complementary_data={"task": tasks})
|
||||
result = processor(transition)
|
||||
expected = ["task1\n", "task2\n", "task3\n"]
|
||||
assert result[TransitionKey.COMPLEMENTARY_DATA]["task"] == expected
|
||||
|
||||
|
||||
def test_pi0_newline_processor_empty_transition():
|
||||
"""Test Pi0NewLineProcessor with empty transition."""
|
||||
processor = Pi0NewLineProcessor()
|
||||
|
||||
# Test with no complementary_data
|
||||
transition = create_transition()
|
||||
result = processor(transition)
|
||||
assert result == transition
|
||||
|
||||
# Test with complementary_data but no task
|
||||
transition = create_transition(complementary_data={"other": "data"})
|
||||
result = processor(transition)
|
||||
assert result == transition
|
||||
|
||||
# Test with None task
|
||||
transition = create_transition(complementary_data={"task": None})
|
||||
result = processor(transition)
|
||||
assert result == transition
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_pi0_processor_cuda():
|
||||
"""Test PI0 processor with CUDA device."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda"
|
||||
stats = create_default_stats()
|
||||
|
||||
# Mock the tokenizer processor to act as pass-through
|
||||
class MockTokenizerProcessorStep(ProcessorStep):
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def __call__(self, transition):
|
||||
return transition
|
||||
|
||||
def state_dict(self):
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state):
|
||||
pass
|
||||
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
def get_config(self):
|
||||
return {"tokenizer_name": "google/paligemma-3b-pt-224"}
|
||||
|
||||
def transform_features(self, features):
|
||||
return features
|
||||
|
||||
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep", MockTokenizerProcessorStep):
|
||||
preprocessor, postprocessor = make_pi0_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
|
||||
# Create CPU data
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(10),
|
||||
OBS_IMAGE: torch.randn(3, 224, 224),
|
||||
}
|
||||
action = torch.randn(6)
|
||||
transition = create_transition(observation, action, complementary_data={"task": "test task"})
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(batch)
|
||||
|
||||
# Check that data is on CUDA
|
||||
assert processed[OBS_STATE].device.type == "cuda"
|
||||
assert processed[OBS_IMAGE].device.type == "cuda"
|
||||
assert processed[TransitionKey.ACTION.value].device.type == "cuda"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_pi0_processor_accelerate_scenario():
|
||||
"""Test PI0 processor in simulated Accelerate scenario."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda:0"
|
||||
stats = create_default_stats()
|
||||
|
||||
# Mock the tokenizer processor to act as pass-through
|
||||
class MockTokenizerProcessorStep(ProcessorStep):
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def __call__(self, transition):
|
||||
return transition
|
||||
|
||||
def state_dict(self):
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state):
|
||||
pass
|
||||
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
def get_config(self):
|
||||
return {"tokenizer_name": "google/paligemma-3b-pt-224"}
|
||||
|
||||
def transform_features(self, features):
|
||||
return features
|
||||
|
||||
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep", MockTokenizerProcessorStep):
|
||||
preprocessor, postprocessor = make_pi0_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
|
||||
# Simulate Accelerate: data already on GPU and batched
|
||||
device = torch.device("cuda:0")
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(1, 10).to(device),
|
||||
OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device),
|
||||
}
|
||||
action = torch.randn(1, 6).to(device)
|
||||
transition = create_transition(observation, action, complementary_data={"task": ["test task"]})
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(batch)
|
||||
|
||||
# Check that data stays on same GPU
|
||||
assert processed[OBS_STATE].device == device
|
||||
assert processed[OBS_IMAGE].device == device
|
||||
assert processed[TransitionKey.ACTION.value].device == device
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs")
|
||||
def test_pi0_processor_multi_gpu():
|
||||
"""Test PI0 processor with multi-GPU setup."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda:0"
|
||||
stats = create_default_stats()
|
||||
|
||||
# Mock the tokenizer processor to act as pass-through
|
||||
class MockTokenizerProcessorStep(ProcessorStep):
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def __call__(self, transition):
|
||||
return transition
|
||||
|
||||
def state_dict(self):
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state):
|
||||
pass
|
||||
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
def get_config(self):
|
||||
return {"tokenizer_name": "google/paligemma-3b-pt-224"}
|
||||
|
||||
def transform_features(self, features):
|
||||
return features
|
||||
|
||||
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep", MockTokenizerProcessorStep):
|
||||
preprocessor, postprocessor = make_pi0_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
|
||||
# Simulate data on different GPU
|
||||
device = torch.device("cuda:1")
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(1, 10).to(device),
|
||||
OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device),
|
||||
}
|
||||
action = torch.randn(1, 6).to(device)
|
||||
transition = create_transition(observation, action, complementary_data={"task": ["test task"]})
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(batch)
|
||||
|
||||
# Check that data stays on cuda:1
|
||||
assert processed[OBS_STATE].device == device
|
||||
assert processed[OBS_IMAGE].device == device
|
||||
assert processed[TransitionKey.ACTION.value].device == device
|
||||
|
||||
|
||||
def test_pi0_processor_without_stats():
|
||||
"""Test PI0 processor creation without dataset statistics."""
|
||||
config = create_default_config()
|
||||
|
||||
# Mock the tokenizer processor
|
||||
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep", MockTokenizerProcessorStep):
|
||||
preprocessor, postprocessor = make_pi0_pre_post_processors(
|
||||
config,
|
||||
dataset_stats=None,
|
||||
)
|
||||
|
||||
# Should still create processors
|
||||
assert preprocessor is not None
|
||||
assert postprocessor is not None
|
||||
|
||||
|
||||
def test_pi0_newline_processor_state_dict():
|
||||
"""Test Pi0NewLineProcessor state dict methods."""
|
||||
processor = Pi0NewLineProcessor()
|
||||
|
||||
# Test state_dict (should be empty)
|
||||
state = processor.state_dict()
|
||||
assert state == {}
|
||||
|
||||
# Test load_state_dict (should do nothing)
|
||||
processor.load_state_dict({})
|
||||
|
||||
# Test reset (should do nothing)
|
||||
processor.reset()
|
||||
|
||||
# Test get_config
|
||||
config = processor.get_config()
|
||||
assert config == {}
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_pi0_processor_bfloat16_device_float32_normalizer():
|
||||
"""Test: DeviceProcessor(bfloat16) + NormalizerProcessor(float32) → output bfloat16 via automatic adaptation"""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
config.device = "cuda"
|
||||
|
||||
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep", MockTokenizerProcessorStep):
|
||||
preprocessor, _ = make_pi0_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
|
||||
# Modify the pipeline to use bfloat16 device processor with float32 normalizer
|
||||
modified_steps = []
|
||||
for step in preprocessor.steps:
|
||||
if isinstance(step, DeviceProcessorStep):
|
||||
# Device processor converts to bfloat16
|
||||
modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="bfloat16"))
|
||||
elif isinstance(step, NormalizerProcessorStep):
|
||||
# Normalizer stays configured as float32 (will auto-adapt to bfloat16)
|
||||
norm_step = step # Now type checker knows this is NormalizerProcessorStep
|
||||
modified_steps.append(
|
||||
NormalizerProcessorStep(
|
||||
features=norm_step.features,
|
||||
norm_map=norm_step.norm_map,
|
||||
stats=norm_step.stats,
|
||||
device=config.device,
|
||||
dtype=torch.float32, # Deliberately configured as float32
|
||||
)
|
||||
)
|
||||
else:
|
||||
modified_steps.append(step)
|
||||
preprocessor.steps = modified_steps
|
||||
|
||||
# Verify initial normalizer configuration (PI0 has NormalizerProcessorStep at index 5)
|
||||
normalizer_step = preprocessor.steps[5] # NormalizerProcessorStep
|
||||
assert normalizer_step.dtype == torch.float32
|
||||
|
||||
# Create test data with both state and visual observations
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(10, dtype=torch.float32), # PI0 expects size 10
|
||||
OBS_IMAGE: torch.randn(3, 224, 224, dtype=torch.float32),
|
||||
}
|
||||
action = torch.randn(6, dtype=torch.float32) # PI0 expects size 6
|
||||
transition = create_transition(
|
||||
observation, action, complementary_data={"task": "test bfloat16 adaptation"}
|
||||
)
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
# Process through full pipeline
|
||||
processed = preprocessor(batch)
|
||||
|
||||
# Verify: DeviceProcessor → bfloat16, NormalizerProcessor adapts → final output is bfloat16
|
||||
assert processed[OBS_STATE].dtype == torch.bfloat16
|
||||
assert processed[OBS_IMAGE].dtype == torch.bfloat16 # IDENTITY normalization still gets dtype conversion
|
||||
assert processed[TransitionKey.ACTION.value].dtype == torch.bfloat16
|
||||
|
||||
# Verify normalizer automatically adapted its internal state
|
||||
assert normalizer_step.dtype == torch.bfloat16
|
||||
# Check state stats (has normalization)
|
||||
for stat_tensor in normalizer_step._tensor_stats[OBS_STATE].values():
|
||||
assert stat_tensor.dtype == torch.bfloat16
|
||||
# OBS_IMAGE uses IDENTITY normalization, so no stats to check
|
||||
Reference in New Issue
Block a user