Improve dataset examples (#82)
Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
@@ -19,7 +19,7 @@ class XarmDataset(torch.utils.data.Dataset):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_id: str,
|
||||
dataset_id: str = "xarm_lift_medium",
|
||||
version: str | None = "v1.0",
|
||||
root: Path | None = None,
|
||||
split: str = "train",
|
||||
@@ -34,31 +34,31 @@ class XarmDataset(torch.utils.data.Dataset):
|
||||
self.transform = transform
|
||||
self.delta_timestamps = delta_timestamps
|
||||
if self.root is not None:
|
||||
self.data_dict = load_from_disk(Path(self.root) / self.dataset_id / self.split)
|
||||
self.hf_dataset = load_from_disk(Path(self.root) / self.dataset_id / self.split)
|
||||
else:
|
||||
self.data_dict = load_dataset(
|
||||
self.hf_dataset = load_dataset(
|
||||
f"lerobot/{self.dataset_id}", revision=self.version, split=self.split
|
||||
)
|
||||
self.data_dict = self.data_dict.with_format("torch")
|
||||
self.hf_dataset = self.hf_dataset.with_format("torch")
|
||||
|
||||
@property
|
||||
def num_samples(self) -> int:
|
||||
return len(self.data_dict)
|
||||
return len(self.hf_dataset)
|
||||
|
||||
@property
|
||||
def num_episodes(self) -> int:
|
||||
return len(self.data_dict.unique("episode_id"))
|
||||
return len(self.hf_dataset.unique("episode_id"))
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def __getitem__(self, idx):
|
||||
item = self.data_dict[idx]
|
||||
item = self.hf_dataset[idx]
|
||||
|
||||
if self.delta_timestamps is not None:
|
||||
item = load_previous_and_future_frames(
|
||||
item,
|
||||
self.data_dict,
|
||||
self.hf_dataset,
|
||||
self.delta_timestamps,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user