Improve visualize_dataset, Improve AbstractReplayBuffer, Small improvements

This commit is contained in:
Remi Cadene
2024-03-06 10:14:03 +00:00
parent 2f80d71c3e
commit f95ecd66fc
7 changed files with 195 additions and 150 deletions

View File

@@ -183,8 +183,7 @@ class PushtExperienceReplay(AbstractExperienceReplay):
# last step of demonstration is considered done
done[-1] = True
print("before " + """episode = TensorDict(""")
episode = TensorDict(
ep_td = TensorDict(
{
("observation", "image"): image[:-1],
("observation", "state"): agent_pos[:-1],
@@ -203,11 +202,11 @@ class PushtExperienceReplay(AbstractExperienceReplay):
if episode_id == 0:
# hack to initialize tensordict data structure to store episodes
td_data = episode[0].expand(total_frames).memmap_like(self.data_dir)
td_data = ep_td[0].expand(total_frames).memmap_like(self.data_dir)
td_data[idxtd : idxtd + len(episode)] = episode
td_data[idxtd : idxtd + len(ep_td)] = ep_td
idx0 = idx1
idxtd = idxtd + len(episode)
idxtd = idxtd + len(ep_td)
return TensorStorage(td_data.lock_())