Add Normalize, non_blocking=True in tdmpc, tdmpc run (TODO: diffusion)
This commit is contained in:
@@ -100,7 +100,8 @@ class PushtEnv(EnvBase):
|
||||
|
||||
def _step(self, tensordict: TensorDict):
|
||||
td = tensordict
|
||||
action = td["action"].numpy()
|
||||
# remove batch dim
|
||||
action = td["action"].squeeze(0).numpy()
|
||||
# step expects shape=(4,) so we pad if necessary
|
||||
# TODO(rcadene): add info["is_success"] and info["success"] ?
|
||||
sum_reward = 0
|
||||
|
||||
Reference in New Issue
Block a user