forked from tangger/lerobot
Fix online training (#94)
This commit is contained in:
@@ -41,7 +41,7 @@ import gymnasium as gym
|
||||
import imageio
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from datasets import Dataset, Features, Image, Sequence, Value
|
||||
from huggingface_hub import snapshot_download
|
||||
from PIL import Image as PILImage
|
||||
from tqdm import trange
|
||||
@@ -270,8 +270,34 @@ def eval_policy(
|
||||
data_dict[key].append(img)
|
||||
|
||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||
episode_data_index["from"] = torch.tensor(episode_data_index["from"])
|
||||
episode_data_index["to"] = torch.tensor(episode_data_index["to"])
|
||||
|
||||
hf_dataset = Dataset.from_dict(data_dict)
|
||||
# TODO(rcadene): clean this
|
||||
features = {}
|
||||
for key in observations:
|
||||
if "image" in key:
|
||||
features[key] = Image()
|
||||
else:
|
||||
features[key] = Sequence(
|
||||
length=data_dict[key].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
features.update(
|
||||
{
|
||||
"action": Sequence(
|
||||
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
),
|
||||
"episode_index": Value(dtype="int64", id=None),
|
||||
"frame_index": Value(dtype="int64", id=None),
|
||||
"timestamp": Value(dtype="float32", id=None),
|
||||
"next.reward": Value(dtype="float32", id=None),
|
||||
"next.done": Value(dtype="bool", id=None),
|
||||
#'next.success': Value(dtype='bool', id=None),
|
||||
"index": Value(dtype="int64", id=None),
|
||||
}
|
||||
)
|
||||
features = Features(features)
|
||||
hf_dataset = Dataset.from_dict(data_dict, features=features)
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
|
||||
if max_episodes_rendered > 0:
|
||||
|
||||
Reference in New Issue
Block a user