forked from tangger/lerobot
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
@@ -56,7 +56,9 @@ def test_same_attributes_defined(lerobot_dataset_factory, tmp_path):
|
||||
# Instantiate both ways
|
||||
robot = make_robot("koch", mock=True)
|
||||
root_create = tmp_path / "create"
|
||||
dataset_create = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=root_create)
|
||||
dataset_create = LeRobotDataset.create(
|
||||
repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=root_create
|
||||
)
|
||||
|
||||
root_init = tmp_path / "init"
|
||||
dataset_init = lerobot_dataset_factory(root=root_init)
|
||||
@@ -102,7 +104,16 @@ def test_dataset_initialization(lerobot_dataset_factory, tmp_path):
|
||||
@pytest.mark.parametrize(
|
||||
"env_name, repo_id, policy_name",
|
||||
lerobot.env_dataset_policy_triplets
|
||||
+ [("aloha", ["lerobot/aloha_sim_insertion_human", "lerobot/aloha_sim_transfer_cube_human"], "act")],
|
||||
+ [
|
||||
(
|
||||
"aloha",
|
||||
[
|
||||
"lerobot/aloha_sim_insertion_human",
|
||||
"lerobot/aloha_sim_transfer_cube_human",
|
||||
],
|
||||
"act",
|
||||
)
|
||||
],
|
||||
)
|
||||
def test_factory(env_name, repo_id, policy_name):
|
||||
"""
|
||||
@@ -220,7 +231,9 @@ def test_compute_stats_on_xarm():
|
||||
# Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched
|
||||
# computation of the statistics. While doing this, we also make sure it works when we don't divide the
|
||||
# dataset into even batches.
|
||||
computed_stats = compute_stats(dataset, batch_size=int(len(dataset) * 0.25), num_workers=0)
|
||||
computed_stats = compute_stats(
|
||||
dataset, batch_size=int(len(dataset) * 0.25), num_workers=0
|
||||
)
|
||||
|
||||
# get einops patterns to aggregate batches and compute statistics
|
||||
stats_patterns = get_stats_einops_patterns(dataset)
|
||||
@@ -241,7 +254,9 @@ def test_compute_stats_on_xarm():
|
||||
expected_stats[k] = {}
|
||||
expected_stats[k]["mean"] = einops.reduce(full_batch[k], pattern, "mean")
|
||||
expected_stats[k]["std"] = torch.sqrt(
|
||||
einops.reduce((full_batch[k] - expected_stats[k]["mean"]) ** 2, pattern, "mean")
|
||||
einops.reduce(
|
||||
(full_batch[k] - expected_stats[k]["mean"]) ** 2, pattern, "mean"
|
||||
)
|
||||
)
|
||||
expected_stats[k]["min"] = einops.reduce(full_batch[k], pattern, "min")
|
||||
expected_stats[k]["max"] = einops.reduce(full_batch[k], pattern, "max")
|
||||
@@ -286,7 +301,9 @@ def test_flatten_unflatten_dict():
|
||||
d = unflatten_dict(flatten_dict(d))
|
||||
|
||||
# test equality between nested dicts
|
||||
assert json.dumps(original_d, sort_keys=True) == json.dumps(d, sort_keys=True), f"{original_d} != {d}"
|
||||
assert json.dumps(original_d, sort_keys=True) == json.dumps(
|
||||
d, sort_keys=True
|
||||
), f"{original_d} != {d}"
|
||||
|
||||
|
||||
@pytest.mark.skip("TODO after v2 migration / removing hydra")
|
||||
@@ -333,7 +350,13 @@ def test_backward_compatibility(repo_id):
|
||||
load_and_compare(i + 1)
|
||||
|
||||
# test 2 frames at the middle of first episode
|
||||
i = int((dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2)
|
||||
i = int(
|
||||
(
|
||||
dataset.episode_data_index["to"][0].item()
|
||||
- dataset.episode_data_index["from"][0].item()
|
||||
)
|
||||
/ 2
|
||||
)
|
||||
load_and_compare(i)
|
||||
load_and_compare(i + 1)
|
||||
|
||||
@@ -370,23 +393,40 @@ def test_aggregate_stats():
|
||||
data_c = torch.rand(20, dtype=torch.float32)
|
||||
|
||||
hf_dataset_1 = Dataset.from_dict(
|
||||
{"a": data_a[:10], "b": data_b[:10], "c": data_c[:10], "index": torch.arange(10)}
|
||||
{
|
||||
"a": data_a[:10],
|
||||
"b": data_b[:10],
|
||||
"c": data_c[:10],
|
||||
"index": torch.arange(10),
|
||||
}
|
||||
)
|
||||
hf_dataset_1.set_transform(hf_transform_to_torch)
|
||||
hf_dataset_2 = Dataset.from_dict({"a": data_a[10:20], "b": data_b[10:], "index": torch.arange(10)})
|
||||
hf_dataset_2 = Dataset.from_dict(
|
||||
{"a": data_a[10:20], "b": data_b[10:], "index": torch.arange(10)}
|
||||
)
|
||||
hf_dataset_2.set_transform(hf_transform_to_torch)
|
||||
hf_dataset_3 = Dataset.from_dict({"a": data_a[20:], "c": data_c[10:], "index": torch.arange(10)})
|
||||
hf_dataset_3 = Dataset.from_dict(
|
||||
{"a": data_a[20:], "c": data_c[10:], "index": torch.arange(10)}
|
||||
)
|
||||
hf_dataset_3.set_transform(hf_transform_to_torch)
|
||||
dataset_1 = LeRobotDataset.from_preloaded("d1", hf_dataset=hf_dataset_1)
|
||||
dataset_1.stats = compute_stats(dataset_1, batch_size=len(hf_dataset_1), num_workers=0)
|
||||
dataset_1.stats = compute_stats(
|
||||
dataset_1, batch_size=len(hf_dataset_1), num_workers=0
|
||||
)
|
||||
dataset_2 = LeRobotDataset.from_preloaded("d2", hf_dataset=hf_dataset_2)
|
||||
dataset_2.stats = compute_stats(dataset_2, batch_size=len(hf_dataset_2), num_workers=0)
|
||||
dataset_2.stats = compute_stats(
|
||||
dataset_2, batch_size=len(hf_dataset_2), num_workers=0
|
||||
)
|
||||
dataset_3 = LeRobotDataset.from_preloaded("d3", hf_dataset=hf_dataset_3)
|
||||
dataset_3.stats = compute_stats(dataset_3, batch_size=len(hf_dataset_3), num_workers=0)
|
||||
dataset_3.stats = compute_stats(
|
||||
dataset_3, batch_size=len(hf_dataset_3), num_workers=0
|
||||
)
|
||||
stats = aggregate_stats([dataset_1, dataset_2, dataset_3])
|
||||
for data_key, data in zip(["a", "b", "c"], [data_a, data_b, data_c], strict=True):
|
||||
for agg_fn in ["mean", "min", "max"]:
|
||||
assert torch.allclose(stats[data_key][agg_fn], einops.reduce(data, "n -> 1", agg_fn))
|
||||
assert torch.allclose(
|
||||
stats[data_key][agg_fn], einops.reduce(data, "n -> 1", agg_fn)
|
||||
)
|
||||
assert torch.allclose(stats[data_key]["std"], torch.std(data, correction=0))
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user