Add possibility for the policy to provide a sequence of actions to the env

This commit is contained in:
Remi Cadene
2024-03-03 14:02:24 +00:00
parent 4c400b41a5
commit fddd9f0311
2 changed files with 12 additions and 8 deletions

View File

@@ -12,8 +12,6 @@ from diffusion_policy.model.vision.model_getter import get_resnet
from .diffusion_unet_image_policy import DiffusionUnetImagePolicy
from .multi_image_obs_encoder import MultiImageObsEncoder
FIRST_ACTION = 0
class DiffusionPolicy(nn.Module):
def __init__(
@@ -110,8 +108,7 @@ class DiffusionPolicy(nn.Module):
}
out = self.diffusion.predict_action(obs_dict)
# TODO(rcadene): add possibility to return >1 timestemps
action = out["action"].squeeze(0)[FIRST_ACTION]
action = out["action"].squeeze(0)
return action
def update(self, replay_buffer, step):