PushtEnv inheriates AbstractEnv, Improve factory Normalization

This commit is contained in:
Cadene
2024-03-11 14:05:23 +00:00
parent ebd5c786f1
commit bdd2c801bc
3 changed files with 19 additions and 29 deletions

View File

@@ -49,7 +49,6 @@ class ActionChunkingTransformerPolicy(nn.Module):
self.model, self.optimizer = build_act_model_and_optimizer(cfg)
self.kl_weight = self.cfg.kl_weight
logging.info(f"KL Weight {self.kl_weight}")
self.to(self.device)
def update(self, replay_buffer, step):
@@ -156,7 +155,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
# TODO(rcadene): remove unsqueeze hack to add bsize=1
observation["image"] = observation["image"].unsqueeze(0)
observation["state"] = observation["state"].unsqueeze(0)
# observation["state"] = observation["state"].unsqueeze(0)
# TODO(rcadene): remove hack
# add 1 camera dimension