Fix unit tests, Refactor, Add pusht env, (TODO pusht replay buffer, image preprocessing)

This commit is contained in:
Cadene
2024-02-20 12:26:57 +00:00
parent fdfb2010fd
commit 3da6ffb2cb
10 changed files with 559 additions and 89 deletions

View File

@@ -167,18 +167,21 @@ class SimxarmEnv(EnvBase):
device=self.device,
)
self.done_spec = DiscreteTensorSpec(
2,
shape=(1,),
dtype=torch.bool,
device=self.device,
)
self.success_spec = DiscreteTensorSpec(
2,
shape=(1,),
dtype=torch.bool,
device=self.device,
self.done_spec = CompositeSpec(
{
"done": DiscreteTensorSpec(
2,
shape=(1,),
dtype=torch.bool,
device=self.device,
),
"success": DiscreteTensorSpec(
2,
shape=(1,),
dtype=torch.bool,
device=self.device,
),
}
)
def _set_seed(self, seed: Optional[int]):