Small fix, Refactor diffusion, Diffusion runs (TODO: remove normalization in diffusion)

This commit is contained in:
Remi Cadene
2024-03-02 17:04:39 +00:00
parent 45b4ecb727
commit 80785f8d0e
6 changed files with 449 additions and 10 deletions

View File

@@ -138,9 +138,6 @@ class TDMPC(nn.Module):
"state": observation["state"].contiguous(),
}
action = self.act(obs, t0=t0, step=self.step.item())
# TODO(rcadene): hack to postprocess action (e.g. unnormalize)
# action = action * self.action_std + self.action_mean
return action
@torch.no_grad()