Update ManiSkill configuration and replay buffer to support truncation and dataset handling

- Reduced image size in ManiSkill environment configuration from 128 to 64
- Added support for truncation in replay buffer and actor server
- Updated SAC policy configuration to use a specific dataset and modify vision encoder settings
- Improved dataset conversion process with progress tracking and task naming
- Added flexibility for joint action space masking in learner server
This commit is contained in:
AdilZouitine
2025-02-24 16:53:37 +00:00
committed by Michel Aractingi
parent d3b84ecd6f
commit 4c73891575
5 changed files with 78 additions and 27 deletions

View File

@@ -31,6 +31,7 @@ class Transition(TypedDict):
reward: float
next_state: dict[str, torch.Tensor]
done: bool
truncated: bool
complementary_info: dict[str, Any] = None
@@ -40,6 +41,7 @@ class BatchTransition(TypedDict):
reward: torch.Tensor
next_state: dict[str, torch.Tensor]
done: torch.Tensor
truncated: torch.Tensor
def move_transition_to_device(
@@ -70,6 +72,11 @@ def move_transition_to_device(
device, non_blocking=device.type == "cuda"
)
if isinstance(transition["truncated"], torch.Tensor):
transition["truncated"] = transition["truncated"].to(
device, non_blocking=device.type == "cuda"
)
# Move next_state tensors to CPU
transition["next_state"] = {
key: val.to(device, non_blocking=device.type == "cuda")
@@ -205,6 +212,7 @@ class ReplayBuffer:
reward: float,
next_state: dict[str, torch.Tensor],
done: bool,
truncated: bool,
complementary_info: Optional[dict[str, torch.Tensor]] = None,
):
"""Saves a transition, ensuring tensors are stored on the designated storage device."""
@@ -229,6 +237,7 @@ class ReplayBuffer:
reward=reward,
next_state=next_state,
done=done,
truncated=truncated,
complementary_info=complementary_info,
)
self.position = (self.position + 1) % self.capacity
@@ -294,6 +303,7 @@ class ReplayBuffer:
reward=data["reward"],
next_state=data["next_state"],
done=data["done"],
truncated=False,
)
return replay_buffer
@@ -352,6 +362,8 @@ class ReplayBuffer:
# ----- 3) Reward and done -----
reward = float(current_sample["next.reward"].item()) # ensure float
done = bool(current_sample["next.done"].item()) # ensure bool
# TODO: (azouitine) Handle truncation properly
truncated = bool(current_sample["next.done"].item()) # ensure bool
# ----- 4) Next state -----
# If not done and the next sample is in the same episode, we pull the next sample's state.
@@ -374,6 +386,7 @@ class ReplayBuffer:
reward=reward,
next_state=next_state,
done=done,
truncated=truncated,
)
transitions.append(transition)
@@ -419,6 +432,11 @@ class ReplayBuffer:
[t["done"] for t in list_of_transitions], dtype=torch.float32
).to(self.device)
# -- Build batched truncateds --
batch_truncateds = torch.tensor(
[t["truncated"] for t in list_of_transitions], dtype=torch.float32
).to(self.device)
# Return a BatchTransition typed dict
return BatchTransition(
state=batch_state,
@@ -426,6 +444,7 @@ class ReplayBuffer:
reward=batch_rewards,
next_state=batch_next_state,
done=batch_dones,
truncated=batch_truncateds,
)
def to_lerobot_dataset(
@@ -501,7 +520,7 @@ class ReplayBuffer:
# Start writing images if needed. If you have no image features, this is harmless.
# Set num_processes or num_threads if you want concurrency.
lerobot_dataset.start_image_writer(num_processes=0, num_threads=2)
lerobot_dataset.start_image_writer(num_processes=0, num_threads=3)
# --------------------------------------------------------------------------------------------
# Convert transitions into episodes and frames
@@ -513,7 +532,11 @@ class ReplayBuffer:
)
frame_idx_in_episode = 0
for global_frame_idx, transition in enumerate(self.memory):
for global_frame_idx, transition in tqdm(
enumerate(self.memory),
desc="Converting replay buffer to dataset",
total=len(self.memory),
):
frame_dict = {}
# Fill the data for state keys
@@ -546,14 +569,15 @@ class ReplayBuffer:
# Move to next frame
frame_idx_in_episode += 1
# If we reached an episode boundary, call save_episode, reset counters
if transition["done"]:
# TODO: (azouitine) Handle truncation properly
if transition["done"] or transition["truncated"]:
# Use some placeholder name for the task
lerobot_dataset.save_episode(task="from_replay_buffer")
lerobot_dataset.save_episode(task=task_name)
episode_index += 1
frame_idx_in_episode = 0
# Start a new buffer for the next episode
lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer(
episode_index
episode_index=episode_index
)
# We are done adding frames
@@ -624,6 +648,10 @@ def concatenate_batch_transitions(
left_batch_transitions["done"] = torch.cat(
[left_batch_transitions["done"], right_batch_transition["done"]], dim=0
)
left_batch_transitions["truncated"] = torch.cat(
[left_batch_transitions["truncated"], right_batch_transition["truncated"]],
dim=0,
)
return left_batch_transitions