Several fixes to move the actor_server and learner_server code from the maniskill environment to the real robot environment.

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
Michel Aractingi
2025-02-10 16:03:39 +01:00
parent af769abd8d
commit 9784d8a47f
10 changed files with 597 additions and 318 deletions

View File

@@ -56,10 +56,10 @@ def move_transition_to_device(transition: Transition, device: str = "cpu") -> Tr
}
# If complementary_info is present, move its tensors to CPU
if transition["complementary_info"] is not None:
transition["complementary_info"] = {
key: val.to(device, non_blocking=True) for key, val in transition["complementary_info"].items()
}
# if transition["complementary_info"] is not None:
# transition["complementary_info"] = {
# key: val.to(device, non_blocking=True) for key, val in transition["complementary_info"].items()
# }
return transition
@@ -309,6 +309,7 @@ class ReplayBuffer:
def sample(self, batch_size: int) -> BatchTransition:
"""Sample a random batch of transitions and collate them into batched tensors."""
batch_size = min(batch_size, len(self.memory))
list_of_transitions = random.sample(self.memory, batch_size)
# -- Build batched states --
@@ -341,9 +342,6 @@ class ReplayBuffer:
batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.float32).to(
self.device
)
batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.float32).to(
self.device
)
# Return a BatchTransition typed dict
return BatchTransition(
@@ -531,30 +529,31 @@ def concatenate_batch_transitions(
# if __name__ == "__main__":
# dataset_name = "lerobot/pusht_image"
# dataset = LeRobotDataset(repo_id=dataset_name, episodes=range(1, 3))
# replay_buffer = ReplayBuffer.from_lerobot_dataset(
# lerobot_dataset=dataset, state_keys=["observation.image", "observation.state"]
# )
# replay_buffer_converted = replay_buffer.to_lerobot_dataset(repo_id="AdilZtn/pusht_image_converted")
# for i in range(len(replay_buffer_converted)):
# replay_convert = replay_buffer_converted[i]
# dataset_convert = dataset[i]
# for key in replay_convert.keys():
# if key in {"index", "episode_index", "frame_index", "timestamp", "task_index"}:
# continue
# if key in dataset_convert.keys():
# assert torch.equal(replay_convert[key], dataset_convert[key])
# print(f"Key {key} is equal : {replay_convert[key].size()}, {dataset_convert[key].size()}")
# re_reconverted_dataset = ReplayBuffer.from_lerobot_dataset(
# replay_buffer_converted, state_keys=["observation.image", "observation.state"], device="cpu"
# )
# for _ in range(20):
# batch = re_reconverted_dataset.sample(32)
# dataset_name = "aractingi/push_green_cube_hf_cropped_resized"
# dataset = LeRobotDataset(repo_id=dataset_name)
# for key in batch.keys():
# if key in {"state", "next_state"}:
# for key_state in batch[key].keys():
# print(key_state, batch[key][key_state].size())
# continue
# print(key, batch[key].size())
# replay_buffer = ReplayBuffer.from_lerobot_dataset(
# lerobot_dataset=dataset, state_keys=["observation.image", "observation.state"]
# )
# replay_buffer_converted = replay_buffer.to_lerobot_dataset(repo_id="AdilZtn/pusht_image_converted")
# for i in range(len(replay_buffer_converted)):
# replay_convert = replay_buffer_converted[i]
# dataset_convert = dataset[i]
# for key in replay_convert.keys():
# if key in {"index", "episode_index", "frame_index", "timestamp", "task_index"}:
# continue
# if key in dataset_convert.keys():
# assert torch.equal(replay_convert[key], dataset_convert[key])
# print(f"Key {key} is equal : {replay_convert[key].size()}, {dataset_convert[key].size()}")
# re_reconverted_dataset = ReplayBuffer.from_lerobot_dataset(
# replay_buffer_converted, state_keys=["observation.image", "observation.state"], device="cpu"
# )
# for _ in range(20):
# batch = re_reconverted_dataset.sample(32)
# for key in batch.keys():
# if key in {"state", "next_state"}:
# for key_state in batch[key].keys():
# print(key_state, batch[key][key_state].size())
# continue
# print(key, batch[key].size())