diff --git a/tests/data/pusht/next/done.memmap b/tests/data/pusht/next/done.memmap index 3c77d33bb..44fd709f9 100644 --- a/tests/data/pusht/next/done.memmap +++ b/tests/data/pusht/next/done.memmap @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:d1a8f20ab8c1dead0f61b7b38de300b0ebd0df1d870babfbbe03ce9d2b81e36a +oid sha256:cc2786e1f9910a9d811400edcddaf7075195f7a16b216dcbefba3bc7c4f2ae51 size 50 diff --git a/tests/data/pusht/stats.pth b/tests/data/pusht/stats.pth index 329cb35f7..334a01bef 100644 Binary files a/tests/data/pusht/stats.pth and b/tests/data/pusht/stats.pth differ diff --git a/tests/scripts/create_dataset.py b/tests/scripts/create_dataset.py new file mode 100644 index 000000000..c58280d79 --- /dev/null +++ b/tests/scripts/create_dataset.py @@ -0,0 +1,41 @@ +""" + usage: `python tests/scripts/mock_dataset.py --in-data-dir data/pusht --out-data-dir tests/data/pusht` +""" + +import argparse +import shutil + +from tensordict import TensorDict +from pathlib import Path + + +def mock_dataset(in_data_dir, out_data_dir, num_frames=50): + # load full dataset as a tensor dict + in_td_data = TensorDict.load_memmap(in_data_dir) + + # use 1 frame to know the specification of the dataset + # and copy it over `n` frames in the test artifact directory + out_td_data = in_td_data[0].expand(num_frames).memmap_like(out_data_dir) + + # copy the first `n` frames so that we have real data + out_td_data[:num_frames] = in_td_data[:num_frames].clone() + + # make sure everything has been properly written + out_td_data.lock_() + + # copy the full statistics of dataset since it's pretty small + in_stats_path = Path(in_data_dir) / "stats.pth" + out_stats_path = Path(out_data_dir) / "stats.pth" + shutil.copy(in_stats_path, out_stats_path) + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser(description="Create dataset") + + parser.add_argument("--in-data-dir", type=str, help="Path to input data") + parser.add_argument("--out-data-dir", type=str, help="Path to save the output data") + + args = parser.parse_args() + + mock_dataset(args.in_data_dir, args.out_data_dir) \ No newline at end of file diff --git a/tests/scripts/mock_dataset.py b/tests/scripts/mock_dataset.py new file mode 100644 index 000000000..c58280d79 --- /dev/null +++ b/tests/scripts/mock_dataset.py @@ -0,0 +1,41 @@ +""" + usage: `python tests/scripts/mock_dataset.py --in-data-dir data/pusht --out-data-dir tests/data/pusht` +""" + +import argparse +import shutil + +from tensordict import TensorDict +from pathlib import Path + + +def mock_dataset(in_data_dir, out_data_dir, num_frames=50): + # load full dataset as a tensor dict + in_td_data = TensorDict.load_memmap(in_data_dir) + + # use 1 frame to know the specification of the dataset + # and copy it over `n` frames in the test artifact directory + out_td_data = in_td_data[0].expand(num_frames).memmap_like(out_data_dir) + + # copy the first `n` frames so that we have real data + out_td_data[:num_frames] = in_td_data[:num_frames].clone() + + # make sure everything has been properly written + out_td_data.lock_() + + # copy the full statistics of dataset since it's pretty small + in_stats_path = Path(in_data_dir) / "stats.pth" + out_stats_path = Path(out_data_dir) / "stats.pth" + shutil.copy(in_stats_path, out_stats_path) + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser(description="Create dataset") + + parser.add_argument("--in-data-dir", type=str, help="Path to input data") + parser.add_argument("--out-data-dir", type=str, help="Path to save the output data") + + args = parser.parse_args() + + mock_dataset(args.in_data_dir, args.out_data_dir) \ No newline at end of file