Refactor configs to have env in seperate yaml + Fix training
This commit is contained in:
@@ -10,28 +10,29 @@ def make_offline_buffer(cfg, sampler=None):
|
||||
overwrite_sampler = sampler is not None
|
||||
|
||||
if not overwrite_sampler:
|
||||
num_traj_per_batch = cfg.batch_size # // cfg.horizon
|
||||
# TODO(rcadene): move batch_size outside
|
||||
num_traj_per_batch = cfg.policy.batch_size # // cfg.horizon
|
||||
# TODO(rcadene): Sampler outputs a batch_size <= cfg.batch_size.
|
||||
# We would need to add a transform to pad the tensordict to ensure batch_size == cfg.batch_size.
|
||||
sampler = PrioritizedSliceSampler(
|
||||
max_capacity=100_000,
|
||||
alpha=cfg.per_alpha,
|
||||
beta=cfg.per_beta,
|
||||
alpha=cfg.policy.per_alpha,
|
||||
beta=cfg.policy.per_beta,
|
||||
num_slices=num_traj_per_batch,
|
||||
strict_length=False,
|
||||
)
|
||||
|
||||
if cfg.env == "simxarm":
|
||||
if cfg.env.name == "simxarm":
|
||||
# TODO(rcadene): add PrioritizedSliceSampler inside Simxarm to not have to `sampler.extend(index)` here
|
||||
offline_buffer = SimxarmExperienceReplay(
|
||||
f"xarm_{cfg.task}_medium",
|
||||
f"xarm_{cfg.env.task}_medium",
|
||||
# download="force",
|
||||
download=True,
|
||||
streaming=False,
|
||||
root="data",
|
||||
sampler=sampler,
|
||||
)
|
||||
elif cfg.env == "pusht":
|
||||
elif cfg.env.name == "pusht":
|
||||
offline_buffer = PushtExperienceReplay(
|
||||
"pusht",
|
||||
# download="force",
|
||||
@@ -41,7 +42,7 @@ def make_offline_buffer(cfg, sampler=None):
|
||||
sampler=sampler,
|
||||
)
|
||||
else:
|
||||
raise ValueError(cfg.env)
|
||||
raise ValueError(cfg.env.name)
|
||||
|
||||
if not overwrite_sampler:
|
||||
num_steps = len(offline_buffer)
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
from torchrl.envs.transforms import StepCounter, TransformedEnv
|
||||
|
||||
from lerobot.common.envs.pusht import PushtEnv
|
||||
from lerobot.common.envs.simxarm import SimxarmEnv
|
||||
from lerobot.common.envs.transforms import Prod
|
||||
|
||||
|
||||
@@ -14,9 +12,13 @@ def make_env(cfg):
|
||||
}
|
||||
|
||||
if cfg.env.name == "simxarm":
|
||||
from lerobot.common.envs.simxarm import SimxarmEnv
|
||||
|
||||
kwargs["task"] = cfg.env.task
|
||||
clsfunc = SimxarmEnv
|
||||
elif cfg.env.name == "pusht":
|
||||
from lerobot.common.envs.pusht import PushtEnv
|
||||
|
||||
clsfunc = PushtEnv
|
||||
else:
|
||||
raise ValueError(cfg.env.name)
|
||||
|
||||
@@ -50,7 +50,7 @@ def print_run(cfg, reward=None):
|
||||
)
|
||||
|
||||
kvs = [
|
||||
("task", cfg.task),
|
||||
("task", cfg.env.task),
|
||||
("train steps", f"{int(cfg.train_steps * cfg.env.action_repeat):,}"),
|
||||
# ('observations', 'x'.join([str(s) for s in cfg.obs_shape])),
|
||||
# ('actions', cfg.action_dim),
|
||||
@@ -72,7 +72,7 @@ def cfg_to_group(cfg, return_list=False):
|
||||
"""Return a wandb-safe group name for logging. Optionally returns group name as list."""
|
||||
# lst = [cfg.task, cfg.modality, re.sub("[^0-9a-zA-Z]+", "-", cfg.exp_name)]
|
||||
lst = [
|
||||
f"env:{cfg.env}",
|
||||
f"env:{cfg.env.name}",
|
||||
f"seed:{cfg.seed}",
|
||||
]
|
||||
return lst if return_list else "-".join(lst)
|
||||
|
||||
44
lerobot/common/policies/diffusion.py
Normal file
44
lerobot/common/policies/diffusion.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from diffusion_policy.policy.diffusion_unet_image_policy import DiffusionUnetImagePolicy
|
||||
|
||||
|
||||
class DiffusionPolicy(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
shape_meta: dict,
|
||||
noise_scheduler: DDPMScheduler,
|
||||
obs_encoder: MultiImageObsEncoder,
|
||||
horizon,
|
||||
n_action_steps,
|
||||
n_obs_steps,
|
||||
num_inference_steps=None,
|
||||
obs_as_global_cond=True,
|
||||
diffusion_step_embed_dim=256,
|
||||
down_dims=(256, 512, 1024),
|
||||
kernel_size=5,
|
||||
n_groups=8,
|
||||
cond_predict_scale=True,
|
||||
# parameters passed to step
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.diffusion = DiffusionUnetImagePolicy(
|
||||
shape_meta=shape_meta,
|
||||
noise_scheduler=noise_scheduler,
|
||||
obs_encoder=obs_encoder,
|
||||
horizon=horizon,
|
||||
n_action_steps=n_action_steps,
|
||||
n_obs_steps=n_obs_steps,
|
||||
num_inference_steps=num_inference_steps,
|
||||
obs_as_global_cond=obs_as_global_cond,
|
||||
diffusion_step_embed_dim=diffusion_step_embed_dim,
|
||||
down_dims=down_dims,
|
||||
kernel_size=kernel_size,
|
||||
n_groups=n_groups,
|
||||
cond_predict_scale=cond_predict_scale,
|
||||
# parameters passed to step
|
||||
**kwargs,
|
||||
)
|
||||
@@ -1,9 +1,12 @@
|
||||
from lerobot.common.policies.tdmpc import TDMPC
|
||||
|
||||
|
||||
def make_policy(cfg):
|
||||
if cfg.policy.name == "tdmpc":
|
||||
from lerobot.common.policies.tdmpc import TDMPC
|
||||
|
||||
policy = TDMPC(cfg.policy)
|
||||
elif cfg.policy.name == "diffusion":
|
||||
from lerobot.common.policies.diffusion import DiffusionPolicy
|
||||
|
||||
policy = DiffusionPolicy(cfg.policy)
|
||||
else:
|
||||
raise ValueError(cfg.policy.name)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user