forked from tangger/lerobot
Update ManiSkill configuration and replay buffer to support truncation and dataset handling
- Reduced image size in ManiSkill environment configuration from 128 to 64 - Added support for truncation in replay buffer and actor server - Updated SAC policy configuration to use a specific dataset and modify vision encoder settings - Improved dataset conversion process with progress tracking and task naming - Added flexibility for joint action space masking in learner server
This commit is contained in:
committed by
Michel Aractingi
parent
d3b84ecd6f
commit
4c73891575
@@ -31,6 +31,7 @@ class Transition(TypedDict):
|
||||
reward: float
|
||||
next_state: dict[str, torch.Tensor]
|
||||
done: bool
|
||||
truncated: bool
|
||||
complementary_info: dict[str, Any] = None
|
||||
|
||||
|
||||
@@ -40,6 +41,7 @@ class BatchTransition(TypedDict):
|
||||
reward: torch.Tensor
|
||||
next_state: dict[str, torch.Tensor]
|
||||
done: torch.Tensor
|
||||
truncated: torch.Tensor
|
||||
|
||||
|
||||
def move_transition_to_device(
|
||||
@@ -70,6 +72,11 @@ def move_transition_to_device(
|
||||
device, non_blocking=device.type == "cuda"
|
||||
)
|
||||
|
||||
if isinstance(transition["truncated"], torch.Tensor):
|
||||
transition["truncated"] = transition["truncated"].to(
|
||||
device, non_blocking=device.type == "cuda"
|
||||
)
|
||||
|
||||
# Move next_state tensors to CPU
|
||||
transition["next_state"] = {
|
||||
key: val.to(device, non_blocking=device.type == "cuda")
|
||||
@@ -205,6 +212,7 @@ class ReplayBuffer:
|
||||
reward: float,
|
||||
next_state: dict[str, torch.Tensor],
|
||||
done: bool,
|
||||
truncated: bool,
|
||||
complementary_info: Optional[dict[str, torch.Tensor]] = None,
|
||||
):
|
||||
"""Saves a transition, ensuring tensors are stored on the designated storage device."""
|
||||
@@ -229,6 +237,7 @@ class ReplayBuffer:
|
||||
reward=reward,
|
||||
next_state=next_state,
|
||||
done=done,
|
||||
truncated=truncated,
|
||||
complementary_info=complementary_info,
|
||||
)
|
||||
self.position = (self.position + 1) % self.capacity
|
||||
@@ -294,6 +303,7 @@ class ReplayBuffer:
|
||||
reward=data["reward"],
|
||||
next_state=data["next_state"],
|
||||
done=data["done"],
|
||||
truncated=False,
|
||||
)
|
||||
return replay_buffer
|
||||
|
||||
@@ -352,6 +362,8 @@ class ReplayBuffer:
|
||||
# ----- 3) Reward and done -----
|
||||
reward = float(current_sample["next.reward"].item()) # ensure float
|
||||
done = bool(current_sample["next.done"].item()) # ensure bool
|
||||
# TODO: (azouitine) Handle truncation properly
|
||||
truncated = bool(current_sample["next.done"].item()) # ensure bool
|
||||
|
||||
# ----- 4) Next state -----
|
||||
# If not done and the next sample is in the same episode, we pull the next sample's state.
|
||||
@@ -374,6 +386,7 @@ class ReplayBuffer:
|
||||
reward=reward,
|
||||
next_state=next_state,
|
||||
done=done,
|
||||
truncated=truncated,
|
||||
)
|
||||
transitions.append(transition)
|
||||
|
||||
@@ -419,6 +432,11 @@ class ReplayBuffer:
|
||||
[t["done"] for t in list_of_transitions], dtype=torch.float32
|
||||
).to(self.device)
|
||||
|
||||
# -- Build batched truncateds --
|
||||
batch_truncateds = torch.tensor(
|
||||
[t["truncated"] for t in list_of_transitions], dtype=torch.float32
|
||||
).to(self.device)
|
||||
|
||||
# Return a BatchTransition typed dict
|
||||
return BatchTransition(
|
||||
state=batch_state,
|
||||
@@ -426,6 +444,7 @@ class ReplayBuffer:
|
||||
reward=batch_rewards,
|
||||
next_state=batch_next_state,
|
||||
done=batch_dones,
|
||||
truncated=batch_truncateds,
|
||||
)
|
||||
|
||||
def to_lerobot_dataset(
|
||||
@@ -501,7 +520,7 @@ class ReplayBuffer:
|
||||
|
||||
# Start writing images if needed. If you have no image features, this is harmless.
|
||||
# Set num_processes or num_threads if you want concurrency.
|
||||
lerobot_dataset.start_image_writer(num_processes=0, num_threads=2)
|
||||
lerobot_dataset.start_image_writer(num_processes=0, num_threads=3)
|
||||
|
||||
# --------------------------------------------------------------------------------------------
|
||||
# Convert transitions into episodes and frames
|
||||
@@ -513,7 +532,11 @@ class ReplayBuffer:
|
||||
)
|
||||
|
||||
frame_idx_in_episode = 0
|
||||
for global_frame_idx, transition in enumerate(self.memory):
|
||||
for global_frame_idx, transition in tqdm(
|
||||
enumerate(self.memory),
|
||||
desc="Converting replay buffer to dataset",
|
||||
total=len(self.memory),
|
||||
):
|
||||
frame_dict = {}
|
||||
|
||||
# Fill the data for state keys
|
||||
@@ -546,14 +569,15 @@ class ReplayBuffer:
|
||||
# Move to next frame
|
||||
frame_idx_in_episode += 1
|
||||
# If we reached an episode boundary, call save_episode, reset counters
|
||||
if transition["done"]:
|
||||
# TODO: (azouitine) Handle truncation properly
|
||||
if transition["done"] or transition["truncated"]:
|
||||
# Use some placeholder name for the task
|
||||
lerobot_dataset.save_episode(task="from_replay_buffer")
|
||||
lerobot_dataset.save_episode(task=task_name)
|
||||
episode_index += 1
|
||||
frame_idx_in_episode = 0
|
||||
# Start a new buffer for the next episode
|
||||
lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer(
|
||||
episode_index
|
||||
episode_index=episode_index
|
||||
)
|
||||
|
||||
# We are done adding frames
|
||||
@@ -624,6 +648,10 @@ def concatenate_batch_transitions(
|
||||
left_batch_transitions["done"] = torch.cat(
|
||||
[left_batch_transitions["done"], right_batch_transition["done"]], dim=0
|
||||
)
|
||||
left_batch_transitions["truncated"] = torch.cat(
|
||||
[left_batch_transitions["truncated"], right_batch_transition["truncated"]],
|
||||
dim=0,
|
||||
)
|
||||
return left_batch_transitions
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user