Adds split_by_episodes to LeRobotDataset (#158)

This commit is contained in:
Radek Osmulski
2024-05-20 22:04:04 +10:00
committed by GitHub
parent 01eae09ba6
commit 9b62c25f6c
5 changed files with 242 additions and 21 deletions

View File

@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO(aliberts): Mute logging for these tests
import io
import subprocess
import sys
from pathlib import Path
@@ -32,6 +33,11 @@ def _run_script(path):
subprocess.run([sys.executable, path], check=True)
def _read_file(path):
with open(path) as file:
return file.read()
def test_example_1():
path = "examples/1_load_lerobot_dataset.py"
_run_script(path)
@@ -39,18 +45,17 @@ def test_example_1():
@require_package("gym_pusht")
def test_examples_3_and_2():
def test_examples_2_through_4():
"""
Train a model with example 3, check the outputs.
Evaluate the trained model with example 2, check the outputs.
Calculate the validation loss with example 4, check the outputs.
"""
path = "examples/3_train_policy.py"
### Test example 3
file_contents = _read_file("examples/3_train_policy.py")
with open(path) as file:
file_contents = file.read()
# Do less steps, use smaller batch, use CPU, and don't complicate things with dataloader workers.
# Do fewer steps, use smaller batch, use CPU, and don't complicate things with dataloader workers.
file_contents = _find_and_replace(
file_contents,
[
@@ -67,16 +72,17 @@ def test_examples_3_and_2():
for file_name in ["model.safetensors", "config.json"]:
assert Path(f"outputs/train/example_pusht_diffusion/{file_name}").exists()
path = "examples/2_evaluate_pretrained_policy.py"
### Test example 2
file_contents = _read_file("examples/2_evaluate_pretrained_policy.py")
with open(path) as file:
file_contents = file.read()
# Do less evals, use CPU, and use the local model.
# Do fewer evals, use CPU, and use the local model.
file_contents = _find_and_replace(
file_contents,
[
('pretrained_policy_path = Path(snapshot_download("lerobot/diffusion_pusht"))', ""),
(
'pretrained_policy_path = Path(snapshot_download("lerobot/diffusion_pusht"))',
"",
),
(
'# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")',
'pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")',
@@ -89,3 +95,34 @@ def test_examples_3_and_2():
exec(file_contents, {})
assert Path("outputs/eval/example_pusht_diffusion/rollout.mp4").exists()
## Test example 4
file_contents = _read_file("examples/4_calculate_validation_loss.py")
# Run on a single example from the last episode, use CPU, and use the local model.
file_contents = _find_and_replace(
file_contents,
[
(
'pretrained_policy_path = Path(snapshot_download("lerobot/diffusion_pusht"))',
"",
),
(
'# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")',
'pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")',
),
('split="train[24342:]"', 'split="train[-1:]"'),
("num_workers=4", "num_workers=0"),
('device = torch.device("cuda")', 'device = torch.device("cpu")'),
("batch_size=64", "batch_size=1"),
],
)
# Capture the output of the script
output_buffer = io.StringIO()
sys.stdout = output_buffer
exec(file_contents, {})
printed_output = output_buffer.getvalue()
# Restore stdout to its original state
sys.stdout = sys.__stdout__
assert "Average loss on validation set" in printed_output