forked from tangger/lerobot
Merge pull request #58 from alexander-soare/update_diffusion_model
Update diffusion model
This commit is contained in:
@@ -145,7 +145,6 @@ Or you can achieve the same result by executing our script from the command line
|
||||
```bash
|
||||
python lerobot/scripts/eval.py \
|
||||
--hub-id lerobot/diffusion_policy_pusht_image \
|
||||
--revision v1.0 \
|
||||
eval_episodes=10 \
|
||||
hydra.run.dir=outputs/eval/example_hub
|
||||
```
|
||||
|
||||
@@ -81,13 +81,8 @@ def make_offline_buffer(
|
||||
else:
|
||||
raise ValueError(cfg.env.name)
|
||||
|
||||
# TODO(rcadene): backward compatiblity to load pretrained pusht policy
|
||||
dataset_id = cfg.get("dataset_id")
|
||||
if dataset_id is None and cfg.env.name == "pusht":
|
||||
dataset_id = "pusht"
|
||||
|
||||
offline_buffer = clsfunc(
|
||||
dataset_id=dataset_id,
|
||||
dataset_id=cfg.dataset_id,
|
||||
sampler=sampler,
|
||||
batch_size=batch_size,
|
||||
root=DATA_DIR,
|
||||
|
||||
@@ -103,29 +103,3 @@ optimizer:
|
||||
betas: [0.95, 0.999]
|
||||
eps: 1.0e-8
|
||||
weight_decay: 1.0e-6
|
||||
|
||||
training:
|
||||
device: "cuda:0"
|
||||
seed: 42
|
||||
debug: False
|
||||
resume: True
|
||||
# optimization
|
||||
# lr_scheduler: cosine
|
||||
# lr_warmup_steps: 500
|
||||
num_epochs: 8000
|
||||
# gradient_accumulate_every: 1
|
||||
# EMA destroys performance when used with BatchNorm
|
||||
# replace BatchNorm with GroupNorm.
|
||||
# use_ema: True
|
||||
freeze_encoder: False
|
||||
# training loop control
|
||||
# in epochs
|
||||
rollout_every: 50
|
||||
checkpoint_every: 50
|
||||
val_every: 1
|
||||
sample_every: 5
|
||||
# steps per epoch
|
||||
max_train_steps: null
|
||||
max_val_steps: null
|
||||
# misc
|
||||
tqdm_interval_sec: 1.0
|
||||
|
||||
@@ -268,7 +268,7 @@ if __name__ == "__main__":
|
||||
# TODO(alexander-soare): Save and load stats in trained model directory.
|
||||
stats_path = None
|
||||
elif args.hub_id is not None:
|
||||
folder = Path(snapshot_download(args.hub_id, revision="v1.0"))
|
||||
folder = Path(snapshot_download(args.hub_id, revision=args.revision))
|
||||
cfg = hydra.initialize(config_path=str(_relative_path_between(folder, Path(__file__).parent)))
|
||||
cfg = hydra.compose("config", args.overrides)
|
||||
cfg.policy.pretrained_model_path = folder / "model.pt"
|
||||
|
||||
Reference in New Issue
Block a user