Refactor env to add key word arguments from config yaml (#223)
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
This commit is contained in:
@@ -69,12 +69,10 @@ def load_from_raw(raw_dir: Path, out_dir: Path, fps: int):
|
|||||||
# "nearest" is the best option over "backward", since the latter can desynchronizes camera timestamps by
|
# "nearest" is the best option over "backward", since the latter can desynchronizes camera timestamps by
|
||||||
# matching timestamps that are too far appart, in order to fit the backward constraints. It's not the case for "nearest".
|
# matching timestamps that are too far appart, in order to fit the backward constraints. It's not the case for "nearest".
|
||||||
# However, note that "nearest" might synchronize the reference camera with other cameras on slightly future timestamps.
|
# However, note that "nearest" might synchronize the reference camera with other cameras on slightly future timestamps.
|
||||||
# This is not a problem when the tolerance is set to be low enough to avoid matching timestamps that
|
|
||||||
# are too far appart.
|
# are too far appart.
|
||||||
direction="nearest",
|
direction="nearest",
|
||||||
tolerance=pd.Timedelta(f"{1/fps} seconds"),
|
tolerance=pd.Timedelta(f"{1/fps} seconds"),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Remove rows with episode_index -1 which indicates data that correspond to in-between episodes
|
# Remove rows with episode_index -1 which indicates data that correspond to in-between episodes
|
||||||
df = df[df["episode_index"] != -1]
|
df = df[df["episode_index"] != -1]
|
||||||
|
|
||||||
@@ -89,9 +87,10 @@ def load_from_raw(raw_dir: Path, out_dir: Path, fps: int):
|
|||||||
raise ValueError(path)
|
raise ValueError(path)
|
||||||
episode_index = int(match.group(1))
|
episode_index = int(match.group(1))
|
||||||
episode_index_per_cam[key] = episode_index
|
episode_index_per_cam[key] = episode_index
|
||||||
assert (
|
if len(set(episode_index_per_cam.values())) != 1:
|
||||||
len(set(episode_index_per_cam.values())) == 1
|
raise ValueError(
|
||||||
), f"All cameras are expected to belong to the same episode, but getting {episode_index_per_cam}"
|
f"All cameras are expected to belong to the same episode, but getting {episode_index_per_cam}"
|
||||||
|
)
|
||||||
return episode_index
|
return episode_index
|
||||||
|
|
||||||
df["episode_index"] = df.apply(get_episode_index, axis=1)
|
df["episode_index"] = df.apply(get_episode_index, axis=1)
|
||||||
@@ -119,7 +118,8 @@ def load_from_raw(raw_dir: Path, out_dir: Path, fps: int):
|
|||||||
# sanity check episode indices go from 0 to n-1
|
# sanity check episode indices go from 0 to n-1
|
||||||
ep_ids = [ep_idx for ep_idx, _ in df.groupby("episode_index")]
|
ep_ids = [ep_idx for ep_idx, _ in df.groupby("episode_index")]
|
||||||
expected_ep_ids = list(range(df["episode_index"].max() + 1))
|
expected_ep_ids = list(range(df["episode_index"].max() + 1))
|
||||||
assert ep_ids == expected_ep_ids, f"Episodes indices go from {ep_ids} instead of {expected_ep_ids}"
|
if ep_ids != expected_ep_ids:
|
||||||
|
raise ValueError(f"Episodes indices go from {ep_ids} instead of {expected_ep_ids}")
|
||||||
|
|
||||||
# Create symlink to raw videos directory (that needs to be absolute not relative)
|
# Create symlink to raw videos directory (that needs to be absolute not relative)
|
||||||
out_dir.mkdir(parents=True, exist_ok=True)
|
out_dir.mkdir(parents=True, exist_ok=True)
|
||||||
@@ -132,7 +132,8 @@ def load_from_raw(raw_dir: Path, out_dir: Path, fps: int):
|
|||||||
continue
|
continue
|
||||||
for ep_idx in ep_ids:
|
for ep_idx in ep_ids:
|
||||||
video_path = videos_dir / f"{key}_episode_{ep_idx:06d}.mp4"
|
video_path = videos_dir / f"{key}_episode_{ep_idx:06d}.mp4"
|
||||||
assert video_path.exists(), f"Video file not found in {video_path}"
|
if not video_path.exists():
|
||||||
|
raise ValueError(f"Video file not found in {video_path}")
|
||||||
|
|
||||||
data_dict = {}
|
data_dict = {}
|
||||||
for key in df:
|
for key in df:
|
||||||
@@ -144,7 +145,8 @@ def load_from_raw(raw_dir: Path, out_dir: Path, fps: int):
|
|||||||
|
|
||||||
# sanity check the video path is well formated
|
# sanity check the video path is well formated
|
||||||
video_path = videos_dir.parent / data_dict[key][0]["path"]
|
video_path = videos_dir.parent / data_dict[key][0]["path"]
|
||||||
assert video_path.exists(), f"Video file not found in {video_path}"
|
if not video_path.exists():
|
||||||
|
raise ValueError(f"Video file not found in {video_path}")
|
||||||
# is number
|
# is number
|
||||||
elif df[key].iloc[0].ndim == 0 or df[key].iloc[0].shape[0] == 1:
|
elif df[key].iloc[0].ndim == 0 or df[key].iloc[0].shape[0] == 1:
|
||||||
data_dict[key] = torch.from_numpy(df[key].values)
|
data_dict[key] = torch.from_numpy(df[key].values)
|
||||||
|
|||||||
@@ -27,14 +27,6 @@ def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv
|
|||||||
if n_envs is not None and n_envs < 1:
|
if n_envs is not None and n_envs < 1:
|
||||||
raise ValueError("`n_envs must be at least 1")
|
raise ValueError("`n_envs must be at least 1")
|
||||||
|
|
||||||
kwargs = {
|
|
||||||
"obs_type": "pixels_agent_pos",
|
|
||||||
"render_mode": "rgb_array",
|
|
||||||
"max_episode_steps": cfg.env.episode_length,
|
|
||||||
"visualization_width": 384,
|
|
||||||
"visualization_height": 384,
|
|
||||||
}
|
|
||||||
|
|
||||||
package_name = f"gym_{cfg.env.name}"
|
package_name = f"gym_{cfg.env.name}"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -46,12 +38,16 @@ def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv
|
|||||||
raise e
|
raise e
|
||||||
|
|
||||||
gym_handle = f"{package_name}/{cfg.env.task}"
|
gym_handle = f"{package_name}/{cfg.env.task}"
|
||||||
|
gym_kwgs = dict(cfg.env.get("gym", {}))
|
||||||
|
|
||||||
|
if cfg.env.get("episode_length"):
|
||||||
|
gym_kwgs["max_episode_steps"] = cfg.env.episode_length
|
||||||
|
|
||||||
# batched version of the env that returns an observation of shape (b, c)
|
# batched version of the env that returns an observation of shape (b, c)
|
||||||
env_cls = gym.vector.AsyncVectorEnv if cfg.eval.use_async_envs else gym.vector.SyncVectorEnv
|
env_cls = gym.vector.AsyncVectorEnv if cfg.eval.use_async_envs else gym.vector.SyncVectorEnv
|
||||||
env = env_cls(
|
env = env_cls(
|
||||||
[
|
[
|
||||||
lambda: gym.make(gym_handle, disable_env_checker=True, **kwargs)
|
lambda: gym.make(gym_handle, disable_env_checker=True, **gym_kwgs)
|
||||||
for _ in range(n_envs if n_envs is not None else cfg.eval.batch_size)
|
for _ in range(n_envs if n_envs is not None else cfg.eval.batch_size)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -37,6 +37,8 @@ training:
|
|||||||
save_freq: ???
|
save_freq: ???
|
||||||
log_freq: 250
|
log_freq: 250
|
||||||
save_checkpoint: true
|
save_checkpoint: true
|
||||||
|
num_workers: 4
|
||||||
|
batch_size: ???
|
||||||
|
|
||||||
eval:
|
eval:
|
||||||
n_episodes: 1
|
n_episodes: 1
|
||||||
|
|||||||
10
lerobot/configs/env/aloha.yaml
vendored
10
lerobot/configs/env/aloha.yaml
vendored
@@ -5,10 +5,10 @@ fps: 50
|
|||||||
env:
|
env:
|
||||||
name: aloha
|
name: aloha
|
||||||
task: AlohaInsertion-v0
|
task: AlohaInsertion-v0
|
||||||
from_pixels: True
|
|
||||||
pixels_only: False
|
|
||||||
image_size: [3, 480, 640]
|
|
||||||
episode_length: 400
|
|
||||||
fps: ${fps}
|
|
||||||
state_dim: 14
|
state_dim: 14
|
||||||
action_dim: 14
|
action_dim: 14
|
||||||
|
fps: ${fps}
|
||||||
|
episode_length: 400
|
||||||
|
gym:
|
||||||
|
obs_type: pixels_agent_pos
|
||||||
|
render_mode: rgb_array
|
||||||
|
|||||||
11
lerobot/configs/env/pusht.yaml
vendored
11
lerobot/configs/env/pusht.yaml
vendored
@@ -5,10 +5,13 @@ fps: 10
|
|||||||
env:
|
env:
|
||||||
name: pusht
|
name: pusht
|
||||||
task: PushT-v0
|
task: PushT-v0
|
||||||
from_pixels: True
|
|
||||||
pixels_only: False
|
|
||||||
image_size: 96
|
image_size: 96
|
||||||
episode_length: 300
|
|
||||||
fps: ${fps}
|
|
||||||
state_dim: 2
|
state_dim: 2
|
||||||
action_dim: 2
|
action_dim: 2
|
||||||
|
fps: ${fps}
|
||||||
|
episode_length: 300
|
||||||
|
gym:
|
||||||
|
obs_type: pixels_agent_pos
|
||||||
|
render_mode: rgb_array
|
||||||
|
visualization_width: 384
|
||||||
|
visualization_height: 384
|
||||||
|
|||||||
11
lerobot/configs/env/xarm.yaml
vendored
11
lerobot/configs/env/xarm.yaml
vendored
@@ -5,10 +5,13 @@ fps: 15
|
|||||||
env:
|
env:
|
||||||
name: xarm
|
name: xarm
|
||||||
task: XarmLift-v0
|
task: XarmLift-v0
|
||||||
from_pixels: True
|
|
||||||
pixels_only: False
|
|
||||||
image_size: 84
|
image_size: 84
|
||||||
episode_length: 25
|
|
||||||
fps: ${fps}
|
|
||||||
state_dim: 4
|
state_dim: 4
|
||||||
action_dim: 4
|
action_dim: 4
|
||||||
|
fps: ${fps}
|
||||||
|
episode_length: 25
|
||||||
|
gym:
|
||||||
|
obs_type: pixels_agent_pos
|
||||||
|
render_mode: rgb_array
|
||||||
|
visualization_width: 384
|
||||||
|
visualization_height: 384
|
||||||
|
|||||||
@@ -281,8 +281,12 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||||||
logging.info("make_dataset")
|
logging.info("make_dataset")
|
||||||
offline_dataset = make_dataset(cfg)
|
offline_dataset = make_dataset(cfg)
|
||||||
|
|
||||||
logging.info("make_env")
|
# Create environment used for evaluating checkpoints during training on simulation data.
|
||||||
eval_env = make_env(cfg)
|
# On real-world data, no need to create an environment as evaluations are done outside train.py,
|
||||||
|
# using the eval.py instead, with gym_dora environment and dora-rs.
|
||||||
|
if cfg.training.eval_freq > 0:
|
||||||
|
logging.info("make_env")
|
||||||
|
eval_env = make_env(cfg)
|
||||||
|
|
||||||
logging.info("make_policy")
|
logging.info("make_policy")
|
||||||
policy = make_policy(
|
policy = make_policy(
|
||||||
@@ -315,7 +319,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||||||
|
|
||||||
# Note: this helper will be used in offline and online training loops.
|
# Note: this helper will be used in offline and online training loops.
|
||||||
def evaluate_and_checkpoint_if_needed(step):
|
def evaluate_and_checkpoint_if_needed(step):
|
||||||
if step % cfg.training.eval_freq == 0:
|
if cfg.training.eval_freq > 0 and step % cfg.training.eval_freq == 0:
|
||||||
logging.info(f"Eval policy at step {step}")
|
logging.info(f"Eval policy at step {step}")
|
||||||
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext():
|
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext():
|
||||||
eval_info = eval_policy(
|
eval_info = eval_policy(
|
||||||
@@ -349,7 +353,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||||||
# create dataloader for offline training
|
# create dataloader for offline training
|
||||||
dataloader = torch.utils.data.DataLoader(
|
dataloader = torch.utils.data.DataLoader(
|
||||||
offline_dataset,
|
offline_dataset,
|
||||||
num_workers=4,
|
num_workers=cfg.training.num_workers,
|
||||||
batch_size=cfg.training.batch_size,
|
batch_size=cfg.training.batch_size,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
pin_memory=device.type != "cpu",
|
pin_memory=device.type != "cpu",
|
||||||
@@ -386,6 +390,16 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||||||
|
|
||||||
step += 1
|
step += 1
|
||||||
|
|
||||||
|
logging.info("End of offline training")
|
||||||
|
|
||||||
|
if cfg.training.online_steps == 0:
|
||||||
|
if cfg.training.eval_freq > 0:
|
||||||
|
eval_env.close()
|
||||||
|
return
|
||||||
|
|
||||||
|
# create an env dedicated to online episodes collection from policy rollout
|
||||||
|
online_training_env = make_env(cfg, n_envs=1)
|
||||||
|
|
||||||
# create an empty online dataset similar to offline dataset
|
# create an empty online dataset similar to offline dataset
|
||||||
online_dataset = deepcopy(offline_dataset)
|
online_dataset = deepcopy(offline_dataset)
|
||||||
online_dataset.hf_dataset = {}
|
online_dataset.hf_dataset = {}
|
||||||
@@ -406,8 +420,11 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||||||
drop_last=False,
|
drop_last=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
eval_env.close()
|
logging.info("End of online training")
|
||||||
logging.info("End of training")
|
|
||||||
|
if cfg.training.eval_freq > 0:
|
||||||
|
eval_env.close()
|
||||||
|
online_training_env.close()
|
||||||
|
|
||||||
|
|
||||||
@hydra.main(version_base="1.2", config_name="default", config_path="../configs")
|
@hydra.main(version_base="1.2", config_name="default", config_path="../configs")
|
||||||
|
|||||||
Reference in New Issue
Block a user