Ran pre-commit run --all-files

This commit is contained in:
Simon Alibert
2024-02-29 13:37:48 +01:00
parent 0e0e760e43
commit 7e024fdce6
16 changed files with 124 additions and 237 deletions

View File

@@ -3,16 +3,17 @@ import copy
import hydra
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from diffusion_policy.model.common.lr_scheduler import get_scheduler
from diffusion_policy.model.vision.model_getter import get_resnet
from diffusion_policy.model.vision.multi_image_obs_encoder import MultiImageObsEncoder
from diffusion_policy.policy.diffusion_unet_image_policy import DiffusionUnetImagePolicy
FIRST_ACTION = 0
class DiffusionPolicy(nn.Module):
def __init__(
self,
cfg,
@@ -105,7 +106,6 @@ class DiffusionPolicy(nn.Module):
out = self.diffusion.predict_action(obs_dict)
# TODO(rcadene): add possibility to return >1 timestemps
FIRST_ACTION = 0
action = out["action"].squeeze(0)[FIRST_ACTION]
return action
@@ -132,10 +132,7 @@ class DiffusionPolicy(nn.Module):
}
return out
if self.cfg.balanced_sampling:
batch = replay_buffer.sample(batch_size)
else:
batch = replay_buffer.sample()
batch = replay_buffer.sample(batch_size) if self.cfg.balanced_sampling else replay_buffer.sample()
batch = process_batch(batch, self.cfg.horizon, num_slices)
loss = self.diffusion.compute_loss(batch)