@@ -91,9 +91,9 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData
|
||||
)
|
||||
|
||||
if isinstance(cfg.dataset_repo_id, str):
|
||||
# TODO (aliberts): add 'episodes' arg from config after removing hydra
|
||||
dataset = LeRobotDataset(
|
||||
cfg.dataset_repo_id,
|
||||
split=split,
|
||||
delta_timestamps=cfg.training.get("delta_timestamps"),
|
||||
image_transforms=image_transforms,
|
||||
video_backend=cfg.video_backend,
|
||||
@@ -101,7 +101,6 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData
|
||||
else:
|
||||
dataset = MultiLeRobotDataset(
|
||||
cfg.dataset_repo_id,
|
||||
split=split,
|
||||
delta_timestamps=cfg.training.get("delta_timestamps"),
|
||||
image_transforms=image_transforms,
|
||||
video_backend=cfg.video_backend,
|
||||
@@ -112,6 +111,6 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData
|
||||
for stats_type, listconfig in stats_dict.items():
|
||||
# example of stats_type: min, max, mean, std
|
||||
stats = OmegaConf.to_container(listconfig, resolve=True)
|
||||
dataset.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)
|
||||
dataset.meta.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)
|
||||
|
||||
return dataset
|
||||
|
||||
Reference in New Issue
Block a user