WIP Upgrading simxam from mujoco-py to mujoco python bindings

This commit is contained in:
Simon Alibert
2024-03-24 17:36:22 +01:00
parent e41c420a96
commit 1c24bbda3f
55 changed files with 1253 additions and 105 deletions

View File

@@ -32,6 +32,7 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
collate_fn: Callable = None,
writer: Writer = None,
transform: "torchrl.envs.Transform" = None,
# storage = None,
):
self.dataset_id = dataset_id
self.version = version
@@ -43,7 +44,12 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
f"The version of the dataset ({self.version}) is not enforced when root is provided ({self.root})."
)
storage = self._download_or_load_dataset()
# HACK
if dataset_id == "xarm_lift_medium":
self.data_dir = self.root / self.dataset_id
storage = self._download_and_preproc_obsolete()
else:
storage = self._download_or_load_dataset()
super().__init__(
storage=storage,

View File

@@ -67,11 +67,11 @@ class SimxarmExperienceReplay(AbstractExperienceReplay):
)
def _download_and_preproc_obsolete(self):
assert self.root is not None
# assert self.root is not None
# TODO(rcadene): finish download
download()
# download()
dataset_path = self.root / f"{self.dataset_id}_raw" / "buffer.pkl"
dataset_path = self.root / f"{self.dataset_id}" / "buffer.pkl"
print(f"Using offline dataset '{dataset_path}'")
with open(dataset_path, "rb") as f:
dataset_dict = pickle.load(f)
@@ -105,8 +105,8 @@ class SimxarmExperienceReplay(AbstractExperienceReplay):
"frame_id": torch.arange(0, num_frames, 1),
("next", "observation", "image"): next_image,
("next", "observation", "state"): next_state,
("next", "observation", "reward"): next_reward,
("next", "observation", "done"): next_done,
("next", "reward"): next_reward,
("next", "done"): next_done,
},
batch_size=num_frames,
)