forked from tangger/lerobot
Fix unit tests
This commit is contained in:
@@ -296,7 +296,6 @@ def add_frame(dataset, observation, action):
|
||||
|
||||
ep_dict[key].append(frame_info)
|
||||
|
||||
dataset["image_keys"] = img_keys # used for video generation
|
||||
dataset["current_frame_index"] += 1
|
||||
|
||||
|
||||
@@ -389,9 +388,6 @@ def from_dataset_to_lerobot_dataset(dataset, play_sounds):
|
||||
image_keys = [key for key in data_dict if "image" in key]
|
||||
encode_videos(dataset, image_keys, play_sounds)
|
||||
|
||||
total_frames = data_dict["frame_index"].shape[0]
|
||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||
|
||||
hf_dataset = to_hf_dataset(data_dict, video)
|
||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||
|
||||
|
||||
@@ -200,28 +200,6 @@ def to_hf_dataset(data_dict, video) -> Dataset:
|
||||
features["next.done"] = Value(dtype="bool", id=None)
|
||||
features["index"] = Value(dtype="int64", id=None)
|
||||
|
||||
for key in data_dict:
|
||||
if isinstance(data_dict[key], list):
|
||||
print(key, len(data_dict[key]))
|
||||
elif isinstance(data_dict[key], torch.Tensor):
|
||||
print(key, data_dict[key].shape)
|
||||
else:
|
||||
print(key, data_dict[key])
|
||||
|
||||
data_dict["episode_index"] = data_dict["episode_index"].tolist()
|
||||
data_dict["frame_index"] = data_dict["frame_index"].tolist()
|
||||
data_dict["timestamp"] = data_dict["timestamp"].tolist()
|
||||
data_dict["next.done"] = data_dict["next.done"].tolist()
|
||||
data_dict["index"] = data_dict["index"].tolist()
|
||||
|
||||
for key in data_dict:
|
||||
if isinstance(data_dict[key], list):
|
||||
print(key, len(data_dict[key]))
|
||||
elif isinstance(data_dict[key], torch.Tensor):
|
||||
print(key, data_dict[key].shape)
|
||||
else:
|
||||
print(key, data_dict[key])
|
||||
|
||||
hf_dataset = Dataset.from_dict(data_dict, features=Features(features))
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
return hf_dataset
|
||||
|
||||
@@ -7,6 +7,7 @@ import logging
|
||||
import time
|
||||
import traceback
|
||||
from contextlib import nullcontext
|
||||
from copy import copy
|
||||
from functools import cache
|
||||
|
||||
import cv2
|
||||
@@ -90,6 +91,7 @@ def has_method(_object: object, method_name: str):
|
||||
|
||||
|
||||
def predict_action(observation, policy, device, use_amp):
|
||||
observation = copy(observation)
|
||||
with (
|
||||
torch.inference_mode(),
|
||||
torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(),
|
||||
@@ -297,7 +299,9 @@ def stop_recording(robot, listener, display_cameras):
|
||||
|
||||
def sanity_check_dataset_name(repo_id, policy):
|
||||
_, dataset_name = repo_id.split("/")
|
||||
if dataset_name.startswith("eval_") and policy is None:
|
||||
# either repo_id doesnt start with "eval_" and there is no policy
|
||||
# or repo_id starts with "eval_" and there is a policy
|
||||
if dataset_name.startswith("eval_") == (policy is None):
|
||||
raise ValueError(
|
||||
f"Your dataset name begins by 'eval_' ({dataset_name}) but no policy is provided ({policy})."
|
||||
)
|
||||
|
||||
@@ -201,11 +201,11 @@ def teleoperate(robot: Robot, fps: int | None = None, teleop_time_s: float | Non
|
||||
@safe_disconnect
|
||||
def record(
|
||||
robot: Robot,
|
||||
root: str,
|
||||
repo_id: str,
|
||||
pretrained_policy_name_or_path: str | None = None,
|
||||
policy_overrides: List[str] | None = None,
|
||||
fps: int | None = None,
|
||||
root="data",
|
||||
repo_id="lerobot/debug",
|
||||
warmup_time_s=2,
|
||||
episode_time_s=10,
|
||||
reset_time_s=5,
|
||||
|
||||
Reference in New Issue
Block a user