[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
committed by
Michel Aractingi
parent
2abbd60a0d
commit
0ea27704f6
@@ -46,9 +46,7 @@ def train_evaluate_multiclass_classifier():
|
||||
logging.info(
|
||||
f"Start multiclass classifier train eval with {DEVICE} device, batch size {BATCH_SIZE}, learning rate {LR}"
|
||||
)
|
||||
multiclass_config = ClassifierConfig(
|
||||
model_name="microsoft/resnet-18", device=DEVICE, num_classes=10
|
||||
)
|
||||
multiclass_config = ClassifierConfig(model_name="microsoft/resnet-18", device=DEVICE, num_classes=10)
|
||||
multiclass_classifier = Classifier(multiclass_config)
|
||||
|
||||
trainset = CIFAR10(root="data", train=True, download=True, transform=ToTensor())
|
||||
@@ -119,18 +117,10 @@ def train_evaluate_multiclass_classifier():
|
||||
test_probs = torch.stack(test_probs)
|
||||
|
||||
accuracy = Accuracy(task="multiclass", num_classes=multiclass_num_classes)
|
||||
precision = Precision(
|
||||
task="multiclass", average="weighted", num_classes=multiclass_num_classes
|
||||
)
|
||||
recall = Recall(
|
||||
task="multiclass", average="weighted", num_classes=multiclass_num_classes
|
||||
)
|
||||
f1 = F1Score(
|
||||
task="multiclass", average="weighted", num_classes=multiclass_num_classes
|
||||
)
|
||||
auroc = AUROC(
|
||||
task="multiclass", num_classes=multiclass_num_classes, average="weighted"
|
||||
)
|
||||
precision = Precision(task="multiclass", average="weighted", num_classes=multiclass_num_classes)
|
||||
recall = Recall(task="multiclass", average="weighted", num_classes=multiclass_num_classes)
|
||||
f1 = F1Score(task="multiclass", average="weighted", num_classes=multiclass_num_classes)
|
||||
auroc = AUROC(task="multiclass", num_classes=multiclass_num_classes, average="weighted")
|
||||
|
||||
# Calculate metrics
|
||||
acc = accuracy(test_predictions, test_labels)
|
||||
@@ -159,28 +149,18 @@ def train_evaluate_binary_classifier():
|
||||
new_label = float(1.0) if label == target_class else float(0.0)
|
||||
new_targets.append(new_label)
|
||||
|
||||
dataset.targets = (
|
||||
new_targets # Replace the original labels with the binary ones
|
||||
)
|
||||
dataset.targets = new_targets # Replace the original labels with the binary ones
|
||||
return dataset
|
||||
|
||||
binary_train_dataset = CIFAR10(
|
||||
root="data", train=True, download=True, transform=ToTensor()
|
||||
)
|
||||
binary_test_dataset = CIFAR10(
|
||||
root="data", train=False, download=True, transform=ToTensor()
|
||||
)
|
||||
binary_train_dataset = CIFAR10(root="data", train=True, download=True, transform=ToTensor())
|
||||
binary_test_dataset = CIFAR10(root="data", train=False, download=True, transform=ToTensor())
|
||||
|
||||
# Apply one-vs-rest labeling
|
||||
binary_train_dataset = one_vs_rest(binary_train_dataset, target_binary_class)
|
||||
binary_test_dataset = one_vs_rest(binary_test_dataset, target_binary_class)
|
||||
|
||||
binary_trainloader = DataLoader(
|
||||
binary_train_dataset, batch_size=BATCH_SIZE, shuffle=True
|
||||
)
|
||||
binary_testloader = DataLoader(
|
||||
binary_test_dataset, batch_size=BATCH_SIZE, shuffle=False
|
||||
)
|
||||
binary_trainloader = DataLoader(binary_train_dataset, batch_size=BATCH_SIZE, shuffle=True)
|
||||
binary_testloader = DataLoader(binary_test_dataset, batch_size=BATCH_SIZE, shuffle=False)
|
||||
|
||||
binary_epoch = 1
|
||||
|
||||
|
||||
@@ -196,13 +196,9 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs):
|
||||
# Test updating the policy (and test that it does not mutate the batch)
|
||||
batch_ = deepcopy(batch)
|
||||
policy.forward(batch)
|
||||
assert set(batch) == set(batch_), (
|
||||
"Batch keys are not the same after a forward pass."
|
||||
)
|
||||
assert set(batch) == set(batch_), "Batch keys are not the same after a forward pass."
|
||||
assert all(
|
||||
torch.equal(batch[k], batch_[k])
|
||||
if isinstance(batch[k], torch.Tensor)
|
||||
else batch[k] == batch_[k]
|
||||
torch.equal(batch[k], batch_[k]) if isinstance(batch[k], torch.Tensor) else batch[k] == batch_[k]
|
||||
for k in batch
|
||||
), "Batch values are not the same after a forward pass."
|
||||
|
||||
@@ -214,9 +210,7 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs):
|
||||
observation = preprocess_observation(observation)
|
||||
|
||||
# send observation to device/gpu
|
||||
observation = {
|
||||
key: observation[key].to(DEVICE, non_blocking=True) for key in observation
|
||||
}
|
||||
observation = {key: observation[key].to(DEVICE, non_blocking=True) for key in observation}
|
||||
|
||||
# get the next action for the environment (also check that the observation batch is not modified)
|
||||
observation_ = deepcopy(observation)
|
||||
@@ -241,12 +235,8 @@ def test_act_backbone_lr():
|
||||
|
||||
cfg = TrainPipelineConfig(
|
||||
# TODO(rcadene, aliberts): remove dataset download
|
||||
dataset=DatasetConfig(
|
||||
repo_id="lerobot/aloha_sim_insertion_scripted", episodes=[0]
|
||||
),
|
||||
policy=make_policy_config(
|
||||
"act", optimizer_lr=0.01, optimizer_lr_backbone=0.001
|
||||
),
|
||||
dataset=DatasetConfig(repo_id="lerobot/aloha_sim_insertion_scripted", episodes=[0]),
|
||||
policy=make_policy_config("act", optimizer_lr=0.01, optimizer_lr_backbone=0.001),
|
||||
)
|
||||
cfg.validate() # Needed for auto-setting some parameters
|
||||
|
||||
@@ -269,9 +259,7 @@ def test_policy_defaults(dummy_dataset_metadata, policy_name: str):
|
||||
policy_cls = get_policy_class(policy_name)
|
||||
policy_cfg = make_policy_config(policy_name)
|
||||
features = dataset_to_policy_features(dummy_dataset_metadata.features)
|
||||
policy_cfg.output_features = {
|
||||
key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION
|
||||
}
|
||||
policy_cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
||||
policy_cfg.input_features = {
|
||||
key: ft for key, ft in features.items() if key not in policy_cfg.output_features
|
||||
}
|
||||
@@ -283,9 +271,7 @@ def test_save_and_load_pretrained(dummy_dataset_metadata, tmp_path, policy_name:
|
||||
policy_cls = get_policy_class(policy_name)
|
||||
policy_cfg = make_policy_config(policy_name)
|
||||
features = dataset_to_policy_features(dummy_dataset_metadata.features)
|
||||
policy_cfg.output_features = {
|
||||
key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION
|
||||
}
|
||||
policy_cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
||||
policy_cfg.input_features = {
|
||||
key: ft for key, ft in features.items() if key not in policy_cfg.output_features
|
||||
}
|
||||
@@ -294,9 +280,7 @@ def test_save_and_load_pretrained(dummy_dataset_metadata, tmp_path, policy_name:
|
||||
save_dir = tmp_path / f"test_save_and_load_pretrained_{policy_cls.__name__}"
|
||||
policy.save_pretrained(save_dir)
|
||||
loaded_policy = policy_cls.from_pretrained(save_dir, config=policy_cfg)
|
||||
torch.testing.assert_close(
|
||||
list(policy.parameters()), list(loaded_policy.parameters()), rtol=0, atol=0
|
||||
)
|
||||
torch.testing.assert_close(list(policy.parameters()), list(loaded_policy.parameters()), rtol=0, atol=0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("insert_temporal_dim", [False, True])
|
||||
@@ -436,9 +420,7 @@ def test_normalize(insert_temporal_dim):
|
||||
# pass if it's run on another platform due to floating point errors
|
||||
@require_x86_64_kernel
|
||||
@require_cpu
|
||||
def test_backward_compatibility(
|
||||
ds_repo_id: str, policy_name: str, policy_kwargs: dict, file_name_extra: str
|
||||
):
|
||||
def test_backward_compatibility(ds_repo_id: str, policy_name: str, policy_kwargs: dict, file_name_extra: str):
|
||||
"""
|
||||
NOTE: If this test does not pass, and you have intentionally changed something in the policy:
|
||||
1. Inspect the differences in policy outputs and make sure you can account for them. Your PR should
|
||||
@@ -452,17 +434,13 @@ def test_backward_compatibility(
|
||||
6. Remember to stage and commit the resulting changes to `tests/artifacts`.
|
||||
"""
|
||||
ds_name = ds_repo_id.split("/")[-1]
|
||||
artifact_dir = (
|
||||
Path("tests/artifacts/policies") / f"{ds_name}_{policy_name}_{file_name_extra}"
|
||||
)
|
||||
artifact_dir = Path("tests/artifacts/policies") / f"{ds_name}_{policy_name}_{file_name_extra}"
|
||||
saved_output_dict = load_file(artifact_dir / "output_dict.safetensors")
|
||||
saved_grad_stats = load_file(artifact_dir / "grad_stats.safetensors")
|
||||
saved_param_stats = load_file(artifact_dir / "param_stats.safetensors")
|
||||
saved_actions = load_file(artifact_dir / "actions.safetensors")
|
||||
|
||||
output_dict, grad_stats, param_stats, actions = get_policy_stats(
|
||||
ds_repo_id, policy_name, policy_kwargs
|
||||
)
|
||||
output_dict, grad_stats, param_stats, actions = get_policy_stats(ds_repo_id, policy_name, policy_kwargs)
|
||||
|
||||
for key in saved_output_dict:
|
||||
torch.testing.assert_close(output_dict[key], saved_output_dict[key])
|
||||
@@ -471,12 +449,8 @@ def test_backward_compatibility(
|
||||
for key in saved_param_stats:
|
||||
torch.testing.assert_close(param_stats[key], saved_param_stats[key])
|
||||
for key in saved_actions:
|
||||
rtol, atol = (
|
||||
(2e-3, 5e-6) if policy_name == "diffusion" else (None, None)
|
||||
) # HACK
|
||||
torch.testing.assert_close(
|
||||
actions[key], saved_actions[key], rtol=rtol, atol=atol
|
||||
)
|
||||
rtol, atol = (2e-3, 5e-6) if policy_name == "diffusion" else (None, None) # HACK
|
||||
torch.testing.assert_close(actions[key], saved_actions[key], rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
def test_act_temporal_ensembler():
|
||||
@@ -502,9 +476,7 @@ def test_act_temporal_ensembler():
|
||||
batch_size = batch_seq.shape[0]
|
||||
# Exponential weighting (normalized). Unsqueeze once to match the position of the `episode_length`
|
||||
# dimension of `batch_seq`.
|
||||
weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size)).unsqueeze(
|
||||
-1
|
||||
)
|
||||
weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size)).unsqueeze(-1)
|
||||
|
||||
# Simulate stepping through a rollout and computing a batch of actions with model on each step.
|
||||
for i in range(episode_length):
|
||||
@@ -527,8 +499,7 @@ def test_act_temporal_ensembler():
|
||||
episode_step_indices = torch.arange(i + 1)[-len(chunk_indices) :]
|
||||
seq_slice = batch_seq[:, episode_step_indices, chunk_indices]
|
||||
offline_avg = (
|
||||
einops.reduce(seq_slice * weights[: i + 1], "b s 1 -> b 1", "sum")
|
||||
/ weights[: i + 1].sum()
|
||||
einops.reduce(seq_slice * weights[: i + 1], "b s 1 -> b 1", "sum") / weights[: i + 1].sum()
|
||||
)
|
||||
# Sanity check. The average should be between the extrema.
|
||||
assert torch.all(einops.reduce(seq_slice, "b s 1 -> b 1", "min") <= offline_avg)
|
||||
|
||||
Reference in New Issue
Block a user