fix(policies): remove action from batch for offline evaluation (#1609)
* fix(policies): remove action from batch for offline evaluation in diffusion, tdmpc, and vqbet policies * style(diffusion): correct comment capitalization for clarity in modeling_diffusion.py
This commit is contained in:
@@ -133,11 +133,15 @@ class DiffusionPolicy(PreTrainedPolicy):
|
|||||||
"horizon" may not the best name to describe what the variable actually means, because this period is
|
"horizon" may not the best name to describe what the variable actually means, because this period is
|
||||||
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
|
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
|
||||||
"""
|
"""
|
||||||
|
# NOTE: for offline evaluation, we have action in the batch, so we need to pop it out
|
||||||
|
if ACTION in batch:
|
||||||
|
batch.pop(ACTION)
|
||||||
|
|
||||||
batch = self.normalize_inputs(batch)
|
batch = self.normalize_inputs(batch)
|
||||||
if self.config.image_features:
|
if self.config.image_features:
|
||||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||||
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
||||||
# Note: It's important that this happens after stacking the images into a single key.
|
# NOTE: It's important that this happens after stacking the images into a single key.
|
||||||
self._queues = populate_queues(self._queues, batch)
|
self._queues = populate_queues(self._queues, batch)
|
||||||
|
|
||||||
if len(self._queues[ACTION]) == 0:
|
if len(self._queues[ACTION]) == 0:
|
||||||
|
|||||||
@@ -143,7 +143,12 @@ class TDMPCPolicy(PreTrainedPolicy):
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
"""Select a single action given environment observations."""
|
"""Select a single action given environment observations."""
|
||||||
|
# NOTE: for offline evaluation, we have action in the batch, so we need to pop it out
|
||||||
|
if ACTION in batch:
|
||||||
|
batch.pop(ACTION)
|
||||||
|
|
||||||
batch = self.normalize_inputs(batch)
|
batch = self.normalize_inputs(batch)
|
||||||
|
|
||||||
if self.config.image_features:
|
if self.config.image_features:
|
||||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||||
batch[OBS_IMAGE] = batch[next(iter(self.config.image_features))]
|
batch[OBS_IMAGE] = batch[next(iter(self.config.image_features))]
|
||||||
|
|||||||
@@ -139,11 +139,14 @@ class VQBeTPolicy(PreTrainedPolicy):
|
|||||||
environment. It works by managing the actions in a queue and only calling `select_actions` when the
|
environment. It works by managing the actions in a queue and only calling `select_actions` when the
|
||||||
queue is empty.
|
queue is empty.
|
||||||
"""
|
"""
|
||||||
|
# NOTE: for offline evaluation, we have action in the batch, so we need to pop it out
|
||||||
|
if ACTION in batch:
|
||||||
|
batch.pop(ACTION)
|
||||||
batch = self.normalize_inputs(batch)
|
batch = self.normalize_inputs(batch)
|
||||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||||
|
# NOTE: It's important that this happens after stacking the images into a single key.
|
||||||
batch["observation.images"] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
batch["observation.images"] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
||||||
# Note: It's important that this happens after stacking the images into a single key.
|
|
||||||
self._queues = populate_queues(self._queues, batch)
|
self._queues = populate_queues(self._queues, batch)
|
||||||
|
|
||||||
if not self.vqbet.action_head.vqvae_model.discretized.item():
|
if not self.vqbet.action_head.vqvae_model.discretized.item():
|
||||||
|
|||||||
Reference in New Issue
Block a user