Fix diffusion (rm transpose), Add prefetch

This commit is contained in:
Cadene
2024-02-28 17:45:01 +00:00
parent cf5063e50e
commit ac90b9c3ee
6 changed files with 52 additions and 11 deletions

View File

@@ -1,9 +1,8 @@
import torch
from torchrl.data.replay_buffers import PrioritizedSliceSampler
from lerobot.common.datasets.pusht import PushtExperienceReplay
from lerobot.common.datasets.simxarm import SimxarmExperienceReplay
# from rl.torchrl.data.replay_buffers.samplers import PrioritizedSliceSampler
from torchrl.data.replay_buffers import PrioritizedSliceSampler
# TODO(rcadene): implement
@@ -27,6 +26,17 @@ from torchrl.data.replay_buffers import PrioritizedSliceSampler
def make_offline_buffer(cfg, sampler=None):
if cfg.policy.balanced_sampling:
assert cfg.online_steps > 0
batch_size = None
pin_memory = False
prefetch = None
else:
assert cfg.online_steps == 0
num_slices = cfg.policy.batch_size
batch_size = cfg.policy.horizon * num_slices
pin_memory = cfg.device == "cuda"
prefetch = cfg.prefetch
overwrite_sampler = sampler is not None
@@ -52,6 +62,9 @@ def make_offline_buffer(cfg, sampler=None):
streaming=False,
root="data",
sampler=sampler,
batch_size=batch_size,
pin_memory=pin_memory,
prefetch=prefetch,
)
elif cfg.env.name == "pusht":
offline_buffer = PushtExperienceReplay(
@@ -61,6 +74,9 @@ def make_offline_buffer(cfg, sampler=None):
streaming=False,
root="data",
sampler=sampler,
batch_size=batch_size,
pin_memory=pin_memory,
prefetch=prefetch,
)
else:
raise ValueError(cfg.env.name)

View File

@@ -119,9 +119,9 @@ class DiffusionPolicy(nn.Module):
assert batch_size % num_slices == 0
def process_batch(batch, horizon, num_slices):
# trajectory t = 256, horizon h = 5
# (t h) ... -> h t ...
batch = batch.reshape(num_slices, horizon).transpose(1, 0).contiguous()
# trajectory t = 64, horizon h = 16
# (t h) ... -> t h ...
batch = batch.reshape(num_slices, horizon) # .transpose(1, 0).contiguous()
out = {
"obs": {
@@ -132,7 +132,10 @@ class DiffusionPolicy(nn.Module):
}
return out
batch = replay_buffer.sample(batch_size)
if self.cfg.balanced_sampling:
batch = replay_buffer.sample(batch_size)
else:
batch = replay_buffer.sample()
batch = process_batch(batch, self.cfg.horizon, num_slices)
loss = self.diffusion.compute_loss(batch)
@@ -149,4 +152,17 @@ class DiffusionPolicy(nn.Module):
"total_loss": loss.item(),
"lr": self.lr_scheduler.get_last_lr()[0],
}
# TODO(rcadene): remove hardcoding
# in diffusion_policy, len(dataloader) is 168 for a batch_size of 64
if step % 168 == 0:
self.global_step += 1
return metrics
def save(self, fp):
torch.save(self.state_dict(), fp)
def load(self, fp):
d = torch.load(fp)
self.load_state_dict(d)

View File

@@ -359,7 +359,11 @@ class TDMPC(nn.Module):
weights = batch["_weight"][FIRST_FRAME, :, None]
return obs, action, next_obses, reward, mask, done, idxs, weights
batch = replay_buffer.sample(batch_size)
if self.cfg.balanced_sampling:
batch = replay_buffer.sample(batch_size)
else:
batch = replay_buffer.sample()
obs, action, next_obses, reward, mask, done, idxs, weights = process_batch(
batch, self.cfg.horizon, num_slices
)

View File

@@ -12,6 +12,7 @@ hydra:
seed: 1337
device: cuda
buffer_device: cuda
prefetch: 4
eval_freq: ???
save_freq: ???
eval_episodes: ???

View File

@@ -21,7 +21,12 @@ past_action_visible: False
keypoint_visible_rate: 1.0
obs_as_global_cond: True
offline_steps: 50000
eval_episodes: 50
eval_freq: 10000
save_freq: 100000
log_freq: 250
offline_steps: 1344000
online_steps: 0
policy:
@@ -48,8 +53,7 @@ policy:
per_alpha: 0.6
per_beta: 0.4
balanced_sampling: true
balanced_sampling: false
utd: 1
offline_steps: ${offline_steps}
use_ema: true

View File

@@ -6,7 +6,7 @@
#SBATCH --time=2-00:00:00
#SBATCH --output=/home/rcadene/slurm/%j.out
#SBATCH --error=/home/rcadene/slurm/%j.err
#SBATCH --qos=low
#SBATCH --qos=medium
#SBATCH --mail-user=re.cadene@gmail.com
#SBATCH --mail-type=ALL