Solve conflicts + pre-commit run -a

This commit is contained in:
Cadene
2024-02-29 23:31:32 +00:00
parent 0b9027f05e
commit ae050d2e94
8 changed files with 26 additions and 41 deletions

View File

@@ -9,19 +9,14 @@ import pymunk
import torch
import torchrl
import tqdm
from tensordict import TensorDict
from torchrl.data.datasets.utils import _get_root_dir
from torchrl.data.replay_buffers.replay_buffers import (
TensorDictReplayBuffer,
)
from torchrl.data.replay_buffers.samplers import (
Sampler,
)
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
from diffusion_policy.common.replay_buffer import ReplayBuffer
from diffusion_policy.env.pusht.pusht_env import pymunk_to_shapely
from tensordict import TensorDict
from torchrl.data.datasets.utils import _get_root_dir
from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import Sampler
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
# as define in env
SUCCESS_THRESHOLD = 0.95 # 95% coverage,

View File

@@ -8,9 +8,7 @@ import torchrl
import tqdm
from tensordict import TensorDict
from torchrl.data.datasets.utils import _get_root_dir
from torchrl.data.replay_buffers.replay_buffers import (
TensorDictReplayBuffer,
)
from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import (
Sampler,
SliceSampler,

View File

@@ -1,13 +1,12 @@
import contextlib
import datetime
import os
from pathlib import Path
import numpy as np
import pandas as pd
from omegaconf import OmegaConf
from termcolor import colored
def make_dir(dir_path):
"""Create directory if it does not already exist."""
with contextlib.suppress(OSError):

View File

@@ -5,7 +5,6 @@ import hydra
import torch
import torch.nn as nn
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from diffusion_policy.model.common.lr_scheduler import get_scheduler
from diffusion_policy.model.vision.model_getter import get_resnet
from diffusion_policy.model.vision.multi_image_obs_encoder import MultiImageObsEncoder
@@ -128,12 +127,8 @@ class DiffusionPolicy(nn.Module):
out = {
"obs": {
"image": batch["observation", "image"].to(
self.device, non_blocking=True
),
"agent_pos": batch["observation", "state"].to(
self.device, non_blocking=True
),
"image": batch["observation", "image"].to(self.device, non_blocking=True),
"agent_pos": batch["observation", "state"].to(self.device, non_blocking=True),
},
"action": batch["action"].to(self.device, non_blocking=True),
}

View File

@@ -33,7 +33,7 @@ def init_logging():
logging.getLogger().addHandler(console_handler)
def format_number_KMB(num):
def format_big_number(num):
suffixes = ["", "K", "M", "B", "T", "Q"]
divisor = 1000.0