Ran pre-commit run --all-files
This commit is contained in:
@@ -3,16 +3,17 @@ import copy
|
||||
import hydra
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||
|
||||
from diffusion_policy.model.common.lr_scheduler import get_scheduler
|
||||
from diffusion_policy.model.vision.model_getter import get_resnet
|
||||
from diffusion_policy.model.vision.multi_image_obs_encoder import MultiImageObsEncoder
|
||||
from diffusion_policy.policy.diffusion_unet_image_policy import DiffusionUnetImagePolicy
|
||||
|
||||
FIRST_ACTION = 0
|
||||
|
||||
|
||||
class DiffusionPolicy(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cfg,
|
||||
@@ -105,7 +106,6 @@ class DiffusionPolicy(nn.Module):
|
||||
out = self.diffusion.predict_action(obs_dict)
|
||||
|
||||
# TODO(rcadene): add possibility to return >1 timestemps
|
||||
FIRST_ACTION = 0
|
||||
action = out["action"].squeeze(0)[FIRST_ACTION]
|
||||
return action
|
||||
|
||||
@@ -132,10 +132,7 @@ class DiffusionPolicy(nn.Module):
|
||||
}
|
||||
return out
|
||||
|
||||
if self.cfg.balanced_sampling:
|
||||
batch = replay_buffer.sample(batch_size)
|
||||
else:
|
||||
batch = replay_buffer.sample()
|
||||
batch = replay_buffer.sample(batch_size) if self.cfg.balanced_sampling else replay_buffer.sample()
|
||||
batch = process_batch(batch, self.cfg.horizon, num_slices)
|
||||
|
||||
loss = self.diffusion.compute_loss(batch)
|
||||
|
||||
Reference in New Issue
Block a user