Solve conflicts + pre-commit run -a
This commit is contained in:
@@ -9,19 +9,14 @@ import pymunk
|
||||
import torch
|
||||
import torchrl
|
||||
import tqdm
|
||||
from tensordict import TensorDict
|
||||
from torchrl.data.datasets.utils import _get_root_dir
|
||||
from torchrl.data.replay_buffers.replay_buffers import (
|
||||
TensorDictReplayBuffer,
|
||||
)
|
||||
from torchrl.data.replay_buffers.samplers import (
|
||||
Sampler,
|
||||
)
|
||||
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
|
||||
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
|
||||
|
||||
from diffusion_policy.common.replay_buffer import ReplayBuffer
|
||||
from diffusion_policy.env.pusht.pusht_env import pymunk_to_shapely
|
||||
from tensordict import TensorDict
|
||||
from torchrl.data.datasets.utils import _get_root_dir
|
||||
from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
|
||||
from torchrl.data.replay_buffers.samplers import Sampler
|
||||
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
|
||||
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
|
||||
|
||||
# as define in env
|
||||
SUCCESS_THRESHOLD = 0.95 # 95% coverage,
|
||||
|
||||
@@ -8,9 +8,7 @@ import torchrl
|
||||
import tqdm
|
||||
from tensordict import TensorDict
|
||||
from torchrl.data.datasets.utils import _get_root_dir
|
||||
from torchrl.data.replay_buffers.replay_buffers import (
|
||||
TensorDictReplayBuffer,
|
||||
)
|
||||
from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
|
||||
from torchrl.data.replay_buffers.samplers import (
|
||||
Sampler,
|
||||
SliceSampler,
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
import contextlib
|
||||
import datetime
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from omegaconf import OmegaConf
|
||||
from termcolor import colored
|
||||
|
||||
|
||||
def make_dir(dir_path):
|
||||
"""Create directory if it does not already exist."""
|
||||
with contextlib.suppress(OSError):
|
||||
|
||||
@@ -5,7 +5,6 @@ import hydra
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||
|
||||
from diffusion_policy.model.common.lr_scheduler import get_scheduler
|
||||
from diffusion_policy.model.vision.model_getter import get_resnet
|
||||
from diffusion_policy.model.vision.multi_image_obs_encoder import MultiImageObsEncoder
|
||||
@@ -128,12 +127,8 @@ class DiffusionPolicy(nn.Module):
|
||||
|
||||
out = {
|
||||
"obs": {
|
||||
"image": batch["observation", "image"].to(
|
||||
self.device, non_blocking=True
|
||||
),
|
||||
"agent_pos": batch["observation", "state"].to(
|
||||
self.device, non_blocking=True
|
||||
),
|
||||
"image": batch["observation", "image"].to(self.device, non_blocking=True),
|
||||
"agent_pos": batch["observation", "state"].to(self.device, non_blocking=True),
|
||||
},
|
||||
"action": batch["action"].to(self.device, non_blocking=True),
|
||||
}
|
||||
|
||||
@@ -33,7 +33,7 @@ def init_logging():
|
||||
logging.getLogger().addHandler(console_handler)
|
||||
|
||||
|
||||
def format_number_KMB(num):
|
||||
def format_big_number(num):
|
||||
suffixes = ["", "K", "M", "B", "T", "Q"]
|
||||
divisor = 1000.0
|
||||
|
||||
|
||||
Reference in New Issue
Block a user