Merge remote-tracking branch 'Cadene/user/rcadene/2024_03_31_remove_torchrl' into refactor_act_remove_torchrl

This commit is contained in:
Alexander Soare
2024-04-08 09:25:45 +01:00
19 changed files with 253 additions and 242 deletions

View File

@@ -164,19 +164,11 @@ def make_dataset(
]
)
if cfg.policy.name == "diffusion" and cfg.env.name == "pusht":
# TODO(rcadene): implement delta_timestamps in config
delta_timestamps = {
"observation.image": [-0.1, 0],
"observation.state": [-0.1, 0],
"action": [-0.1] + [i / clsfunc.fps for i in range(15)],
}
else:
delta_timestamps = {
"observation.images.top": [0],
"observation.state": [0],
"action": [i / clsfunc.fps for i in range(cfg.policy.horizon)],
}
delta_timestamps = cfg.policy.get("delta_timestamps")
if delta_timestamps is not None:
for key in delta_timestamps:
if isinstance(delta_timestamps[key], str):
delta_timestamps[key] = eval(delta_timestamps[key])
dataset = clsfunc(
dataset_id=cfg.dataset_id,

View File

@@ -6,9 +6,9 @@ import pygame
import pymunk
import torch
import tqdm
from gym_pusht.envs.pusht import pymunk_to_shapely
from lerobot.common.datasets.utils import download_and_extract_zip, load_data_with_delta_timestamps
from lerobot.common.envs.pusht.pusht_env import pymunk_to_shapely
from lerobot.common.policies.diffusion.replay_buffer import ReplayBuffer as DiffusionPolicyReplayBuffer
# as define in env