Merge remote-tracking branch 'upstream/user/rcadene/2024_03_31_remove_torchrl' into refactor_act
This commit is contained in:
@@ -165,7 +165,7 @@ class AlohaDataset(torch.utils.data.Dataset):
|
||||
num_frames = ep["/action"].shape[0]
|
||||
|
||||
# last step of demonstration is considered done
|
||||
done = torch.zeros(num_frames, 1, dtype=torch.bool)
|
||||
done = torch.zeros(num_frames, dtype=torch.bool)
|
||||
done[-1] = True
|
||||
|
||||
state = torch.from_numpy(ep["/observations/qpos"][:])
|
||||
|
||||
@@ -193,8 +193,6 @@ class PushtDataset(torch.utils.data.Dataset):
|
||||
idx0 = 0
|
||||
for episode_id in tqdm.tqdm(range(num_episodes)):
|
||||
idx1 = dataset_dict.meta["episode_ends"][episode_id]
|
||||
# to create test artifact
|
||||
# idx1 = 51
|
||||
|
||||
num_frames = idx1 - idx0
|
||||
|
||||
@@ -207,9 +205,9 @@ class PushtDataset(torch.utils.data.Dataset):
|
||||
block_pos = state[:, 2:4]
|
||||
block_angle = state[:, 4]
|
||||
|
||||
reward = torch.zeros(num_frames, 1)
|
||||
success = torch.zeros(num_frames, 1, dtype=torch.bool)
|
||||
done = torch.zeros(num_frames, 1, dtype=torch.bool)
|
||||
reward = torch.zeros(num_frames)
|
||||
success = torch.zeros(num_frames, dtype=torch.bool)
|
||||
done = torch.zeros(num_frames, dtype=torch.bool)
|
||||
for i in range(num_frames):
|
||||
space = pymunk.Space()
|
||||
space.gravity = 0, 0
|
||||
|
||||
@@ -92,11 +92,11 @@ def load_data_with_delta_timestamps(
|
||||
|
||||
# TODO(rcadene): synchronize timestamps + interpolation if needed
|
||||
|
||||
tol = 0.02
|
||||
tol = 0.04
|
||||
is_pad = min_ > tol
|
||||
|
||||
assert is_contiguously_true_or_false(is_pad), (
|
||||
"One or several timestamps unexpectedly violate the tolerance."
|
||||
f"One or several timestamps unexpectedly violate the tolerance ({min_} > {tol=})."
|
||||
"This might be due to synchronization issues with timestamps during data collection."
|
||||
)
|
||||
|
||||
|
||||
@@ -429,7 +429,7 @@ class TDMPCPolicy(nn.Module):
|
||||
batch[key] = batch[key].transpose(1, 0)
|
||||
|
||||
action = batch["action"]
|
||||
reward = batch["next.reward"][:, :, None] # add extra channel dimension
|
||||
reward = batch["next.reward"]
|
||||
# idxs = batch["index"] # TODO(rcadene): use idxs to update sampling weights
|
||||
done = torch.zeros_like(reward, dtype=torch.bool, device=reward.device)
|
||||
mask = torch.ones_like(reward, dtype=torch.bool, device=reward.device)
|
||||
|
||||
@@ -63,9 +63,9 @@ policy:
|
||||
grad_clip_norm: 10
|
||||
|
||||
delta_timestamps:
|
||||
observation.image: [-.1, 0]
|
||||
observation.state: [-.1, 0]
|
||||
action: [-.1, 0, .1, .2, .3, .4, .5, .6, .7, .8, .9, 1.0, 1.1, 1.2, 1.3, 1.4]
|
||||
observation.image: [-0.1, 0]
|
||||
observation.state: [-0.1, 0]
|
||||
action: [-0.1, 0, .1, .2, .3, .4, .5, .6, .7, .8, .9, 1.0, 1.1, 1.2, 1.3, 1.4]
|
||||
|
||||
noise_scheduler:
|
||||
_target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
|
||||
|
||||
Reference in New Issue
Block a user