Include observation.environment_state with keypoints in PushT dataset (#303)

Co-authored-by: Remi <re.cadene@gmail.com>
This commit is contained in:
Alexander Soare
2024-07-09 08:27:40 +01:00
committed by GitHub
parent 7bd5ab16d1
commit a4d77b99f0
27 changed files with 739 additions and 609 deletions

View File

@@ -70,6 +70,8 @@ available_datasets_per_env = {
"lerobot/aloha_sim_transfer_cube_human_image",
"lerobot/aloha_sim_transfer_cube_scripted_image",
],
# TODO(alexander-soare): Add "lerobot/pusht_keypoints". Right now we can't because this is too tightly
# coupled with tests.
"pusht": ["lerobot/pusht", "lerobot/pusht_image"],
"xarm": [
"lerobot/xarm_lift_medium",

View File

@@ -36,7 +36,7 @@ from lerobot.common.datasets.utils import (
from lerobot.common.datasets.video_utils import VideoFrame, load_from_videos
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
CODEBASE_VERSION = "v1.4"
CODEBASE_VERSION = "v1.5"
class LeRobotDataset(torch.utils.data.Dataset):

View File

@@ -54,7 +54,14 @@ def check_format(raw_dir):
assert all(nb_frames == zarr_data[dataset].shape[0] for dataset in required_datasets)
def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episodes: list[int] | None = None):
def load_from_raw(
raw_dir: Path,
videos_dir: Path,
fps: int,
video: bool,
episodes: list[int] | None = None,
keypoints_instead_of_image: bool = False,
):
try:
import pymunk
from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely
@@ -105,10 +112,11 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod
assert (episode_ids[from_idx:to_idx] == ep_idx).all()
# get image
image = imgs[from_idx:to_idx]
assert image.min() >= 0.0
assert image.max() <= 255.0
image = image.type(torch.uint8)
if not keypoints_instead_of_image:
image = imgs[from_idx:to_idx]
assert image.min() >= 0.0
assert image.max() <= 255.0
image = image.type(torch.uint8)
# get state
state = states[from_idx:to_idx]
@@ -116,9 +124,11 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod
block_pos = state[:, 2:4]
block_angle = state[:, 4]
# get reward, success, done
# get reward, success, done, and (maybe) keypoints
reward = torch.zeros(num_frames)
success = torch.zeros(num_frames, dtype=torch.bool)
if keypoints_instead_of_image:
keypoints = torch.zeros(num_frames, 16) # 8 keypoints each with 2 coords
done = torch.zeros(num_frames, dtype=torch.bool)
for i in range(num_frames):
space = pymunk.Space()
@@ -134,7 +144,7 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod
]
space.add(*walls)
block_body = PushTEnv.add_tee(space, block_pos[i].tolist(), block_angle[i].item())
block_body, block_shapes = PushTEnv.add_tee(space, block_pos[i].tolist(), block_angle[i].item())
goal_geom = pymunk_to_shapely(goal_body, block_body.shapes)
block_geom = pymunk_to_shapely(block_body, block_body.shapes)
intersection_area = goal_geom.intersection(block_geom).area
@@ -142,33 +152,40 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod
coverage = intersection_area / goal_area
reward[i] = np.clip(coverage / success_threshold, 0, 1)
success[i] = coverage > success_threshold
if keypoints_instead_of_image:
keypoints[i] = torch.from_numpy(PushTEnv.get_keypoints(block_shapes).flatten())
# last step of demonstration is considered done
done[-1] = True
ep_dict = {}
imgs_array = [x.numpy() for x in image]
img_key = "observation.image"
if video:
# save png images in temporary directory
tmp_imgs_dir = videos_dir / "tmp_images"
save_images_concurrently(imgs_array, tmp_imgs_dir)
if not keypoints_instead_of_image:
imgs_array = [x.numpy() for x in image]
img_key = "observation.image"
if video:
# save png images in temporary directory
tmp_imgs_dir = videos_dir / "tmp_images"
save_images_concurrently(imgs_array, tmp_imgs_dir)
# encode images to a mp4 video
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
video_path = videos_dir / fname
encode_video_frames(tmp_imgs_dir, video_path, fps)
# encode images to a mp4 video
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
video_path = videos_dir / fname
encode_video_frames(tmp_imgs_dir, video_path, fps)
# clean temporary images directory
shutil.rmtree(tmp_imgs_dir)
# clean temporary images directory
shutil.rmtree(tmp_imgs_dir)
# store the reference to the video frame
ep_dict[img_key] = [{"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)]
else:
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
# store the reference to the video frame
ep_dict[img_key] = [
{"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)
]
else:
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
ep_dict["observation.state"] = agent_pos
if keypoints_instead_of_image:
ep_dict["observation.environment_state"] = keypoints
ep_dict["action"] = actions[from_idx:to_idx]
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int64)
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
@@ -180,7 +197,6 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod
ep_dict["next.done"] = torch.cat([done[1:], done[[-1]]])
ep_dict["next.success"] = torch.cat([success[1:], success[[-1]]])
ep_dicts.append(ep_dict)
data_dict = concatenate_episodes(ep_dicts)
total_frames = data_dict["frame_index"].shape[0]
@@ -188,17 +204,23 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod
return data_dict
def to_hf_dataset(data_dict, video):
def to_hf_dataset(data_dict, video, keypoints_instead_of_image: bool = False):
features = {}
if video:
features["observation.image"] = VideoFrame()
else:
features["observation.image"] = Image()
if not keypoints_instead_of_image:
if video:
features["observation.image"] = VideoFrame()
else:
features["observation.image"] = Image()
features["observation.state"] = Sequence(
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
)
if keypoints_instead_of_image:
features["observation.environment_state"] = Sequence(
length=data_dict["observation.environment_state"].shape[1],
feature=Value(dtype="float32", id=None),
)
features["action"] = Sequence(
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
)
@@ -222,17 +244,21 @@ def from_raw_to_lerobot_format(
video: bool = True,
episodes: list[int] | None = None,
):
# Manually change this to True to use keypoints of the T instead of an image observation (but don't merge
# with True). Also make sure to use video = 0 in the `push_dataset_to_hub.py` script.
keypoints_instead_of_image = False
# sanity check
check_format(raw_dir)
if fps is None:
fps = 10
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes)
hf_dataset = to_hf_dataset(data_dict, video)
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes, keypoints_instead_of_image)
hf_dataset = to_hf_dataset(data_dict, video, keypoints_instead_of_image)
episode_data_index = calculate_episode_data_index(hf_dataset)
info = {
"fps": fps,
"video": video,
"video": video if not keypoints_instead_of_image else 0,
}
return hf_dataset, episode_data_index, info

View File

@@ -40,6 +40,60 @@ python lerobot/scripts/push_dataset_to_hub.py \
--raw-format umi_zarr \
--repo-id lerobot/umi_cup_in_the_wild
```
**WARNING: Updating an existing dataset**
If you want to update an existing dataset, you need to change the `CODEBASE_VERSION` from `lerobot_dataset.py`
before running `push_dataset_to_hub.py`. This is especially useful if you introduce a breaking change
intentionally or not (i.e. something not backward compatible such as modifying the reward functions used,
deleting some frames at the end of an episode, etc.). That way, people running a previous version of the
codebase won't be affected by your change and backward compatibility is maintained.
For instance, Pusht has many versions to maintain backward compatibility between LeRobot codebase versions:
- [v1.0](https://huggingface.co/datasets/lerobot/pusht/tree/v1.0)
- [v1.1](https://huggingface.co/datasets/lerobot/pusht/tree/v1.1)
- [v1.2](https://huggingface.co/datasets/lerobot/pusht/tree/v1.2)
- [v1.3](https://huggingface.co/datasets/lerobot/pusht/tree/v1.3)
- [v1.4](https://huggingface.co/datasets/lerobot/pusht/tree/v1.4)
- [v1.5](https://huggingface.co/datasets/lerobot/pusht/tree/v1.5) <-- last version
- [main](https://huggingface.co/datasets/lerobot/pusht/tree/main) <-- points to the last version
However, you will need to update the version of ALL the other datasets so that they have the new
`CODEBASE_VERSION` as a branch in their hugging face dataset repository. Don't worry, there is an easy way
that doesn't require to run `push_dataset_to_hub.py`. You can just "branch-out" from the `main` branch on HF
dataset repo by running this script which corresponds to a `git checkout -b` (so no copy or upload needed):
```python
import os
from huggingface_hub import create_branch, hf_hub_download
from huggingface_hub.utils._errors import RepositoryNotFoundError
from lerobot import available_datasets
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1" # makes it easier to see the print-out below
NEW_CODEBASE_VERSION = "v1.5" # REPLACE THIS WITH YOUR DESIRED VERSION
for repo_id in available_datasets:
# First check if the newer version already exists.
try:
hf_hub_download(
repo_id=repo_id, repo_type="dataset", filename=".gitattributes", revision=NEW_CODEBASE_VERSION
)
print(f"Found existing branch for {repo_id}. Please contact a member of the core LeRobot team.")
print("Exiting early")
break
except RepositoryNotFoundError:
# Now create a branch.
create_branch(repo_id, repo_type="dataset", branch=NEW_CODEBASE_VERSION, revision=CODEBASE_VERSION)
print(f"{repo_id} successfully updated")
```
On the other hand, if you are pushing a new dataset, you don't need to worry about any of the instructions
above, nor to be compatible with previous codebase versions.
"""
import argparse
@@ -317,7 +371,10 @@ def main():
parser.add_argument(
"--tests-data-dir",
type=Path,
help="When provided, save tests artifacts into the given directory for (e.g. `--tests-data-dir tests/data/lerobot/pusht`).",
help=(
"When provided, save tests artifacts into the given directory "
"(e.g. `--tests-data-dir tests/data` will save to tests/data/{--repo-id})."
),
)
args = parser.parse_args()