WIP WIP train.py works, loss going down WIP eval.py Fix WIP (eval running, TODO: verify results reproduced) Eval works! (testing reproducibility) WIP pretrained model pusht reproduces same results as torchrl pretrained model pusht reproduces same results as torchrl Remove AbstractPolicy, Move all queues in select_action WIP test_datasets passed (TODO: re-enable NormalizeTransform)
161 lines
5.4 KiB
Python
161 lines
5.4 KiB
Python
import pickle
|
|
import zipfile
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
import tqdm
|
|
|
|
from lerobot.common.datasets.utils import load_data_with_delta_timestamps
|
|
|
|
|
|
def download(raw_dir):
|
|
import gdown
|
|
|
|
raw_dir.mkdir(parents=True, exist_ok=True)
|
|
url = "https://drive.google.com/uc?id=1nhxpykGtPDhmQKm-_B8zBSywVRdgeVya"
|
|
zip_path = raw_dir / "data.zip"
|
|
gdown.download(url, str(zip_path), quiet=False)
|
|
print("Extracting...")
|
|
with zipfile.ZipFile(str(zip_path), "r") as zip_f:
|
|
for member in zip_f.namelist():
|
|
if member.startswith("data/xarm") and member.endswith(".pkl"):
|
|
print(member)
|
|
zip_f.extract(member=member)
|
|
zip_path.unlink()
|
|
|
|
|
|
class SimxarmDataset(torch.utils.data.Dataset):
|
|
available_datasets = [
|
|
"xarm_lift_medium",
|
|
]
|
|
fps = 15
|
|
image_keys = ["observation.image"]
|
|
|
|
def __init__(
|
|
self,
|
|
dataset_id: str,
|
|
version: str | None = "v1.1",
|
|
root: Path | None = None,
|
|
transform: callable = None,
|
|
delta_timestamps: dict[list[float]] | None = None,
|
|
):
|
|
super().__init__()
|
|
self.dataset_id = dataset_id
|
|
self.version = version
|
|
self.root = root
|
|
self.transform = transform
|
|
self.delta_timestamps = delta_timestamps
|
|
|
|
data_dir = self.root / f"{self.dataset_id}"
|
|
if (data_dir / "data_dict.pth").exists() and (data_dir / "data_ids_per_episode.pth").exists():
|
|
self.data_dict = torch.load(data_dir / "data_dict.pth")
|
|
self.data_ids_per_episode = torch.load(data_dir / "data_ids_per_episode.pth")
|
|
else:
|
|
self._download_and_preproc_obsolete()
|
|
data_dir.mkdir(parents=True, exist_ok=True)
|
|
torch.save(self.data_dict, data_dir / "data_dict.pth")
|
|
torch.save(self.data_ids_per_episode, data_dir / "data_ids_per_episode.pth")
|
|
|
|
@property
|
|
def num_samples(self) -> int:
|
|
return len(self.data_dict["index"])
|
|
|
|
@property
|
|
def num_episodes(self) -> int:
|
|
return len(self.data_ids_per_episode)
|
|
|
|
def __len__(self):
|
|
return self.num_samples
|
|
|
|
def __getitem__(self, idx):
|
|
item = {}
|
|
|
|
# get episode id and timestamp of the sampled frame
|
|
current_ts = self.data_dict["timestamp"][idx].item()
|
|
episode = self.data_dict["episode"][idx].item()
|
|
|
|
for key in self.data_dict:
|
|
if self.delta_timestamps is not None and key in self.delta_timestamps:
|
|
data, is_pad = load_data_with_delta_timestamps(
|
|
self.data_dict,
|
|
self.data_ids_per_episode,
|
|
self.delta_timestamps,
|
|
key,
|
|
current_ts,
|
|
episode,
|
|
)
|
|
item[key] = data
|
|
item[f"{key}_is_pad"] = is_pad
|
|
else:
|
|
item[key] = self.data_dict[key][idx]
|
|
|
|
if self.transform is not None:
|
|
item = self.transform(item)
|
|
|
|
return item
|
|
|
|
def _download_and_preproc_obsolete(self):
|
|
assert self.root is not None
|
|
raw_dir = self.root / f"{self.dataset_id}_raw"
|
|
if not raw_dir.exists():
|
|
download(raw_dir)
|
|
|
|
dataset_path = self.root / f"{self.dataset_id}" / "buffer.pkl"
|
|
print(f"Using offline dataset '{dataset_path}'")
|
|
with open(dataset_path, "rb") as f:
|
|
dataset_dict = pickle.load(f)
|
|
|
|
total_frames = dataset_dict["actions"].shape[0]
|
|
|
|
self.data_ids_per_episode = {}
|
|
ep_dicts = []
|
|
|
|
idx0 = 0
|
|
idx1 = 0
|
|
episode_id = 0
|
|
for i in tqdm.tqdm(range(total_frames)):
|
|
idx1 += 1
|
|
|
|
if not dataset_dict["dones"][i]:
|
|
continue
|
|
|
|
num_frames = idx1 - idx0
|
|
|
|
image = torch.tensor(dataset_dict["observations"]["rgb"][idx0:idx1])
|
|
state = torch.tensor(dataset_dict["observations"]["state"][idx0:idx1])
|
|
action = torch.tensor(dataset_dict["actions"][idx0:idx1])
|
|
# TODO(rcadene): concat the last "next_observations" to "observations"
|
|
# next_image = torch.tensor(dataset_dict["next_observations"]["rgb"][idx0:idx1])
|
|
# next_state = torch.tensor(dataset_dict["next_observations"]["state"][idx0:idx1])
|
|
next_reward = torch.tensor(dataset_dict["rewards"][idx0:idx1])
|
|
next_done = torch.tensor(dataset_dict["dones"][idx0:idx1])
|
|
|
|
ep_dict = {
|
|
"observation.image": image,
|
|
"observation.state": state,
|
|
"action": action,
|
|
"episode": torch.tensor([episode_id] * num_frames, dtype=torch.int),
|
|
"frame_id": torch.arange(0, num_frames, 1),
|
|
"timestamp": torch.arange(0, num_frames, 1) / self.fps,
|
|
# "next.observation.image": next_image,
|
|
# "next.observation.state": next_state,
|
|
"next.reward": next_reward,
|
|
"next.done": next_done,
|
|
}
|
|
ep_dicts.append(ep_dict)
|
|
|
|
assert isinstance(episode_id, int)
|
|
self.data_ids_per_episode[episode_id] = torch.arange(idx0, idx1, 1)
|
|
assert len(self.data_ids_per_episode[episode_id]) == num_frames
|
|
|
|
idx0 = idx1
|
|
episode_id += 1
|
|
|
|
self.data_dict = {}
|
|
|
|
keys = ep_dicts[0].keys()
|
|
for key in keys:
|
|
self.data_dict[key] = torch.cat([x[key] for x in ep_dicts])
|
|
|
|
self.data_dict["index"] = torch.arange(0, total_frames, 1)
|