diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index a765a7e1..542daf2a 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -462,7 +462,7 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea names = ft["names"] # Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets. - if names and names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w) + if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w) shape = (shape[2], shape[0], shape[1]) elif key == "observation.environment_state": type = FeatureType.ENV diff --git a/lerobot/scripts/server/gym_manipulator.py b/lerobot/scripts/server/gym_manipulator.py index 3836d995..1f3f7820 100644 --- a/lerobot/scripts/server/gym_manipulator.py +++ b/lerobot/scripts/server/gym_manipulator.py @@ -1949,6 +1949,10 @@ def record_dataset(env, policy, cfg, success_collection_steps=15): # Setup initial action (zero action if using teleop) action = env.action_space.sample() * 0.0 + action_names = ["delta_x_ee", "delta_y_ee", "delta_z_ee"] + if cfg.wrapper.use_gripper: + action_names.append("gripper_delta") + # Configure dataset features based on environment spaces features = { "observation.state": { @@ -1958,8 +1962,8 @@ def record_dataset(env, policy, cfg, success_collection_steps=15): }, "action": { "dtype": "float32", - "shape": env.action_space.shape, - "names": None, + "shape": (len(action_names),), + "names": action_names, }, "next.reward": {"dtype": "float32", "shape": (1,), "names": None}, "next.done": {"dtype": "bool", "shape": (1,), "names": None}, @@ -1976,7 +1980,7 @@ def record_dataset(env, policy, cfg, success_collection_steps=15): features[key] = { "dtype": "video", "shape": env.observation_space[key].shape, - "names": None, + "names": ["channels", "height", "width"], } # Create dataset