Refactor SACPolicy and learner server for improved replay buffer management

- Updated SACPolicy to create critic heads using a list comprehension for better readability.
- Simplified the saving and loading of models using `save_model` and `load_model` functions from the safetensors library.
- Introduced `initialize_offline_replay_buffer` function in the learner server to streamline offline dataset handling and replay buffer initialization.
- Enhanced logging for dataset loading processes to improve traceability during training.
This commit is contained in:
AdilZouitine
2025-03-18 14:57:15 +00:00
committed by Michel Aractingi
parent 7b01e16439
commit 0959694bab
2 changed files with 159 additions and 138 deletions

View File

@@ -121,7 +121,7 @@ def load_training_state(
return None, None
training_state = torch.load(
logger.last_checkpoint_dir / logger.training_state_file_name
logger.last_checkpoint_dir / logger.training_state_file_name, weights_only=False
)
if isinstance(training_state["optimizer"], dict):
@@ -160,6 +160,7 @@ def initialize_replay_buffer(
optimize_memory=True,
)
logging.info("Resume training load the online dataset")
dataset = LeRobotDataset(
repo_id=cfg.dataset_repo_id,
local_files_only=True,
@@ -174,6 +175,37 @@ def initialize_replay_buffer(
)
def initialize_offline_replay_buffer(
cfg: DictConfig,
logger: Logger,
device: str,
storage_device: str,
active_action_dims: list[int] | None = None,
) -> ReplayBuffer:
if not cfg.resume:
logging.info("make_dataset offline buffer")
offline_dataset = make_dataset(cfg)
if cfg.resume:
logging.info("load offline dataset")
offline_dataset = LeRobotDataset(
repo_id=cfg.dataset_repo_id,
local_files_only=True,
root=logger.log_dir / "dataset_offline",
)
logging.info("Convert to a offline replay buffer")
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
offline_dataset,
device=device,
state_keys=cfg.policy.input_shapes.keys(),
action_mask=active_action_dims,
action_delta=cfg.env.wrapper.delta_action,
storage_device=storage_device,
optimize_memory=True,
)
return offline_replay_buffer
def get_observation_features(
policy: SACPolicy, observations: torch.Tensor, next_observations: torch.Tensor
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
@@ -447,9 +479,6 @@ def add_actor_information_and_train(
offline_replay_buffer = None
if cfg.dataset_repo_id is not None:
logging.info("make_dataset offline buffer")
offline_dataset = make_dataset(cfg)
logging.info("Convertion to a offline replay buffer")
active_action_dims = None
if cfg.env.wrapper.joint_masking_action_space is not None:
active_action_dims = [
@@ -457,14 +486,12 @@ def add_actor_information_and_train(
for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space)
if mask
]
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
offline_dataset,
offline_replay_buffer = initialize_offline_replay_buffer(
cfg=cfg,
logger=logger,
device=device,
state_keys=cfg.policy.input_shapes.keys(),
action_mask=active_action_dims,
action_delta=cfg.env.wrapper.delta_action,
storage_device=storage_device,
optimize_memory=True,
active_action_dims=active_action_dims,
)
batch_size: int = batch_size // 2 # We will sample from both replay buffer
@@ -714,6 +741,19 @@ def add_actor_information_and_train(
replay_buffer.to_lerobot_dataset(
cfg.dataset_repo_id, fps=cfg.fps, root=logger.log_dir / "dataset"
)
if offline_replay_buffer is not None:
dataset_dir = logger.log_dir / "dataset_offline"
if dataset_dir.exists() and dataset_dir.is_dir():
shutil.rmtree(
dataset_dir,
)
offline_replay_buffer.to_lerobot_dataset(
cfg.dataset_repo_id,
fps=cfg.fps,
root=logger.log_dir / "dataset_offline",
)
logging.info("Resume training")