added names in record_dataset function of gym_manipulator
This commit is contained in:
committed by
AdilZouitine
parent
35743b72de
commit
f762e2758f
@@ -462,7 +462,7 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea
|
|||||||
|
|
||||||
names = ft["names"]
|
names = ft["names"]
|
||||||
# Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets.
|
# Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets.
|
||||||
if names and names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
|
if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
|
||||||
shape = (shape[2], shape[0], shape[1])
|
shape = (shape[2], shape[0], shape[1])
|
||||||
elif key == "observation.environment_state":
|
elif key == "observation.environment_state":
|
||||||
type = FeatureType.ENV
|
type = FeatureType.ENV
|
||||||
|
|||||||
@@ -1949,6 +1949,10 @@ def record_dataset(env, policy, cfg, success_collection_steps=15):
|
|||||||
# Setup initial action (zero action if using teleop)
|
# Setup initial action (zero action if using teleop)
|
||||||
action = env.action_space.sample() * 0.0
|
action = env.action_space.sample() * 0.0
|
||||||
|
|
||||||
|
action_names = ["delta_x_ee", "delta_y_ee", "delta_z_ee"]
|
||||||
|
if cfg.wrapper.use_gripper:
|
||||||
|
action_names.append("gripper_delta")
|
||||||
|
|
||||||
# Configure dataset features based on environment spaces
|
# Configure dataset features based on environment spaces
|
||||||
features = {
|
features = {
|
||||||
"observation.state": {
|
"observation.state": {
|
||||||
@@ -1958,8 +1962,8 @@ def record_dataset(env, policy, cfg, success_collection_steps=15):
|
|||||||
},
|
},
|
||||||
"action": {
|
"action": {
|
||||||
"dtype": "float32",
|
"dtype": "float32",
|
||||||
"shape": env.action_space.shape,
|
"shape": (len(action_names),),
|
||||||
"names": None,
|
"names": action_names,
|
||||||
},
|
},
|
||||||
"next.reward": {"dtype": "float32", "shape": (1,), "names": None},
|
"next.reward": {"dtype": "float32", "shape": (1,), "names": None},
|
||||||
"next.done": {"dtype": "bool", "shape": (1,), "names": None},
|
"next.done": {"dtype": "bool", "shape": (1,), "names": None},
|
||||||
@@ -1976,7 +1980,7 @@ def record_dataset(env, policy, cfg, success_collection_steps=15):
|
|||||||
features[key] = {
|
features[key] = {
|
||||||
"dtype": "video",
|
"dtype": "video",
|
||||||
"shape": env.observation_space[key].shape,
|
"shape": env.observation_space[key].shape,
|
||||||
"names": None,
|
"names": ["channels", "height", "width"],
|
||||||
}
|
}
|
||||||
|
|
||||||
# Create dataset
|
# Create dataset
|
||||||
|
|||||||
Reference in New Issue
Block a user