Added number of steps after success as parameter in config

This commit is contained in:
Michel Aractingi
2025-05-09 18:09:10 +02:00
parent fb9bb89cb4
commit b104f8b012
2 changed files with 7 additions and 6 deletions

View File

@@ -224,6 +224,9 @@ class HILSerlRobotEnvConfig(EnvConfig):
push_to_hub: bool = True
pretrained_policy_name_or_path: Optional[str] = None
reward_classifier_pretrained_path: Optional[str] = None
number_of_steps_after_success: int = (
0 # For the reward classifier, to record more positive examples after a success
)
def gym_kwargs(self) -> dict:
return {}

View File

@@ -1923,7 +1923,7 @@ def init_reward_classifier(cfg):
###########################################################
def record_dataset(env, policy, cfg, success_collection_steps=0):
def record_dataset(env, policy, cfg):
"""
Record a dataset of robot interactions using either a policy or teleop.
@@ -1940,7 +1940,7 @@ def record_dataset(env, policy, cfg, success_collection_steps=0):
- fps: Frames per second for recording
- push_to_hub: Whether to push dataset to Hugging Face Hub
- task: Name/description of the task being recorded
success_collection_steps: Number of additional steps to continue recording after
- number_of_steps_after_success: Number of additional steps to continue recording after
a success (reward=1) is detected. This helps collect
more positive examples for reward classifier training.
"""
@@ -2047,7 +2047,7 @@ def record_dataset(env, policy, cfg, success_collection_steps=0):
really_done = terminated or truncated
if success_detected:
success_steps_collected += 1
really_done = success_steps_collected >= success_collection_steps
really_done = success_steps_collected >= cfg.number_of_steps_after_success
frame["next.done"] = np.array([really_done], dtype=bool)
frame["task"] = cfg.task
@@ -2065,7 +2065,7 @@ def record_dataset(env, policy, cfg, success_collection_steps=0):
if (terminated or truncated) and not success_detected:
# Regular termination without success
break
elif success_detected and success_steps_collected >= success_collection_steps:
elif success_detected and success_steps_collected >= cfg.number_of_steps_after_success:
# We've collected enough success states
logging.info(f"Collected {success_steps_collected} additional success states")
break
@@ -2139,12 +2139,10 @@ def main(cfg: EnvConfig):
policy.to(cfg.device)
policy.eval()
# Get success_collection_steps from config or default to 15
record_dataset(
env,
policy=policy,
cfg=cfg,
success_collection_steps=0,
)
exit()