Added support for checkpointing the policy. We can save and load the policy state dict, optimizers state, optimization step and interaction step
Added functions for converting the replay buffer from and to LeRobotDataset. When we want to save the replay buffer, we convert it first to LeRobotDataset format and save it locally and vice-versa. Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
@@ -152,7 +152,7 @@ def serve_actor_service(port=50052):
|
||||
server.wait_for_termination()
|
||||
|
||||
|
||||
def act_with_policy(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None):
|
||||
def act_with_policy(cfg: DictConfig):
|
||||
"""
|
||||
Executes policy interaction within the environment.
|
||||
|
||||
@@ -161,8 +161,6 @@ def act_with_policy(cfg: DictConfig, out_dir: str | None = None, job_name: str |
|
||||
|
||||
Args:
|
||||
cfg (DictConfig): Configuration settings for the interaction process.
|
||||
out_dir (Optional[str]): Directory to store output logs or results. Defaults to None.
|
||||
job_name (Optional[str]): Name of the job for logging or tracking purposes. Defaults to None.
|
||||
"""
|
||||
|
||||
logging.info("make_env online")
|
||||
@@ -189,9 +187,10 @@ def act_with_policy(cfg: DictConfig, out_dir: str | None = None, job_name: str |
|
||||
# Hack: But if we do online training, we do not need dataset_stats
|
||||
dataset_stats=None,
|
||||
# TODO: Handle resume training
|
||||
pretrained_policy_name_or_path=None,
|
||||
device=device,
|
||||
)
|
||||
# pretrained_policy_name_or_path=None,
|
||||
# device=device,
|
||||
# )
|
||||
assert isinstance(policy, nn.Module)
|
||||
|
||||
# HACK for maniskill
|
||||
@@ -295,11 +294,7 @@ def actor_cli(cfg: dict):
|
||||
policy_thread = Thread(
|
||||
target=act_with_policy,
|
||||
daemon=True,
|
||||
args=(
|
||||
cfg,
|
||||
hydra.core.hydra_config.HydraConfig.get().run.dir,
|
||||
hydra.core.hydra_config.HydraConfig.get().job.name,
|
||||
),
|
||||
args=(cfg,),
|
||||
)
|
||||
policy_thread.start()
|
||||
policy_thread.join()
|
||||
|
||||
Reference in New Issue
Block a user