make unit tests pass
This commit is contained in:
@@ -44,7 +44,7 @@ from datasets import load_dataset
|
||||
# TODO(rcadene): list available datasets on lerobot page using `datasets`
|
||||
|
||||
# download/load hugging face dataset in pyarrow format
|
||||
hf_dataset, fps = load_dataset("lerobot/pusht", split="train"), 10
|
||||
hf_dataset, fps = load_dataset("lerobot/pusht", split="train", revision="v1.1"), 10
|
||||
|
||||
# display name of dataset and its features
|
||||
# TODO(rcadene): update to make the print pretty
|
||||
|
||||
@@ -34,7 +34,7 @@ dataset = make_dataset(hydra_cfg)
|
||||
# If you're doing something different, you will likely need to change at least some of the defaults.
|
||||
cfg = DiffusionConfig()
|
||||
# TODO(alexander-soare): Remove LR scheduler from the policy.
|
||||
policy = DiffusionPolicy(cfg, lr_scheduler_num_training_steps=training_steps)
|
||||
policy = DiffusionPolicy(cfg, lr_scheduler_num_training_steps=training_steps, dataset_stats=dataset.stats)
|
||||
policy.train()
|
||||
policy.to(device)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user