forked from tangger/lerobot
Several fixes to move the actor_server and learner_server code from the maniskill environment to the real robot environment.
Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
@@ -56,10 +56,10 @@ def move_transition_to_device(transition: Transition, device: str = "cpu") -> Tr
|
||||
}
|
||||
|
||||
# If complementary_info is present, move its tensors to CPU
|
||||
if transition["complementary_info"] is not None:
|
||||
transition["complementary_info"] = {
|
||||
key: val.to(device, non_blocking=True) for key, val in transition["complementary_info"].items()
|
||||
}
|
||||
# if transition["complementary_info"] is not None:
|
||||
# transition["complementary_info"] = {
|
||||
# key: val.to(device, non_blocking=True) for key, val in transition["complementary_info"].items()
|
||||
# }
|
||||
return transition
|
||||
|
||||
|
||||
@@ -309,6 +309,7 @@ class ReplayBuffer:
|
||||
|
||||
def sample(self, batch_size: int) -> BatchTransition:
|
||||
"""Sample a random batch of transitions and collate them into batched tensors."""
|
||||
batch_size = min(batch_size, len(self.memory))
|
||||
list_of_transitions = random.sample(self.memory, batch_size)
|
||||
|
||||
# -- Build batched states --
|
||||
@@ -341,9 +342,6 @@ class ReplayBuffer:
|
||||
batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.float32).to(
|
||||
self.device
|
||||
)
|
||||
batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.float32).to(
|
||||
self.device
|
||||
)
|
||||
|
||||
# Return a BatchTransition typed dict
|
||||
return BatchTransition(
|
||||
@@ -531,30 +529,31 @@ def concatenate_batch_transitions(
|
||||
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# dataset_name = "lerobot/pusht_image"
|
||||
# dataset = LeRobotDataset(repo_id=dataset_name, episodes=range(1, 3))
|
||||
# replay_buffer = ReplayBuffer.from_lerobot_dataset(
|
||||
# lerobot_dataset=dataset, state_keys=["observation.image", "observation.state"]
|
||||
# )
|
||||
# replay_buffer_converted = replay_buffer.to_lerobot_dataset(repo_id="AdilZtn/pusht_image_converted")
|
||||
# for i in range(len(replay_buffer_converted)):
|
||||
# replay_convert = replay_buffer_converted[i]
|
||||
# dataset_convert = dataset[i]
|
||||
# for key in replay_convert.keys():
|
||||
# if key in {"index", "episode_index", "frame_index", "timestamp", "task_index"}:
|
||||
# continue
|
||||
# if key in dataset_convert.keys():
|
||||
# assert torch.equal(replay_convert[key], dataset_convert[key])
|
||||
# print(f"Key {key} is equal : {replay_convert[key].size()}, {dataset_convert[key].size()}")
|
||||
# re_reconverted_dataset = ReplayBuffer.from_lerobot_dataset(
|
||||
# replay_buffer_converted, state_keys=["observation.image", "observation.state"], device="cpu"
|
||||
# )
|
||||
# for _ in range(20):
|
||||
# batch = re_reconverted_dataset.sample(32)
|
||||
# dataset_name = "aractingi/push_green_cube_hf_cropped_resized"
|
||||
# dataset = LeRobotDataset(repo_id=dataset_name)
|
||||
|
||||
# for key in batch.keys():
|
||||
# if key in {"state", "next_state"}:
|
||||
# for key_state in batch[key].keys():
|
||||
# print(key_state, batch[key][key_state].size())
|
||||
# continue
|
||||
# print(key, batch[key].size())
|
||||
# replay_buffer = ReplayBuffer.from_lerobot_dataset(
|
||||
# lerobot_dataset=dataset, state_keys=["observation.image", "observation.state"]
|
||||
# )
|
||||
# replay_buffer_converted = replay_buffer.to_lerobot_dataset(repo_id="AdilZtn/pusht_image_converted")
|
||||
# for i in range(len(replay_buffer_converted)):
|
||||
# replay_convert = replay_buffer_converted[i]
|
||||
# dataset_convert = dataset[i]
|
||||
# for key in replay_convert.keys():
|
||||
# if key in {"index", "episode_index", "frame_index", "timestamp", "task_index"}:
|
||||
# continue
|
||||
# if key in dataset_convert.keys():
|
||||
# assert torch.equal(replay_convert[key], dataset_convert[key])
|
||||
# print(f"Key {key} is equal : {replay_convert[key].size()}, {dataset_convert[key].size()}")
|
||||
# re_reconverted_dataset = ReplayBuffer.from_lerobot_dataset(
|
||||
# replay_buffer_converted, state_keys=["observation.image", "observation.state"], device="cpu"
|
||||
# )
|
||||
# for _ in range(20):
|
||||
# batch = re_reconverted_dataset.sample(32)
|
||||
|
||||
# for key in batch.keys():
|
||||
# if key in {"state", "next_state"}:
|
||||
# for key_state in batch[key].keys():
|
||||
# print(key_state, batch[key][key_state].size())
|
||||
# continue
|
||||
# print(key, batch[key].size())
|
||||
|
||||
Reference in New Issue
Block a user