Merge pull request #58 from alexander-soare/update_diffusion_model
Update diffusion model
This commit is contained in:
@@ -145,7 +145,6 @@ Or you can achieve the same result by executing our script from the command line
|
|||||||
```bash
|
```bash
|
||||||
python lerobot/scripts/eval.py \
|
python lerobot/scripts/eval.py \
|
||||||
--hub-id lerobot/diffusion_policy_pusht_image \
|
--hub-id lerobot/diffusion_policy_pusht_image \
|
||||||
--revision v1.0 \
|
|
||||||
eval_episodes=10 \
|
eval_episodes=10 \
|
||||||
hydra.run.dir=outputs/eval/example_hub
|
hydra.run.dir=outputs/eval/example_hub
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -81,13 +81,8 @@ def make_offline_buffer(
|
|||||||
else:
|
else:
|
||||||
raise ValueError(cfg.env.name)
|
raise ValueError(cfg.env.name)
|
||||||
|
|
||||||
# TODO(rcadene): backward compatiblity to load pretrained pusht policy
|
|
||||||
dataset_id = cfg.get("dataset_id")
|
|
||||||
if dataset_id is None and cfg.env.name == "pusht":
|
|
||||||
dataset_id = "pusht"
|
|
||||||
|
|
||||||
offline_buffer = clsfunc(
|
offline_buffer = clsfunc(
|
||||||
dataset_id=dataset_id,
|
dataset_id=cfg.dataset_id,
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
root=DATA_DIR,
|
root=DATA_DIR,
|
||||||
|
|||||||
@@ -103,29 +103,3 @@ optimizer:
|
|||||||
betas: [0.95, 0.999]
|
betas: [0.95, 0.999]
|
||||||
eps: 1.0e-8
|
eps: 1.0e-8
|
||||||
weight_decay: 1.0e-6
|
weight_decay: 1.0e-6
|
||||||
|
|
||||||
training:
|
|
||||||
device: "cuda:0"
|
|
||||||
seed: 42
|
|
||||||
debug: False
|
|
||||||
resume: True
|
|
||||||
# optimization
|
|
||||||
# lr_scheduler: cosine
|
|
||||||
# lr_warmup_steps: 500
|
|
||||||
num_epochs: 8000
|
|
||||||
# gradient_accumulate_every: 1
|
|
||||||
# EMA destroys performance when used with BatchNorm
|
|
||||||
# replace BatchNorm with GroupNorm.
|
|
||||||
# use_ema: True
|
|
||||||
freeze_encoder: False
|
|
||||||
# training loop control
|
|
||||||
# in epochs
|
|
||||||
rollout_every: 50
|
|
||||||
checkpoint_every: 50
|
|
||||||
val_every: 1
|
|
||||||
sample_every: 5
|
|
||||||
# steps per epoch
|
|
||||||
max_train_steps: null
|
|
||||||
max_val_steps: null
|
|
||||||
# misc
|
|
||||||
tqdm_interval_sec: 1.0
|
|
||||||
|
|||||||
@@ -268,7 +268,7 @@ if __name__ == "__main__":
|
|||||||
# TODO(alexander-soare): Save and load stats in trained model directory.
|
# TODO(alexander-soare): Save and load stats in trained model directory.
|
||||||
stats_path = None
|
stats_path = None
|
||||||
elif args.hub_id is not None:
|
elif args.hub_id is not None:
|
||||||
folder = Path(snapshot_download(args.hub_id, revision="v1.0"))
|
folder = Path(snapshot_download(args.hub_id, revision=args.revision))
|
||||||
cfg = hydra.initialize(config_path=str(_relative_path_between(folder, Path(__file__).parent)))
|
cfg = hydra.initialize(config_path=str(_relative_path_between(folder, Path(__file__).parent)))
|
||||||
cfg = hydra.compose("config", args.overrides)
|
cfg = hydra.compose("config", args.overrides)
|
||||||
cfg.policy.pretrained_model_path = folder / "model.pt"
|
cfg.policy.pretrained_model_path = folder / "model.pt"
|
||||||
|
|||||||
Reference in New Issue
Block a user