optimize shard
This commit is contained in:
@@ -205,7 +205,16 @@ def create_lerobot_dataset(
|
||||
|
||||
builder = tfds.builder(dataset_name, data_dir=data_dir, version=version)
|
||||
features = generate_features_from_raw(dataset_name, builder, use_videos)
|
||||
raw_dataset = builder.as_dataset(split="train")
|
||||
|
||||
if num_shards is not None:
|
||||
if num_shards != builder.info.splits["train"].num_shards:
|
||||
raise ValueError()
|
||||
if shard_index >= builder.info.splits["train"].num_shards:
|
||||
raise ValueError()
|
||||
|
||||
raw_dataset = builder.as_dataset(split=f"train[{shard_index}shard]")
|
||||
else:
|
||||
raw_dataset = builder.as_dataset(split="train")
|
||||
|
||||
if fps is None:
|
||||
if dataset_name in OXE_DATASET_CONFIGS:
|
||||
@@ -234,8 +243,6 @@ def create_lerobot_dataset(
|
||||
dataset_name,
|
||||
lerobot_dataset,
|
||||
raw_dataset,
|
||||
num_shards=num_shards,
|
||||
shard_index=shard_index,
|
||||
)
|
||||
|
||||
if push_to_hub:
|
||||
|
||||
Reference in New Issue
Block a user