diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index 6b8037fe..3f436b74 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -1,9 +1,8 @@ import torch +from torchrl.data.replay_buffers import PrioritizedSliceSampler from lerobot.common.datasets.pusht import PushtExperienceReplay from lerobot.common.datasets.simxarm import SimxarmExperienceReplay -# from rl.torchrl.data.replay_buffers.samplers import PrioritizedSliceSampler -from torchrl.data.replay_buffers import PrioritizedSliceSampler # TODO(rcadene): implement @@ -27,6 +26,17 @@ from torchrl.data.replay_buffers import PrioritizedSliceSampler def make_offline_buffer(cfg, sampler=None): + if cfg.policy.balanced_sampling: + assert cfg.online_steps > 0 + batch_size = None + pin_memory = False + prefetch = None + else: + assert cfg.online_steps == 0 + num_slices = cfg.policy.batch_size + batch_size = cfg.policy.horizon * num_slices + pin_memory = cfg.device == "cuda" + prefetch = cfg.prefetch overwrite_sampler = sampler is not None @@ -52,6 +62,9 @@ def make_offline_buffer(cfg, sampler=None): streaming=False, root="data", sampler=sampler, + batch_size=batch_size, + pin_memory=pin_memory, + prefetch=prefetch, ) elif cfg.env.name == "pusht": offline_buffer = PushtExperienceReplay( @@ -61,6 +74,9 @@ def make_offline_buffer(cfg, sampler=None): streaming=False, root="data", sampler=sampler, + batch_size=batch_size, + pin_memory=pin_memory, + prefetch=prefetch, ) else: raise ValueError(cfg.env.name) diff --git a/lerobot/common/policies/diffusion.py b/lerobot/common/policies/diffusion.py index 65d7085c..3bd9f515 100644 --- a/lerobot/common/policies/diffusion.py +++ b/lerobot/common/policies/diffusion.py @@ -119,9 +119,9 @@ class DiffusionPolicy(nn.Module): assert batch_size % num_slices == 0 def process_batch(batch, horizon, num_slices): - # trajectory t = 256, horizon h = 5 - # (t h) ... -> h t ... - batch = batch.reshape(num_slices, horizon).transpose(1, 0).contiguous() + # trajectory t = 64, horizon h = 16 + # (t h) ... -> t h ... + batch = batch.reshape(num_slices, horizon) # .transpose(1, 0).contiguous() out = { "obs": { @@ -132,7 +132,10 @@ class DiffusionPolicy(nn.Module): } return out - batch = replay_buffer.sample(batch_size) + if self.cfg.balanced_sampling: + batch = replay_buffer.sample(batch_size) + else: + batch = replay_buffer.sample() batch = process_batch(batch, self.cfg.horizon, num_slices) loss = self.diffusion.compute_loss(batch) @@ -149,4 +152,17 @@ class DiffusionPolicy(nn.Module): "total_loss": loss.item(), "lr": self.lr_scheduler.get_last_lr()[0], } + + # TODO(rcadene): remove hardcoding + # in diffusion_policy, len(dataloader) is 168 for a batch_size of 64 + if step % 168 == 0: + self.global_step += 1 + return metrics + + def save(self, fp): + torch.save(self.state_dict(), fp) + + def load(self, fp): + d = torch.load(fp) + self.load_state_dict(d) diff --git a/lerobot/common/policies/tdmpc.py b/lerobot/common/policies/tdmpc.py index d3a3c19e..55c022df 100644 --- a/lerobot/common/policies/tdmpc.py +++ b/lerobot/common/policies/tdmpc.py @@ -359,7 +359,11 @@ class TDMPC(nn.Module): weights = batch["_weight"][FIRST_FRAME, :, None] return obs, action, next_obses, reward, mask, done, idxs, weights - batch = replay_buffer.sample(batch_size) + if self.cfg.balanced_sampling: + batch = replay_buffer.sample(batch_size) + else: + batch = replay_buffer.sample() + obs, action, next_obses, reward, mask, done, idxs, weights = process_batch( batch, self.cfg.horizon, num_slices ) diff --git a/lerobot/configs/default.yaml b/lerobot/configs/default.yaml index 83cbc5a3..97f560f5 100644 --- a/lerobot/configs/default.yaml +++ b/lerobot/configs/default.yaml @@ -12,6 +12,7 @@ hydra: seed: 1337 device: cuda buffer_device: cuda +prefetch: 4 eval_freq: ??? save_freq: ??? eval_episodes: ??? diff --git a/lerobot/configs/policy/diffusion.yaml b/lerobot/configs/policy/diffusion.yaml index c3b18298..8e88468d 100644 --- a/lerobot/configs/policy/diffusion.yaml +++ b/lerobot/configs/policy/diffusion.yaml @@ -21,7 +21,12 @@ past_action_visible: False keypoint_visible_rate: 1.0 obs_as_global_cond: True -offline_steps: 50000 +eval_episodes: 50 +eval_freq: 10000 +save_freq: 100000 +log_freq: 250 + +offline_steps: 1344000 online_steps: 0 policy: @@ -48,8 +53,7 @@ policy: per_alpha: 0.6 per_beta: 0.4 - balanced_sampling: true - + balanced_sampling: false utd: 1 offline_steps: ${offline_steps} use_ema: true diff --git a/sbatch.sh b/sbatch.sh index 732a902f..52a4df4b 100644 --- a/sbatch.sh +++ b/sbatch.sh @@ -6,7 +6,7 @@ #SBATCH --time=2-00:00:00 #SBATCH --output=/home/rcadene/slurm/%j.out #SBATCH --error=/home/rcadene/slurm/%j.err -#SBATCH --qos=low +#SBATCH --qos=medium #SBATCH --mail-user=re.cadene@gmail.com #SBATCH --mail-type=ALL