forked from tangger/lerobot
id -> index, finish moving compute_stats before hf_dataset push_to_hub
This commit is contained in:
@@ -19,7 +19,7 @@ from huggingface_hub import HfApi
|
||||
from PIL import Image as PILImage
|
||||
from safetensors.torch import save_file
|
||||
|
||||
from lerobot.common.datasets.utils import compute_stats, flatten_dict
|
||||
from lerobot.common.datasets.utils import compute_stats, convert_images_to_channel_first_tensors, flatten_dict
|
||||
|
||||
|
||||
def download_and_upload(root, revision, dataset_id):
|
||||
@@ -75,28 +75,18 @@ def concatenate_episodes(ep_dicts):
|
||||
for x in ep_dict[key]:
|
||||
data_dict[key].append(x)
|
||||
|
||||
total_frames = data_dict["frame_id"].shape[0]
|
||||
total_frames = data_dict["frame_index"].shape[0]
|
||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||
return data_dict
|
||||
|
||||
|
||||
def push_to_hub(hf_dataset, episode_data_index, info, root, revision, dataset_id):
|
||||
hf_dataset = hf_dataset.with_format("torch")
|
||||
|
||||
def push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dataset_id):
|
||||
# push to main to indicate latest version
|
||||
hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True)
|
||||
|
||||
# push to version branch
|
||||
hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision=revision)
|
||||
|
||||
# get stats
|
||||
stats_pth_path = root / dataset_id / "stats.pth"
|
||||
if stats_pth_path.exists():
|
||||
stats = torch.load(stats_pth_path)
|
||||
else:
|
||||
stats = compute_stats(hf_dataset)
|
||||
torch.save(stats, stats_pth_path)
|
||||
|
||||
# create and store meta_data
|
||||
meta_data_dir = root / dataset_id / "meta_data"
|
||||
meta_data_dir.mkdir(parents=True, exist_ok=True)
|
||||
@@ -237,8 +227,8 @@ def download_and_upload_pusht(root, revision, dataset_id="pusht", fps=10):
|
||||
"observation.image": [PILImage.fromarray(x.numpy()) for x in image],
|
||||
"observation.state": agent_pos,
|
||||
"action": actions[id_from:id_to],
|
||||
"episode_id": torch.tensor([episode_id] * num_frames, dtype=torch.int),
|
||||
"frame_id": torch.arange(0, num_frames, 1),
|
||||
"episode_index": torch.tensor([episode_id] * num_frames, dtype=torch.int),
|
||||
"frame_index": torch.arange(0, num_frames, 1),
|
||||
"timestamp": torch.arange(0, num_frames, 1) / fps,
|
||||
# "next.observation.image": image[1:],
|
||||
# "next.observation.state": agent_pos[1:],
|
||||
@@ -262,8 +252,8 @@ def download_and_upload_pusht(root, revision, dataset_id="pusht", fps=10):
|
||||
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
),
|
||||
"action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)),
|
||||
"episode_id": Value(dtype="int64", id=None),
|
||||
"frame_id": Value(dtype="int64", id=None),
|
||||
"episode_index": Value(dtype="int64", id=None),
|
||||
"frame_index": Value(dtype="int64", id=None),
|
||||
"timestamp": Value(dtype="float32", id=None),
|
||||
"next.reward": Value(dtype="float32", id=None),
|
||||
"next.done": Value(dtype="bool", id=None),
|
||||
@@ -272,11 +262,14 @@ def download_and_upload_pusht(root, revision, dataset_id="pusht", fps=10):
|
||||
}
|
||||
features = Features(features)
|
||||
hf_dataset = Dataset.from_dict(data_dict, features=features)
|
||||
hf_dataset = hf_dataset.with_format("torch")
|
||||
hf_dataset.set_transform(convert_images_to_channel_first_tensors)
|
||||
|
||||
info = {
|
||||
"fps": fps,
|
||||
}
|
||||
push_to_hub(hf_dataset, episode_data_index, info, root, revision, dataset_id)
|
||||
stats = compute_stats(hf_dataset)
|
||||
push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dataset_id)
|
||||
|
||||
|
||||
def download_and_upload_xarm(root, revision, dataset_id, fps=15):
|
||||
@@ -334,8 +327,8 @@ def download_and_upload_xarm(root, revision, dataset_id, fps=15):
|
||||
"observation.image": [PILImage.fromarray(x.numpy()) for x in image],
|
||||
"observation.state": state,
|
||||
"action": action,
|
||||
"episode_id": torch.tensor([episode_id] * num_frames, dtype=torch.int),
|
||||
"frame_id": torch.arange(0, num_frames, 1),
|
||||
"episode_index": torch.tensor([episode_id] * num_frames, dtype=torch.int),
|
||||
"frame_index": torch.arange(0, num_frames, 1),
|
||||
"timestamp": torch.arange(0, num_frames, 1) / fps,
|
||||
# "next.observation.image": next_image,
|
||||
# "next.observation.state": next_state,
|
||||
@@ -358,8 +351,8 @@ def download_and_upload_xarm(root, revision, dataset_id, fps=15):
|
||||
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
),
|
||||
"action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)),
|
||||
"episode_id": Value(dtype="int64", id=None),
|
||||
"frame_id": Value(dtype="int64", id=None),
|
||||
"episode_index": Value(dtype="int64", id=None),
|
||||
"frame_index": Value(dtype="int64", id=None),
|
||||
"timestamp": Value(dtype="float32", id=None),
|
||||
"next.reward": Value(dtype="float32", id=None),
|
||||
"next.done": Value(dtype="bool", id=None),
|
||||
@@ -368,11 +361,14 @@ def download_and_upload_xarm(root, revision, dataset_id, fps=15):
|
||||
}
|
||||
features = Features(features)
|
||||
hf_dataset = Dataset.from_dict(data_dict, features=features)
|
||||
hf_dataset = hf_dataset.with_format("torch")
|
||||
hf_dataset.set_transform(convert_images_to_channel_first_tensors)
|
||||
|
||||
info = {
|
||||
"fps": fps,
|
||||
}
|
||||
push_to_hub(hf_dataset, episode_data_index, info, root, revision, dataset_id)
|
||||
stats = compute_stats(hf_dataset)
|
||||
push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dataset_id)
|
||||
|
||||
|
||||
def download_and_upload_aloha(root, revision, dataset_id, fps=50):
|
||||
@@ -464,8 +460,8 @@ def download_and_upload_aloha(root, revision, dataset_id, fps=50):
|
||||
{
|
||||
"observation.state": state,
|
||||
"action": action,
|
||||
"episode_id": torch.tensor([ep_id] * num_frames),
|
||||
"frame_id": torch.arange(0, num_frames, 1),
|
||||
"episode_index": torch.tensor([ep_id] * num_frames),
|
||||
"frame_index": torch.arange(0, num_frames, 1),
|
||||
"timestamp": torch.arange(0, num_frames, 1) / fps,
|
||||
# "next.observation.state": state,
|
||||
# TODO(rcadene): compute reward and success
|
||||
@@ -493,8 +489,8 @@ def download_and_upload_aloha(root, revision, dataset_id, fps=50):
|
||||
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
),
|
||||
"action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)),
|
||||
"episode_id": Value(dtype="int64", id=None),
|
||||
"frame_id": Value(dtype="int64", id=None),
|
||||
"episode_index": Value(dtype="int64", id=None),
|
||||
"frame_index": Value(dtype="int64", id=None),
|
||||
"timestamp": Value(dtype="float32", id=None),
|
||||
#'next.reward': Value(dtype='float32', id=None),
|
||||
"next.done": Value(dtype="bool", id=None),
|
||||
@@ -503,11 +499,14 @@ def download_and_upload_aloha(root, revision, dataset_id, fps=50):
|
||||
}
|
||||
features = Features(features)
|
||||
hf_dataset = Dataset.from_dict(data_dict, features=features)
|
||||
hf_dataset = hf_dataset.with_format("torch")
|
||||
hf_dataset.set_transform(convert_images_to_channel_first_tensors)
|
||||
|
||||
info = {
|
||||
"fps": fps,
|
||||
}
|
||||
push_to_hub(hf_dataset, episode_data_index, info, root, revision, dataset_id)
|
||||
stats = compute_stats(hf_dataset)
|
||||
push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dataset_id)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user