Fix online training (#94)

This commit is contained in:
Remi
2024-04-23 18:54:55 +02:00
committed by GitHub
parent 1030ea0070
commit c1bcf857c5
3 changed files with 46 additions and 13 deletions

View File

@@ -41,7 +41,7 @@ import gymnasium as gym
import imageio
import numpy as np
import torch
from datasets import Dataset
from datasets import Dataset, Features, Image, Sequence, Value
from huggingface_hub import snapshot_download
from PIL import Image as PILImage
from tqdm import trange
@@ -270,8 +270,34 @@ def eval_policy(
data_dict[key].append(img)
data_dict["index"] = torch.arange(0, total_frames, 1)
episode_data_index["from"] = torch.tensor(episode_data_index["from"])
episode_data_index["to"] = torch.tensor(episode_data_index["to"])
hf_dataset = Dataset.from_dict(data_dict)
# TODO(rcadene): clean this
features = {}
for key in observations:
if "image" in key:
features[key] = Image()
else:
features[key] = Sequence(
length=data_dict[key].shape[1], feature=Value(dtype="float32", id=None)
)
features.update(
{
"action": Sequence(
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
),
"episode_index": Value(dtype="int64", id=None),
"frame_index": Value(dtype="int64", id=None),
"timestamp": Value(dtype="float32", id=None),
"next.reward": Value(dtype="float32", id=None),
"next.done": Value(dtype="bool", id=None),
#'next.success': Value(dtype='bool', id=None),
"index": Value(dtype="int64", id=None),
}
)
features = Features(features)
hf_dataset = Dataset.from_dict(data_dict, features=features)
hf_dataset.set_transform(hf_transform_to_torch)
if max_episodes_rendered > 0: