diff --git a/lerobot/common/envs/configs.py b/lerobot/common/envs/configs.py index 06415764..66edde6b 100644 --- a/lerobot/common/envs/configs.py +++ b/lerobot/common/envs/configs.py @@ -224,6 +224,9 @@ class HILSerlRobotEnvConfig(EnvConfig): push_to_hub: bool = True pretrained_policy_name_or_path: Optional[str] = None reward_classifier_pretrained_path: Optional[str] = None + number_of_steps_after_success: int = ( + 0 # For the reward classifier, to record more positive examples after a success + ) def gym_kwargs(self) -> dict: return {} diff --git a/lerobot/scripts/server/gym_manipulator.py b/lerobot/scripts/server/gym_manipulator.py index 7232b6f3..2b221d71 100644 --- a/lerobot/scripts/server/gym_manipulator.py +++ b/lerobot/scripts/server/gym_manipulator.py @@ -1923,7 +1923,7 @@ def init_reward_classifier(cfg): ########################################################### -def record_dataset(env, policy, cfg, success_collection_steps=0): +def record_dataset(env, policy, cfg): """ Record a dataset of robot interactions using either a policy or teleop. @@ -1940,7 +1940,7 @@ def record_dataset(env, policy, cfg, success_collection_steps=0): - fps: Frames per second for recording - push_to_hub: Whether to push dataset to Hugging Face Hub - task: Name/description of the task being recorded - success_collection_steps: Number of additional steps to continue recording after + - number_of_steps_after_success: Number of additional steps to continue recording after a success (reward=1) is detected. This helps collect more positive examples for reward classifier training. """ @@ -2047,7 +2047,7 @@ def record_dataset(env, policy, cfg, success_collection_steps=0): really_done = terminated or truncated if success_detected: success_steps_collected += 1 - really_done = success_steps_collected >= success_collection_steps + really_done = success_steps_collected >= cfg.number_of_steps_after_success frame["next.done"] = np.array([really_done], dtype=bool) frame["task"] = cfg.task @@ -2065,7 +2065,7 @@ def record_dataset(env, policy, cfg, success_collection_steps=0): if (terminated or truncated) and not success_detected: # Regular termination without success break - elif success_detected and success_steps_collected >= success_collection_steps: + elif success_detected and success_steps_collected >= cfg.number_of_steps_after_success: # We've collected enough success states logging.info(f"Collected {success_steps_collected} additional success states") break @@ -2139,12 +2139,10 @@ def main(cfg: EnvConfig): policy.to(cfg.device) policy.eval() - # Get success_collection_steps from config or default to 15 record_dataset( env, policy=policy, cfg=cfg, - success_collection_steps=0, ) exit()