forked from tangger/lerobot
let's go
This commit is contained in:
@@ -34,18 +34,20 @@ python examples/port_datasets/openx_rlds.py \
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import re
|
||||
import shutil
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import tensorflow_datasets as tfds
|
||||
import tqdm
|
||||
|
||||
from examples.port_datasets.openx_utils.configs import OXE_DATASET_CONFIGS, StateEncoding
|
||||
from examples.port_datasets.openx_utils.transforms import OXE_STANDARDIZATION_TRANSFORMS
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.utils.utils import get_elapsed_time_in_days_hours_minutes_seconds
|
||||
|
||||
np.set_printoptions(precision=2)
|
||||
|
||||
@@ -138,13 +140,14 @@ def save_as_lerobot_dataset(
|
||||
num_shards: int | None = None,
|
||||
shard_index: int | None = None,
|
||||
):
|
||||
start_time = time.time()
|
||||
total_num_episodes = raw_dataset.cardinality().numpy().item()
|
||||
print(f"Total number of episodes {total_num_episodes}")
|
||||
logging.info(f"Total number of episodes {total_num_episodes}")
|
||||
|
||||
if num_shards is not None:
|
||||
sharded_dataset = raw_dataset.shard(num_shards=num_shards, index=shard_index)
|
||||
sharded_num_episodes = sharded_dataset.cardinality().numpy().item()
|
||||
print(f"{sharded_num_episodes=}")
|
||||
logging.info(f"{sharded_num_episodes=}")
|
||||
num_episodes = sharded_num_episodes
|
||||
iter_ = iter(sharded_dataset)
|
||||
else:
|
||||
@@ -155,13 +158,18 @@ def save_as_lerobot_dataset(
|
||||
raise ValueError(f"Number of episodes is {num_episodes}, but needs to be positive.")
|
||||
|
||||
for episode_index in range(num_episodes):
|
||||
print(f"{episode_index} / {num_episodes}")
|
||||
logging.info(f"{episode_index} / {num_episodes} episodes processed")
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
d, h, m, s = get_elapsed_time_in_days_hours_minutes_seconds(elapsed_time)
|
||||
logging.info(f"It has been {d} days, {h} hours, {m} minutes, {s:.3f} seconds")
|
||||
|
||||
episode = next(iter_)
|
||||
print("\nnext\n")
|
||||
logging.info("next")
|
||||
episode = transform_raw_dataset(episode, dataset_name)
|
||||
|
||||
traj = episode["steps"]
|
||||
for i in tqdm.tqdm(range(traj["action"].shape[0])):
|
||||
for i in range(traj["action"].shape[0]):
|
||||
image_dict = {
|
||||
f"observation.images.{key}": value[i].numpy()
|
||||
for key, value in traj["observation"].items()
|
||||
@@ -176,9 +184,8 @@ def save_as_lerobot_dataset(
|
||||
}
|
||||
)
|
||||
|
||||
print()
|
||||
lerobot_dataset.save_episode()
|
||||
print("\nsave_episode\n")
|
||||
logging.info("save_episode")
|
||||
|
||||
|
||||
def create_lerobot_dataset(
|
||||
|
||||
Reference in New Issue
Block a user