Refactor SACPolicy and learner server for improved replay buffer management
- Updated SACPolicy to create critic heads using a list comprehension for better readability. - Simplified the saving and loading of models using `save_model` and `load_model` functions from the safetensors library. - Introduced `initialize_offline_replay_buffer` function in the learner server to streamline offline dataset handling and replay buffer initialization. - Enhanced logging for dataset loading processes to improve traceability during training.
This commit is contained in:
committed by
Michel Aractingi
parent
7b01e16439
commit
0959694bab
@@ -121,7 +121,7 @@ def load_training_state(
|
||||
return None, None
|
||||
|
||||
training_state = torch.load(
|
||||
logger.last_checkpoint_dir / logger.training_state_file_name
|
||||
logger.last_checkpoint_dir / logger.training_state_file_name, weights_only=False
|
||||
)
|
||||
|
||||
if isinstance(training_state["optimizer"], dict):
|
||||
@@ -160,6 +160,7 @@ def initialize_replay_buffer(
|
||||
optimize_memory=True,
|
||||
)
|
||||
|
||||
logging.info("Resume training load the online dataset")
|
||||
dataset = LeRobotDataset(
|
||||
repo_id=cfg.dataset_repo_id,
|
||||
local_files_only=True,
|
||||
@@ -174,6 +175,37 @@ def initialize_replay_buffer(
|
||||
)
|
||||
|
||||
|
||||
def initialize_offline_replay_buffer(
|
||||
cfg: DictConfig,
|
||||
logger: Logger,
|
||||
device: str,
|
||||
storage_device: str,
|
||||
active_action_dims: list[int] | None = None,
|
||||
) -> ReplayBuffer:
|
||||
if not cfg.resume:
|
||||
logging.info("make_dataset offline buffer")
|
||||
offline_dataset = make_dataset(cfg)
|
||||
if cfg.resume:
|
||||
logging.info("load offline dataset")
|
||||
offline_dataset = LeRobotDataset(
|
||||
repo_id=cfg.dataset_repo_id,
|
||||
local_files_only=True,
|
||||
root=logger.log_dir / "dataset_offline",
|
||||
)
|
||||
|
||||
logging.info("Convert to a offline replay buffer")
|
||||
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
|
||||
offline_dataset,
|
||||
device=device,
|
||||
state_keys=cfg.policy.input_shapes.keys(),
|
||||
action_mask=active_action_dims,
|
||||
action_delta=cfg.env.wrapper.delta_action,
|
||||
storage_device=storage_device,
|
||||
optimize_memory=True,
|
||||
)
|
||||
return offline_replay_buffer
|
||||
|
||||
|
||||
def get_observation_features(
|
||||
policy: SACPolicy, observations: torch.Tensor, next_observations: torch.Tensor
|
||||
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
|
||||
@@ -447,9 +479,6 @@ def add_actor_information_and_train(
|
||||
offline_replay_buffer = None
|
||||
|
||||
if cfg.dataset_repo_id is not None:
|
||||
logging.info("make_dataset offline buffer")
|
||||
offline_dataset = make_dataset(cfg)
|
||||
logging.info("Convertion to a offline replay buffer")
|
||||
active_action_dims = None
|
||||
if cfg.env.wrapper.joint_masking_action_space is not None:
|
||||
active_action_dims = [
|
||||
@@ -457,14 +486,12 @@ def add_actor_information_and_train(
|
||||
for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space)
|
||||
if mask
|
||||
]
|
||||
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
|
||||
offline_dataset,
|
||||
offline_replay_buffer = initialize_offline_replay_buffer(
|
||||
cfg=cfg,
|
||||
logger=logger,
|
||||
device=device,
|
||||
state_keys=cfg.policy.input_shapes.keys(),
|
||||
action_mask=active_action_dims,
|
||||
action_delta=cfg.env.wrapper.delta_action,
|
||||
storage_device=storage_device,
|
||||
optimize_memory=True,
|
||||
active_action_dims=active_action_dims,
|
||||
)
|
||||
batch_size: int = batch_size // 2 # We will sample from both replay buffer
|
||||
|
||||
@@ -714,6 +741,19 @@ def add_actor_information_and_train(
|
||||
replay_buffer.to_lerobot_dataset(
|
||||
cfg.dataset_repo_id, fps=cfg.fps, root=logger.log_dir / "dataset"
|
||||
)
|
||||
if offline_replay_buffer is not None:
|
||||
dataset_dir = logger.log_dir / "dataset_offline"
|
||||
|
||||
if dataset_dir.exists() and dataset_dir.is_dir():
|
||||
shutil.rmtree(
|
||||
dataset_dir,
|
||||
)
|
||||
|
||||
offline_replay_buffer.to_lerobot_dataset(
|
||||
cfg.dataset_repo_id,
|
||||
fps=cfg.fps,
|
||||
root=logger.log_dir / "dataset_offline",
|
||||
)
|
||||
|
||||
logging.info("Resume training")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user