import torch from torch import nn def populate_queues(queues, batch): for key in batch: if len(queues[key]) != queues[key].maxlen: # initialize by copying the first observation several times until the queue is full while len(queues[key]) != queues[key].maxlen: queues[key].append(batch[key]) else: # add latest observation to the queue queues[key].append(batch[key]) return queues def get_device_from_parameters(module: nn.Module) -> torch.device: """Get a module's device by checking one of its parameters. Note: assumes that all parameters have the same device """ return next(iter(module.parameters())).device def get_dtype_from_parameters(module: nn.Module) -> torch.dtype: """Get a module's parameter dtype by checking one of its parameters. Note: assumes that all parameters have the same dtype. """ return next(iter(module.parameters())).dtype