Add download_and_upload_dataset.py in script, update all datasets, update online training
This commit is contained in:
@@ -1,25 +1,12 @@
|
||||
from pathlib import Path
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
import torch
|
||||
import tqdm
|
||||
from datasets import Dataset, load_dataset
|
||||
from datasets import load_dataset
|
||||
|
||||
from lerobot.common.datasets._diffusion_policy_replay_buffer import (
|
||||
ReplayBuffer as DiffusionPolicyReplayBuffer,
|
||||
)
|
||||
from lerobot.common.datasets.utils import download_and_extract_zip, load_previous_and_future_frames
|
||||
|
||||
# as define in env
|
||||
SUCCESS_THRESHOLD = 0.95 # 95% coverage,
|
||||
|
||||
PUSHT_URL = "https://diffusion-policy.cs.columbia.edu/data/training/pusht.zip"
|
||||
PUSHT_ZARR = Path("pusht/pusht_cchi_v7_replay.zarr")
|
||||
from lerobot.common.datasets.utils import load_previous_and_future_frames
|
||||
|
||||
|
||||
class PushtDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
https://huggingface.co/datasets/lerobot/pusht
|
||||
|
||||
Arguments
|
||||
----------
|
||||
@@ -35,32 +22,19 @@ class PushtDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
dataset_id: str,
|
||||
version: str | None = "v1.2",
|
||||
root: Path | None = None,
|
||||
version: str | None = "v1.0",
|
||||
transform: callable = None,
|
||||
delta_timestamps: dict[list[float]] | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.dataset_id = dataset_id
|
||||
self.version = version
|
||||
self.root = root
|
||||
self.transform = transform
|
||||
self.delta_timestamps = delta_timestamps
|
||||
|
||||
# self.data_dir = self.root / f"{self.dataset_id}"
|
||||
# if (self.data_dir / "data_dict.pth").exists() and (
|
||||
# self.data_dir / "data_ids_per_episode.pth"
|
||||
# ).exists():
|
||||
# self.data_dict = torch.load(self.data_dir / "data_dict.pth")
|
||||
# self.data_ids_per_episode = torch.load(self.data_dir / "data_ids_per_episode.pth")
|
||||
# else:
|
||||
# self._download_and_preproc_obsolete()
|
||||
# self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||
# torch.save(self.data_dict, self.data_dir / "data_dict.pth")
|
||||
# torch.save(self.data_ids_per_episode, self.data_dir / "data_ids_per_episode.pth")
|
||||
|
||||
self.data_dict = load_dataset("lerobot/pusht", split="train")
|
||||
# self.data_dict = load_dataset(f"lerobot/{self.dataset_id}", revision=self.version, split="train")
|
||||
self.data_dict = load_dataset(f"lerobot/{self.dataset_id}", split="train")
|
||||
self.data_dict = self.data_dict.with_format("torch")
|
||||
self.data_dict.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0")
|
||||
|
||||
@property
|
||||
def num_samples(self) -> int:
|
||||
@@ -87,135 +61,3 @@ class PushtDataset(torch.utils.data.Dataset):
|
||||
item = self.transform(item)
|
||||
|
||||
return item
|
||||
|
||||
def _download_and_preproc_obsolete(self):
|
||||
try:
|
||||
import pymunk
|
||||
from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely
|
||||
except ModuleNotFoundError as e:
|
||||
print("`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`")
|
||||
raise e
|
||||
|
||||
assert self.root is not None
|
||||
raw_dir = self.root / f"{self.dataset_id}_raw"
|
||||
zarr_path = (raw_dir / PUSHT_ZARR).resolve()
|
||||
if not zarr_path.is_dir():
|
||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||
download_and_extract_zip(PUSHT_URL, raw_dir)
|
||||
|
||||
# load
|
||||
dataset_dict = DiffusionPolicyReplayBuffer.copy_from_path(
|
||||
zarr_path
|
||||
) # , keys=['img', 'state', 'action'])
|
||||
|
||||
episode_ids = torch.from_numpy(dataset_dict.get_episode_idxs())
|
||||
num_episodes = dataset_dict.meta["episode_ends"].shape[0]
|
||||
total_frames = dataset_dict["action"].shape[0]
|
||||
# to create test artifact
|
||||
# num_episodes = 1
|
||||
# total_frames = 50
|
||||
assert len(
|
||||
{dataset_dict[key].shape[0] for key in dataset_dict.keys()} # noqa: SIM118
|
||||
), "Some data type dont have the same number of total frames."
|
||||
|
||||
# TODO: verify that goal pose is expected to be fixed
|
||||
goal_pos_angle = np.array([256, 256, np.pi / 4]) # x, y, theta (in radians)
|
||||
goal_body = PushTEnv.get_goal_pose_body(goal_pos_angle)
|
||||
|
||||
imgs = torch.from_numpy(dataset_dict["img"])
|
||||
imgs = einops.rearrange(imgs, "b h w c -> b c h w")
|
||||
states = torch.from_numpy(dataset_dict["state"])
|
||||
actions = torch.from_numpy(dataset_dict["action"])
|
||||
|
||||
self.data_ids_per_episode = {}
|
||||
ep_dicts = []
|
||||
|
||||
idx0 = 0
|
||||
for episode_id in tqdm.tqdm(range(num_episodes)):
|
||||
idx1 = dataset_dict.meta["episode_ends"][episode_id]
|
||||
|
||||
num_frames = idx1 - idx0
|
||||
|
||||
assert (episode_ids[idx0:idx1] == episode_id).all()
|
||||
|
||||
image = imgs[idx0:idx1]
|
||||
assert image.min() >= 0.0
|
||||
assert image.max() <= 255.0
|
||||
image = image.type(torch.uint8)
|
||||
|
||||
state = states[idx0:idx1]
|
||||
agent_pos = state[:, :2]
|
||||
block_pos = state[:, 2:4]
|
||||
block_angle = state[:, 4]
|
||||
|
||||
reward = torch.zeros(num_frames)
|
||||
success = torch.zeros(num_frames, dtype=torch.bool)
|
||||
done = torch.zeros(num_frames, dtype=torch.bool)
|
||||
for i in range(num_frames):
|
||||
space = pymunk.Space()
|
||||
space.gravity = 0, 0
|
||||
space.damping = 0
|
||||
|
||||
# Add walls.
|
||||
walls = [
|
||||
PushTEnv.add_segment(space, (5, 506), (5, 5), 2),
|
||||
PushTEnv.add_segment(space, (5, 5), (506, 5), 2),
|
||||
PushTEnv.add_segment(space, (506, 5), (506, 506), 2),
|
||||
PushTEnv.add_segment(space, (5, 506), (506, 506), 2),
|
||||
]
|
||||
space.add(*walls)
|
||||
|
||||
block_body = PushTEnv.add_tee(space, block_pos[i].tolist(), block_angle[i].item())
|
||||
goal_geom = pymunk_to_shapely(goal_body, block_body.shapes)
|
||||
block_geom = pymunk_to_shapely(block_body, block_body.shapes)
|
||||
intersection_area = goal_geom.intersection(block_geom).area
|
||||
goal_area = goal_geom.area
|
||||
coverage = intersection_area / goal_area
|
||||
reward[i] = np.clip(coverage / SUCCESS_THRESHOLD, 0, 1)
|
||||
success[i] = coverage > SUCCESS_THRESHOLD
|
||||
|
||||
# last step of demonstration is considered done
|
||||
done[-1] = True
|
||||
|
||||
ep_dict = {
|
||||
"observation.image": image,
|
||||
"observation.state": agent_pos,
|
||||
"action": actions[idx0:idx1],
|
||||
"episode": torch.tensor([episode_id] * num_frames, dtype=torch.int),
|
||||
"frame_id": torch.arange(0, num_frames, 1),
|
||||
"timestamp": torch.arange(0, num_frames, 1) / self.fps,
|
||||
# "next.observation.image": image[1:],
|
||||
# "next.observation.state": agent_pos[1:],
|
||||
# TODO(rcadene): verify that reward and done are aligned with image and agent_pos
|
||||
"next.reward": torch.cat([reward[1:], reward[[-1]]]),
|
||||
"next.done": torch.cat([done[1:], done[[-1]]]),
|
||||
"next.success": torch.cat([success[1:], success[[-1]]]),
|
||||
}
|
||||
ep_dicts.append(ep_dict)
|
||||
|
||||
assert isinstance(episode_id, int)
|
||||
self.data_ids_per_episode[episode_id] = torch.arange(idx0, idx1, 1)
|
||||
assert len(self.data_ids_per_episode[episode_id]) == num_frames
|
||||
|
||||
idx0 = idx1
|
||||
|
||||
self.data_dict = {}
|
||||
|
||||
keys = ep_dicts[0].keys()
|
||||
for key in keys:
|
||||
self.data_dict[key] = torch.cat([x[key] for x in ep_dicts])
|
||||
|
||||
self.data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||
|
||||
dataset = Dataset.from_dict(self.data_dict)
|
||||
dataset = dataset.with_format("torch")
|
||||
|
||||
def add_episode_data_id_from_to(frame):
|
||||
ep_id = frame["episode"].item()
|
||||
frame["episode_data_id_from"] = self.data_ids_per_episode[ep_id][0]
|
||||
frame["episode_data_id_to"] = self.data_ids_per_episode[ep_id][-1]
|
||||
return frame
|
||||
|
||||
dataset = dataset.map(add_episode_data_id_from_to, num_proc=4)
|
||||
dataset = dataset.rename_column("episode", "episode_id")
|
||||
dataset.push_to_hub("lerobot/pusht", token=True)
|
||||
|
||||
Reference in New Issue
Block a user