Add Normalize, non_blocking=True in tdmpc, tdmpc run (TODO: diffusion)

This commit is contained in:
Remi Cadene
2024-03-02 15:53:29 +00:00
parent b5a2f460ea
commit 1ae6205269
9 changed files with 183 additions and 67 deletions

View File

@@ -100,7 +100,8 @@ class PushtEnv(EnvBase):
def _step(self, tensordict: TensorDict):
td = tensordict
action = td["action"].numpy()
# remove batch dim
action = td["action"].squeeze(0).numpy()
# step expects shape=(4,) so we pad if necessary
# TODO(rcadene): add info["is_success"] and info["success"] ?
sum_reward = 0