fix training
This commit is contained in:
@@ -14,6 +14,8 @@ from torchrl.data.replay_buffers.samplers import SliceSampler
|
||||
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
|
||||
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
|
||||
|
||||
from rl.torchrl.envs.transforms.transforms import Compose
|
||||
|
||||
|
||||
class AbstractExperienceReplay(TensorDictReplayBuffer):
|
||||
def __init__(
|
||||
@@ -78,7 +80,11 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
|
||||
return self._transform
|
||||
|
||||
def set_transform(self, transform):
|
||||
self._transform = transform
|
||||
if not isinstance(transform, Compose):
|
||||
# required since torchrl calls `len(self._transform)` downstream
|
||||
self._transform = Compose(transform)
|
||||
else:
|
||||
self._transform = transform
|
||||
|
||||
def compute_or_load_stats(self, num_batch=100, batch_size=32) -> TensorDict:
|
||||
stats_path = self.data_dir / "stats.pth"
|
||||
|
||||
Reference in New Issue
Block a user