fix train.py, stats, eval.py (training is running)
This commit is contained in:
@@ -18,28 +18,26 @@ Example:
|
||||
import argparse
|
||||
import shutil
|
||||
|
||||
from tensordict import TensorDict
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def mock_dataset(in_data_dir, out_data_dir, num_frames):
|
||||
in_data_dir = Path(in_data_dir)
|
||||
out_data_dir = Path(out_data_dir)
|
||||
|
||||
# load full dataset as a tensor dict
|
||||
in_td_data = TensorDict.load_memmap(in_data_dir / "replay_buffer")
|
||||
# copy the first `n` frames for each data key so that we have real data
|
||||
in_data_dict = torch.load(in_data_dir / "data_dict.pth")
|
||||
out_data_dict = {key: in_data_dict[key][:num_frames].clone() for key in in_data_dict}
|
||||
torch.save(out_data_dict, out_data_dir / "data_dict.pth")
|
||||
|
||||
# use 1 frame to know the specification of the dataset
|
||||
# and copy it over `n` frames in the test artifact directory
|
||||
out_td_data = in_td_data[0].expand(num_frames).memmap_like(out_data_dir / "replay_buffer")
|
||||
# copy the full mapping between data_id and episode since it's small
|
||||
in_ids_per_ep_path = in_data_dir / "data_ids_per_episode.pth"
|
||||
out_ids_per_ep_path = out_data_dir / "data_ids_per_episode.pth"
|
||||
shutil.copy(in_ids_per_ep_path, out_ids_per_ep_path)
|
||||
|
||||
# copy the first `n` frames so that we have real data
|
||||
out_td_data[:num_frames] = in_td_data[:num_frames].clone()
|
||||
|
||||
# make sure everything has been properly written
|
||||
out_td_data.lock_()
|
||||
|
||||
# copy the full statistics of dataset since it's pretty small
|
||||
# copy the full statistics of dataset since it's small
|
||||
in_stats_path = in_data_dir / "stats.pth"
|
||||
out_stats_path = out_data_dir / "stats.pth"
|
||||
shutil.copy(in_stats_path, out_stats_path)
|
||||
|
||||
@@ -59,11 +59,7 @@ def test_factory(env_name, dataset_id):
|
||||
# )
|
||||
# dataset = make_dataset(cfg)
|
||||
# # Get all of the data.
|
||||
# all_data = TensorDictReplayBuffer(
|
||||
# storage=buffer._storage,
|
||||
# batch_size=len(buffer),
|
||||
# sampler=SamplerWithoutReplacement(),
|
||||
# ).sample().float()
|
||||
# all_data = dataset.data_dict
|
||||
# # Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched
|
||||
# # computation of the statistics. While doing this, we also make sure it works when we don't divide the
|
||||
# # dataset into even batches.
|
||||
|
||||
Reference in New Issue
Block a user