Added number of steps after success as parameter in config
This commit is contained in:
committed by
AdilZouitine
parent
db86586530
commit
aa793cbd4a
@@ -225,6 +225,9 @@ class HILSerlRobotEnvConfig(EnvConfig):
|
||||
push_to_hub: bool = True
|
||||
pretrained_policy_name_or_path: Optional[str] = None
|
||||
reward_classifier_pretrained_path: Optional[str] = None
|
||||
number_of_steps_after_success: int = (
|
||||
0 # For the reward classifier, to record more positive examples after a success
|
||||
)
|
||||
|
||||
def gym_kwargs(self) -> dict:
|
||||
return {}
|
||||
|
||||
@@ -1923,7 +1923,7 @@ def init_reward_classifier(cfg):
|
||||
###########################################################
|
||||
|
||||
|
||||
def record_dataset(env, policy, cfg, success_collection_steps=0):
|
||||
def record_dataset(env, policy, cfg):
|
||||
"""
|
||||
Record a dataset of robot interactions using either a policy or teleop.
|
||||
|
||||
@@ -1940,7 +1940,7 @@ def record_dataset(env, policy, cfg, success_collection_steps=0):
|
||||
- fps: Frames per second for recording
|
||||
- push_to_hub: Whether to push dataset to Hugging Face Hub
|
||||
- task: Name/description of the task being recorded
|
||||
success_collection_steps: Number of additional steps to continue recording after
|
||||
- number_of_steps_after_success: Number of additional steps to continue recording after
|
||||
a success (reward=1) is detected. This helps collect
|
||||
more positive examples for reward classifier training.
|
||||
"""
|
||||
@@ -2047,7 +2047,7 @@ def record_dataset(env, policy, cfg, success_collection_steps=0):
|
||||
really_done = terminated or truncated
|
||||
if success_detected:
|
||||
success_steps_collected += 1
|
||||
really_done = success_steps_collected >= success_collection_steps
|
||||
really_done = success_steps_collected >= cfg.number_of_steps_after_success
|
||||
|
||||
frame["next.done"] = np.array([really_done], dtype=bool)
|
||||
frame["task"] = cfg.task
|
||||
@@ -2065,7 +2065,7 @@ def record_dataset(env, policy, cfg, success_collection_steps=0):
|
||||
if (terminated or truncated) and not success_detected:
|
||||
# Regular termination without success
|
||||
break
|
||||
elif success_detected and success_steps_collected >= success_collection_steps:
|
||||
elif success_detected and success_steps_collected >= cfg.number_of_steps_after_success:
|
||||
# We've collected enough success states
|
||||
logging.info(f"Collected {success_steps_collected} additional success states")
|
||||
break
|
||||
@@ -2139,12 +2139,10 @@ def main(cfg: EnvConfig):
|
||||
policy.to(cfg.device)
|
||||
policy.eval()
|
||||
|
||||
# Get success_collection_steps from config or default to 15
|
||||
record_dataset(
|
||||
env,
|
||||
policy=policy,
|
||||
cfg=cfg,
|
||||
success_collection_steps=0,
|
||||
)
|
||||
exit()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user