This commit is contained in:
Remi Cadene
2025-02-22 11:12:39 +00:00
parent 689c5efc72
commit 39ad2d16d4
4 changed files with 42 additions and 17 deletions

View File

@@ -34,18 +34,20 @@ python examples/port_datasets/openx_rlds.py \
"""
import argparse
import logging
import re
import shutil
import time
from pathlib import Path
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
import tqdm
from examples.port_datasets.openx_utils.configs import OXE_DATASET_CONFIGS, StateEncoding
from examples.port_datasets.openx_utils.transforms import OXE_STANDARDIZATION_TRANSFORMS
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.utils.utils import get_elapsed_time_in_days_hours_minutes_seconds
np.set_printoptions(precision=2)
@@ -138,13 +140,14 @@ def save_as_lerobot_dataset(
num_shards: int | None = None,
shard_index: int | None = None,
):
start_time = time.time()
total_num_episodes = raw_dataset.cardinality().numpy().item()
print(f"Total number of episodes {total_num_episodes}")
logging.info(f"Total number of episodes {total_num_episodes}")
if num_shards is not None:
sharded_dataset = raw_dataset.shard(num_shards=num_shards, index=shard_index)
sharded_num_episodes = sharded_dataset.cardinality().numpy().item()
print(f"{sharded_num_episodes=}")
logging.info(f"{sharded_num_episodes=}")
num_episodes = sharded_num_episodes
iter_ = iter(sharded_dataset)
else:
@@ -155,13 +158,18 @@ def save_as_lerobot_dataset(
raise ValueError(f"Number of episodes is {num_episodes}, but needs to be positive.")
for episode_index in range(num_episodes):
print(f"{episode_index} / {num_episodes}")
logging.info(f"{episode_index} / {num_episodes} episodes processed")
elapsed_time = time.time() - start_time
d, h, m, s = get_elapsed_time_in_days_hours_minutes_seconds(elapsed_time)
logging.info(f"It has been {d} days, {h} hours, {m} minutes, {s:.3f} seconds")
episode = next(iter_)
print("\nnext\n")
logging.info("next")
episode = transform_raw_dataset(episode, dataset_name)
traj = episode["steps"]
for i in tqdm.tqdm(range(traj["action"].shape[0])):
for i in range(traj["action"].shape[0]):
image_dict = {
f"observation.images.{key}": value[i].numpy()
for key, value in traj["observation"].items()
@@ -176,9 +184,8 @@ def save_as_lerobot_dataset(
}
)
print()
lerobot_dataset.save_episode()
print("\nsave_episode\n")
logging.info("save_episode")
def create_lerobot_dataset(