Add storage device configuration for SAC policy and replay buffer
- Introduce `storage_device` parameter in SAC configuration and training settings - Update learner server to use configurable storage device for replay buffer - Reduce online buffer capacity in ManiSkill configuration - Modify replay buffer initialization to support custom storage device
This commit is contained in:
committed by
Michel Aractingi
parent
ae51c19b3c
commit
bb69cb3c8c
@@ -146,14 +146,14 @@ def log_training_info(cfg: DictConfig, out_dir: str, policy: nn.Module) -> None:
|
||||
|
||||
|
||||
def initialize_replay_buffer(
|
||||
cfg: DictConfig, logger: Logger, device: str
|
||||
cfg: DictConfig, logger: Logger, device: str, storage_device:str
|
||||
) -> ReplayBuffer:
|
||||
if not cfg.resume:
|
||||
return ReplayBuffer(
|
||||
capacity=cfg.training.online_buffer_capacity,
|
||||
device=device,
|
||||
state_keys=cfg.policy.input_shapes.keys(),
|
||||
storage_device=device,
|
||||
storage_device=storage_device,
|
||||
optimize_memory=True,
|
||||
)
|
||||
|
||||
@@ -596,6 +596,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
set_global_seed(cfg.seed)
|
||||
|
||||
device = get_safe_torch_device(cfg.device, log=True)
|
||||
storage_device = get_safe_torch_device(cfg_device=cfg.training.storage_device)
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
@@ -628,7 +629,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
|
||||
log_training_info(cfg, out_dir, policy)
|
||||
|
||||
replay_buffer = initialize_replay_buffer(cfg, logger, device)
|
||||
replay_buffer = initialize_replay_buffer(cfg, logger, device, storage_device)
|
||||
batch_size = cfg.training.batch_size
|
||||
offline_replay_buffer = None
|
||||
|
||||
@@ -649,7 +650,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
state_keys=cfg.policy.input_shapes.keys(),
|
||||
action_mask=active_action_dims,
|
||||
action_delta=cfg.env.wrapper.delta_action,
|
||||
storage_device=device,
|
||||
storage_device=storage_device,
|
||||
optimize_memory=True,
|
||||
)
|
||||
batch_size: int = batch_size // 2 # We will sample from both replay buffer
|
||||
|
||||
Reference in New Issue
Block a user