Improve visualize_dataset, Improve AbstractReplayBuffer, Small improvements
This commit is contained in:
@@ -183,8 +183,7 @@ class PushtExperienceReplay(AbstractExperienceReplay):
|
||||
# last step of demonstration is considered done
|
||||
done[-1] = True
|
||||
|
||||
print("before " + """episode = TensorDict(""")
|
||||
episode = TensorDict(
|
||||
ep_td = TensorDict(
|
||||
{
|
||||
("observation", "image"): image[:-1],
|
||||
("observation", "state"): agent_pos[:-1],
|
||||
@@ -203,11 +202,11 @@ class PushtExperienceReplay(AbstractExperienceReplay):
|
||||
|
||||
if episode_id == 0:
|
||||
# hack to initialize tensordict data structure to store episodes
|
||||
td_data = episode[0].expand(total_frames).memmap_like(self.data_dir)
|
||||
td_data = ep_td[0].expand(total_frames).memmap_like(self.data_dir)
|
||||
|
||||
td_data[idxtd : idxtd + len(episode)] = episode
|
||||
td_data[idxtd : idxtd + len(ep_td)] = ep_td
|
||||
|
||||
idx0 = idx1
|
||||
idxtd = idxtd + len(episode)
|
||||
idxtd = idxtd + len(ep_td)
|
||||
|
||||
return TensorStorage(td_data.lock_())
|
||||
|
||||
Reference in New Issue
Block a user