diff --git a/lerobot/scripts/control_sim_robot.py b/lerobot/scripts/control_sim_robot.py index 00e1694f..45a57821 100644 --- a/lerobot/scripts/control_sim_robot.py +++ b/lerobot/scripts/control_sim_robot.py @@ -242,6 +242,7 @@ def create_rl_hf_dataset(data_dict): ) features["reward"] = Value(dtype="float32", id=None) + features["seed"] = Value(dtype="int64", id=None) features["episode_index"] = Value(dtype="int64", id=None) features["frame_index"] = Value(dtype="int64", id=None) features["timestamp"] = Value(dtype="float32", id=None) @@ -360,24 +361,29 @@ def record( show_images = multiprocessing.Process(target=show_image_observations, args=(observations_queue, )) show_images.start() + state_keys_dict = env_cfg.state_keys + image_keys = env_cfg.image_keys with concurrent.futures.ThreadPoolExecutor(max_workers=num_image_writers) as executor: # Start recording all episodes # start reading from leader, disable stop flag in leader process while episode_index < num_episodes: logging.info(f"Recording episode {episode_index}") say(f"Recording episode {episode_index}") - ep_dict = {'action':[], 'observation.state': [], 'reward':[]} + ep_dict = {'action':[], 'reward':[]} + for k in state_keys_dict: + ep_dict[k] = [] frame_index = 0 timestamp = 0 start_episode_t = time.perf_counter() - observation, info = env.reset() + + # save seed so we can restore the environment state when we want to replay the trajectories + seed = np.random.randint(0,1e5) + observation, info = env.reset(seed=seed) #with stop_reading_leader.get_lock(): #stop_reading_leader.Value = 0 read_leader.start() while timestamp < episode_time_s: action = command_queue.get() - image_keys = [key for key in observation if "image" in key] - state_keys = [key for key in observation if "image" not in key] for key in image_keys: str_key = key if key.startswith('observation.images.') else 'observation.images.' + key futures += [ @@ -388,16 +394,14 @@ def record( if not is_headless() and visualize_images: observations_queue.put(observation) - state_obs = [] - for key in state_keys: - state_obs.append(torch.from_numpy(observation[key])) - ep_dict['observation.state'].append(torch.hstack(state_obs) * 180.0 / np.pi) + for key, obs_key in state_keys_dict.items(): + ep_dict[key].append(torch.from_numpy(observation[obs_key])) # Advance the sim environment if len(action.shape) == 1: action = np.expand_dims(action, 0) observation, reward, _, _ , info = env.step(action) - ep_dict['action'].append(torch.from_numpy(action) * 180.0 / np.pi) + ep_dict['action'].append(torch.from_numpy(action)) ep_dict['reward'].append(torch.tensor(reward)) print(reward) @@ -444,10 +448,12 @@ def record( img_path = imgs_dir / f"frame_{i:06d}.png" ep_dict[key].append({"path": str(img_path)}) - ep_dict['observation.state'] = torch.vstack(ep_dict['observation.state']) - ep_dict['action'] = torch.vstack(ep_dict['action']) + for key in state_keys_dict: + ep_dict[key] = torch.vstack(ep_dict[key]) * 180.0 / np.pi + ep_dict['action'] = torch.vstack(ep_dict['action']) * 180.0 / np.pi ep_dict['reward'] = torch.stack(ep_dict['reward']) + ep_dict["seed"] = torch.tensor([seed] * num_frames) ep_dict["episode_index"] = torch.tensor([episode_index] * num_frames) ep_dict["frame_index"] = torch.arange(0, num_frames, 1) ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps @@ -489,7 +495,8 @@ def record( observations_queue.close() break else: - print('Waiting for ten seconds before starting the next recording session.....') + print('Waiting for two seconds before starting the next recording session.....') + busy_wait(2) num_episodes = episode_index @@ -579,10 +586,11 @@ def replay(env, episodes: list, fps: int | None = None, root="data", repo_id="le dataset = LeRobotDataset(repo_id, root=root) items = dataset.hf_dataset.select_columns("action") + seeds = dataset.hf_dataset.select_columns("seed")['seed'] for episode in episodes: - env.reset() from_idx = dataset.episode_data_index["from"][episode].item() to_idx = dataset.episode_data_index["to"][episode].item() + env.reset(seed=seeds[from_idx].item()) logging.info("Replaying episode") say("Replaying episode", blocking=True)