Fix diffusion (rm transpose), Add prefetch
This commit is contained in:
@@ -1,9 +1,8 @@
|
|||||||
import torch
|
import torch
|
||||||
|
from torchrl.data.replay_buffers import PrioritizedSliceSampler
|
||||||
|
|
||||||
from lerobot.common.datasets.pusht import PushtExperienceReplay
|
from lerobot.common.datasets.pusht import PushtExperienceReplay
|
||||||
from lerobot.common.datasets.simxarm import SimxarmExperienceReplay
|
from lerobot.common.datasets.simxarm import SimxarmExperienceReplay
|
||||||
# from rl.torchrl.data.replay_buffers.samplers import PrioritizedSliceSampler
|
|
||||||
from torchrl.data.replay_buffers import PrioritizedSliceSampler
|
|
||||||
|
|
||||||
# TODO(rcadene): implement
|
# TODO(rcadene): implement
|
||||||
|
|
||||||
@@ -27,6 +26,17 @@ from torchrl.data.replay_buffers import PrioritizedSliceSampler
|
|||||||
|
|
||||||
|
|
||||||
def make_offline_buffer(cfg, sampler=None):
|
def make_offline_buffer(cfg, sampler=None):
|
||||||
|
if cfg.policy.balanced_sampling:
|
||||||
|
assert cfg.online_steps > 0
|
||||||
|
batch_size = None
|
||||||
|
pin_memory = False
|
||||||
|
prefetch = None
|
||||||
|
else:
|
||||||
|
assert cfg.online_steps == 0
|
||||||
|
num_slices = cfg.policy.batch_size
|
||||||
|
batch_size = cfg.policy.horizon * num_slices
|
||||||
|
pin_memory = cfg.device == "cuda"
|
||||||
|
prefetch = cfg.prefetch
|
||||||
|
|
||||||
overwrite_sampler = sampler is not None
|
overwrite_sampler = sampler is not None
|
||||||
|
|
||||||
@@ -52,6 +62,9 @@ def make_offline_buffer(cfg, sampler=None):
|
|||||||
streaming=False,
|
streaming=False,
|
||||||
root="data",
|
root="data",
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
|
batch_size=batch_size,
|
||||||
|
pin_memory=pin_memory,
|
||||||
|
prefetch=prefetch,
|
||||||
)
|
)
|
||||||
elif cfg.env.name == "pusht":
|
elif cfg.env.name == "pusht":
|
||||||
offline_buffer = PushtExperienceReplay(
|
offline_buffer = PushtExperienceReplay(
|
||||||
@@ -61,6 +74,9 @@ def make_offline_buffer(cfg, sampler=None):
|
|||||||
streaming=False,
|
streaming=False,
|
||||||
root="data",
|
root="data",
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
|
batch_size=batch_size,
|
||||||
|
pin_memory=pin_memory,
|
||||||
|
prefetch=prefetch,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(cfg.env.name)
|
raise ValueError(cfg.env.name)
|
||||||
|
|||||||
@@ -119,9 +119,9 @@ class DiffusionPolicy(nn.Module):
|
|||||||
assert batch_size % num_slices == 0
|
assert batch_size % num_slices == 0
|
||||||
|
|
||||||
def process_batch(batch, horizon, num_slices):
|
def process_batch(batch, horizon, num_slices):
|
||||||
# trajectory t = 256, horizon h = 5
|
# trajectory t = 64, horizon h = 16
|
||||||
# (t h) ... -> h t ...
|
# (t h) ... -> t h ...
|
||||||
batch = batch.reshape(num_slices, horizon).transpose(1, 0).contiguous()
|
batch = batch.reshape(num_slices, horizon) # .transpose(1, 0).contiguous()
|
||||||
|
|
||||||
out = {
|
out = {
|
||||||
"obs": {
|
"obs": {
|
||||||
@@ -132,7 +132,10 @@ class DiffusionPolicy(nn.Module):
|
|||||||
}
|
}
|
||||||
return out
|
return out
|
||||||
|
|
||||||
batch = replay_buffer.sample(batch_size)
|
if self.cfg.balanced_sampling:
|
||||||
|
batch = replay_buffer.sample(batch_size)
|
||||||
|
else:
|
||||||
|
batch = replay_buffer.sample()
|
||||||
batch = process_batch(batch, self.cfg.horizon, num_slices)
|
batch = process_batch(batch, self.cfg.horizon, num_slices)
|
||||||
|
|
||||||
loss = self.diffusion.compute_loss(batch)
|
loss = self.diffusion.compute_loss(batch)
|
||||||
@@ -149,4 +152,17 @@ class DiffusionPolicy(nn.Module):
|
|||||||
"total_loss": loss.item(),
|
"total_loss": loss.item(),
|
||||||
"lr": self.lr_scheduler.get_last_lr()[0],
|
"lr": self.lr_scheduler.get_last_lr()[0],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# TODO(rcadene): remove hardcoding
|
||||||
|
# in diffusion_policy, len(dataloader) is 168 for a batch_size of 64
|
||||||
|
if step % 168 == 0:
|
||||||
|
self.global_step += 1
|
||||||
|
|
||||||
return metrics
|
return metrics
|
||||||
|
|
||||||
|
def save(self, fp):
|
||||||
|
torch.save(self.state_dict(), fp)
|
||||||
|
|
||||||
|
def load(self, fp):
|
||||||
|
d = torch.load(fp)
|
||||||
|
self.load_state_dict(d)
|
||||||
|
|||||||
@@ -359,7 +359,11 @@ class TDMPC(nn.Module):
|
|||||||
weights = batch["_weight"][FIRST_FRAME, :, None]
|
weights = batch["_weight"][FIRST_FRAME, :, None]
|
||||||
return obs, action, next_obses, reward, mask, done, idxs, weights
|
return obs, action, next_obses, reward, mask, done, idxs, weights
|
||||||
|
|
||||||
batch = replay_buffer.sample(batch_size)
|
if self.cfg.balanced_sampling:
|
||||||
|
batch = replay_buffer.sample(batch_size)
|
||||||
|
else:
|
||||||
|
batch = replay_buffer.sample()
|
||||||
|
|
||||||
obs, action, next_obses, reward, mask, done, idxs, weights = process_batch(
|
obs, action, next_obses, reward, mask, done, idxs, weights = process_batch(
|
||||||
batch, self.cfg.horizon, num_slices
|
batch, self.cfg.horizon, num_slices
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ hydra:
|
|||||||
seed: 1337
|
seed: 1337
|
||||||
device: cuda
|
device: cuda
|
||||||
buffer_device: cuda
|
buffer_device: cuda
|
||||||
|
prefetch: 4
|
||||||
eval_freq: ???
|
eval_freq: ???
|
||||||
save_freq: ???
|
save_freq: ???
|
||||||
eval_episodes: ???
|
eval_episodes: ???
|
||||||
|
|||||||
@@ -21,7 +21,12 @@ past_action_visible: False
|
|||||||
keypoint_visible_rate: 1.0
|
keypoint_visible_rate: 1.0
|
||||||
obs_as_global_cond: True
|
obs_as_global_cond: True
|
||||||
|
|
||||||
offline_steps: 50000
|
eval_episodes: 50
|
||||||
|
eval_freq: 10000
|
||||||
|
save_freq: 100000
|
||||||
|
log_freq: 250
|
||||||
|
|
||||||
|
offline_steps: 1344000
|
||||||
online_steps: 0
|
online_steps: 0
|
||||||
|
|
||||||
policy:
|
policy:
|
||||||
@@ -48,8 +53,7 @@ policy:
|
|||||||
per_alpha: 0.6
|
per_alpha: 0.6
|
||||||
per_beta: 0.4
|
per_beta: 0.4
|
||||||
|
|
||||||
balanced_sampling: true
|
balanced_sampling: false
|
||||||
|
|
||||||
utd: 1
|
utd: 1
|
||||||
offline_steps: ${offline_steps}
|
offline_steps: ${offline_steps}
|
||||||
use_ema: true
|
use_ema: true
|
||||||
|
|||||||
@@ -6,7 +6,7 @@
|
|||||||
#SBATCH --time=2-00:00:00
|
#SBATCH --time=2-00:00:00
|
||||||
#SBATCH --output=/home/rcadene/slurm/%j.out
|
#SBATCH --output=/home/rcadene/slurm/%j.out
|
||||||
#SBATCH --error=/home/rcadene/slurm/%j.err
|
#SBATCH --error=/home/rcadene/slurm/%j.err
|
||||||
#SBATCH --qos=low
|
#SBATCH --qos=medium
|
||||||
#SBATCH --mail-user=re.cadene@gmail.com
|
#SBATCH --mail-user=re.cadene@gmail.com
|
||||||
#SBATCH --mail-type=ALL
|
#SBATCH --mail-type=ALL
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user