From 9ddbbd8e804cf698394784b4a484f52705dff282 Mon Sep 17 00:00:00 2001 From: Remi Cadene Date: Tue, 6 Aug 2024 17:17:07 +0300 Subject: [PATCH] WIP --- lerobot/common/datasets/lerobot_dataset.py | 12 ++- lerobot/common/logger.py | 3 - lerobot/common/policies/act/modeling_act.py | 23 ++--- .../policies/diffusion/modeling_diffusion.py | 6 +- .../common/policies/tdmpc/modeling_tdmpc.py | 98 ++++++++----------- lerobot/common/policies/utils.py | 29 ++++++ lerobot/configs/policy/act.yaml | 2 +- lerobot/configs/policy/act_real.yaml | 2 +- lerobot/configs/policy/act_real_no_state.yaml | 2 +- lerobot/scripts/eval.py | 28 +----- lerobot/scripts/train.py | 9 +- tests/scripts/save_policy_to_safetensors.py | 64 ++++++++---- tests/test_policies.py | 20 ++-- tests/test_visualize_dataset.py | 4 +- 14 files changed, 162 insertions(+), 140 deletions(-) diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index eb76f78d6..056000a73 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -44,7 +44,7 @@ class LeRobotDataset(torch.utils.data.Dataset): def __init__( self, repo_id: str, - root: Path | None = DATA_DIR, + root: Path | None = None, split: str = "train", image_transforms: Callable | None = None, delta_timestamps: dict[list[float]] | None = None, @@ -53,22 +53,24 @@ class LeRobotDataset(torch.utils.data.Dataset): super().__init__() self.repo_id = repo_id self.root = root + if self.root is None and DATA_DIR is not None: + self.root = DATA_DIR self.split = split self.image_transforms = image_transforms self.delta_timestamps = delta_timestamps # load data from hub or locally when root is provided # TODO(rcadene, aliberts): implement faster transfer # https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads - self.hf_dataset = load_hf_dataset(repo_id, CODEBASE_VERSION, root, split) + self.hf_dataset = load_hf_dataset(repo_id, CODEBASE_VERSION, self.root, split) if split == "train": self.episode_data_index = load_episode_data_index(repo_id, CODEBASE_VERSION, root) else: self.episode_data_index = calculate_episode_data_index(self.hf_dataset) self.hf_dataset = reset_episode_index(self.hf_dataset) - self.stats = load_stats(repo_id, CODEBASE_VERSION, root) - self.info = load_info(repo_id, CODEBASE_VERSION, root) + self.stats = load_stats(repo_id, CODEBASE_VERSION, self.root) + self.info = load_info(repo_id, CODEBASE_VERSION, self.root) if self.video: - self.videos_dir = load_videos(repo_id, CODEBASE_VERSION, root) + self.videos_dir = load_videos(repo_id, CODEBASE_VERSION, self.root) self.video_backend = video_backend if video_backend is not None else "pyav" @property diff --git a/lerobot/common/logger.py b/lerobot/common/logger.py index bf578fcc5..b76d9b673 100644 --- a/lerobot/common/logger.py +++ b/lerobot/common/logger.py @@ -233,9 +233,6 @@ class Logger: if self._wandb is not None: for k, v in d.items(): if not isinstance(v, (int, float, str)): - logging.warning( - f'WandB logging of key "{k}" was ignored as its type is not handled by this wrapper.' - ) continue self._wandb.log({f"{mode}/{k}": v}, step=step) diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index 026917018..54a76e769 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -134,25 +134,26 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin): batch = self.normalize_targets(batch) actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch) - l1_loss = ( - F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1) - ).mean() + bsize = actions_hat.shape[0] + l1_loss = F.l1_loss(batch["action"], actions_hat, reduction="none") + l1_loss = l1_loss * ~batch["action_is_pad"].unsqueeze(-1) + l1_loss = l1_loss.view(bsize, -1).mean(dim=1) + + out_dict = {} + out_dict["l1_loss"] = l1_loss - loss_dict = {"l1_loss": l1_loss.item()} if self.config.use_vae: # Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for # each dimension independently, we sum over the latent dimension to get the total # KL-divergence per batch element, then take the mean over the batch. # (See App. B of https://arxiv.org/abs/1312.6114 for more details). - mean_kld = ( - (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean() - ) - loss_dict["kld_loss"] = mean_kld.item() - loss_dict["loss"] = l1_loss + mean_kld * self.config.kl_weight + kld_loss = (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1) + out_dict["loss"] = l1_loss + kld_loss * self.config.kl_weight else: - loss_dict["loss"] = l1_loss + out_dict["loss"] = l1_loss - return loss_dict + out_dict["action"] = self.unnormalize_outputs({"action": actions_hat})["action"] + return out_dict class ACTTemporalEnsembler: diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index 486085374..9f7c15f75 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -341,7 +341,11 @@ class DiffusionModel(nn.Module): in_episode_bound = ~batch["action_is_pad"] loss = loss * in_episode_bound.unsqueeze(-1) - return loss.mean() + # Compute average per item in the batch + bsize = loss.shape[0] + loss = loss.reshape(bsize, -1).mean(1) + + return loss class SpatialSoftmax(nn.Module): diff --git a/lerobot/common/policies/tdmpc/modeling_tdmpc.py b/lerobot/common/policies/tdmpc/modeling_tdmpc.py index 7dbffcefc..94a50bbf8 100644 --- a/lerobot/common/policies/tdmpc/modeling_tdmpc.py +++ b/lerobot/common/policies/tdmpc/modeling_tdmpc.py @@ -396,51 +396,39 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin): # Compute consistency loss as MSE loss between latents predicted from the rollout and latents # predicted from the (target model's) observation encoder. consistency_loss = ( - ( - temporal_loss_coeffs - * F.mse_loss(z_preds[1:], z_targets, reduction="none").mean(dim=-1) - # `z_preds` depends on the current observation and the actions. - * ~batch["observation.state_is_pad"][0] - * ~batch["action_is_pad"] - # `z_targets` depends on the next observation. - * ~batch["observation.state_is_pad"][1:] - ) - .sum(0) - .mean() - ) + temporal_loss_coeffs + * F.mse_loss(z_preds[1:], z_targets, reduction="none").mean(dim=-1) + # `z_preds` depends on the current observation and the actions. + * ~batch["observation.state_is_pad"][0] + * ~batch["action_is_pad"] + # `z_targets` depends on the next observation. + * ~batch["observation.state_is_pad"][1:] + ).sum(0) # Compute the reward loss as MSE loss between rewards predicted from the rollout and the dataset # rewards. reward_loss = ( - ( - temporal_loss_coeffs - * F.mse_loss(reward_preds, reward, reduction="none") - * ~batch["next.reward_is_pad"] - # `reward_preds` depends on the current observation and the actions. - * ~batch["observation.state_is_pad"][0] - * ~batch["action_is_pad"] - ) - .sum(0) - .mean() - ) + temporal_loss_coeffs + * F.mse_loss(reward_preds, reward, reduction="none") + * ~batch["next.reward_is_pad"] + # `reward_preds` depends on the current observation and the actions. + * ~batch["observation.state_is_pad"][0] + * ~batch["action_is_pad"] + ).sum(0) # Compute state-action value loss (TD loss) for all of the Q functions in the ensemble. q_value_loss = ( - ( - temporal_loss_coeffs - * F.mse_loss( - q_preds_ensemble, - einops.repeat(q_targets, "t b -> e t b", e=q_preds_ensemble.shape[0]), - reduction="none", - ).sum(0) # sum over ensemble - # `q_preds_ensemble` depends on the first observation and the actions. - * ~batch["observation.state_is_pad"][0] - * ~batch["action_is_pad"] - # q_targets depends on the reward and the next observations. - * ~batch["next.reward_is_pad"] - * ~batch["observation.state_is_pad"][1:] - ) - .sum(0) - .mean() - ) + temporal_loss_coeffs + * F.mse_loss( + q_preds_ensemble, + einops.repeat(q_targets, "t b -> e t b", e=q_preds_ensemble.shape[0]), + reduction="none", + ).sum(0) # sum over ensemble + # `q_preds_ensemble` depends on the first observation and the actions. + * ~batch["observation.state_is_pad"][0] + * ~batch["action_is_pad"] + # q_targets depends on the reward and the next observations. + * ~batch["next.reward_is_pad"] + * ~batch["observation.state_is_pad"][1:] + ).sum(0) # Compute state value loss as in eqn 3 of FOWM. diff = v_targets - v_preds # Expectile loss penalizes: @@ -450,16 +438,12 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin): diff > 0, self.config.expectile_weight, (1 - self.config.expectile_weight) ) * (diff**2) v_value_loss = ( - ( - temporal_loss_coeffs - * raw_v_value_loss - # `v_targets` depends on the first observation and the actions, as does `v_preds`. - * ~batch["observation.state_is_pad"][0] - * ~batch["action_is_pad"] - ) - .sum(0) - .mean() - ) + temporal_loss_coeffs + * raw_v_value_loss + # `v_targets` depends on the first observation and the actions, as does `v_preds`. + * ~batch["observation.state_is_pad"][0] + * ~batch["action_is_pad"] + ).sum(0) # Calculate the advantage weighted regression loss for π as detailed in FOWM 3.1. # We won't need these gradients again so detach. @@ -492,7 +476,7 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin): # `action_preds` depends on the first observation and the actions. * ~batch["observation.state_is_pad"][0] * ~batch["action_is_pad"] - ).mean() + ).sum(0) loss = ( self.config.consistency_coeff * consistency_loss @@ -504,13 +488,13 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin): info.update( { - "consistency_loss": consistency_loss.item(), - "reward_loss": reward_loss.item(), - "Q_value_loss": q_value_loss.item(), - "V_value_loss": v_value_loss.item(), - "pi_loss": pi_loss.item(), + "consistency_loss": consistency_loss, + "reward_loss": reward_loss, + "Q_value_loss": q_value_loss, + "V_value_loss": v_value_loss, + "pi_loss": pi_loss, "loss": loss, - "sum_loss": loss.item() * self.config.horizon, + "sum_loss": loss * self.config.horizon, } ) diff --git a/lerobot/common/policies/utils.py b/lerobot/common/policies/utils.py index 5a62daa2a..c99452da4 100644 --- a/lerobot/common/policies/utils.py +++ b/lerobot/common/policies/utils.py @@ -13,7 +13,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging +from pathlib import Path + import torch +from huggingface_hub import snapshot_download +from huggingface_hub.utils._errors import RepositoryNotFoundError +from huggingface_hub.utils._validators import HFValidationError from torch import nn @@ -47,3 +53,26 @@ def get_dtype_from_parameters(module: nn.Module) -> torch.dtype: Note: assumes that all parameters have the same dtype. """ return next(iter(module.parameters())).dtype + + +def get_pretrained_policy_path(pretrained_policy_name_or_path, revision=None): + try: + pretrained_policy_path = Path(snapshot_download(pretrained_policy_name_or_path, revision=revision)) + except (HFValidationError, RepositoryNotFoundError) as e: + if isinstance(e, HFValidationError): + error_message = ( + "The provided pretrained_policy_name_or_path is not a valid Hugging Face Hub repo ID." + ) + else: + error_message = ( + "The provided pretrained_policy_name_or_path was not found on the Hugging Face Hub." + ) + + logging.warning(f"{error_message} Treating it as a local directory.") + pretrained_policy_path = Path(pretrained_policy_name_or_path) + if not pretrained_policy_path.is_dir() or not pretrained_policy_path.exists(): + raise ValueError( + "The provided pretrained_policy_name_or_path is not a valid/existing Hugging Face Hub " + "repo ID, nor is it an existing local directory." + ) + return pretrained_policy_path diff --git a/lerobot/configs/policy/act.yaml b/lerobot/configs/policy/act.yaml index 28883936a..bf7dc3b50 100644 --- a/lerobot/configs/policy/act.yaml +++ b/lerobot/configs/policy/act.yaml @@ -24,7 +24,7 @@ training: online_steps_between_rollouts: 1 delta_timestamps: - action: "[i / ${fps} for i in range(${policy.chunk_size})]" + action: "[i / ${fps} for i in range(1, ${policy.chunk_size} + 1)]" eval: n_episodes: 50 diff --git a/lerobot/configs/policy/act_real.yaml b/lerobot/configs/policy/act_real.yaml index 058104f4d..1d832bf0c 100644 --- a/lerobot/configs/policy/act_real.yaml +++ b/lerobot/configs/policy/act_real.yaml @@ -50,7 +50,7 @@ training: online_steps_between_rollouts: 1 delta_timestamps: - action: "[i / ${fps} for i in range(${policy.chunk_size})]" + action: "[i / ${fps} for i in range(1, ${policy.chunk_size} + 1)]" eval: n_episodes: 50 diff --git a/lerobot/configs/policy/act_real_no_state.yaml b/lerobot/configs/policy/act_real_no_state.yaml index 082610503..a6abaccbb 100644 --- a/lerobot/configs/policy/act_real_no_state.yaml +++ b/lerobot/configs/policy/act_real_no_state.yaml @@ -48,7 +48,7 @@ training: online_steps_between_rollouts: 1 delta_timestamps: - action: "[i / ${fps} for i in range(${policy.chunk_size})]" + action: "[i / ${fps} for i in range(1, ${policy.chunk_size} + 1)]" eval: n_episodes: 50 diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index a07f35304..76cfe7816 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -56,9 +56,6 @@ import einops import gymnasium as gym import numpy as np import torch -from huggingface_hub import snapshot_download -from huggingface_hub.utils._errors import RepositoryNotFoundError -from huggingface_hub.utils._validators import HFValidationError from torch import Tensor, nn from tqdm import trange @@ -68,7 +65,7 @@ from lerobot.common.envs.utils import preprocess_observation from lerobot.common.logger import log_output_dir from lerobot.common.policies.factory import make_policy from lerobot.common.policies.policy_protocol import Policy -from lerobot.common.policies.utils import get_device_from_parameters +from lerobot.common.policies.utils import get_device_from_parameters, get_pretrained_policy_path from lerobot.common.utils.io_utils import write_video from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed @@ -501,29 +498,6 @@ def main( logging.info("End of eval") -def get_pretrained_policy_path(pretrained_policy_name_or_path, revision=None): - try: - pretrained_policy_path = Path(snapshot_download(pretrained_policy_name_or_path, revision=revision)) - except (HFValidationError, RepositoryNotFoundError) as e: - if isinstance(e, HFValidationError): - error_message = ( - "The provided pretrained_policy_name_or_path is not a valid Hugging Face Hub repo ID." - ) - else: - error_message = ( - "The provided pretrained_policy_name_or_path was not found on the Hugging Face Hub." - ) - - logging.warning(f"{error_message} Treating it as a local directory.") - pretrained_policy_path = Path(pretrained_policy_name_or_path) - if not pretrained_policy_path.is_dir() or not pretrained_policy_path.exists(): - raise ValueError( - "The provided pretrained_policy_name_or_path is not a valid/existing Hugging Face Hub " - "repo ID, nor is it an existing local directory." - ) - return pretrained_policy_path - - if __name__ == "__main__": init_logging() diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index d8fdfc1f0..0231ef776 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -120,8 +120,7 @@ def update_policy( policy.train() with torch.autocast(device_type=device.type) if use_amp else nullcontext(): output_dict = policy.forward(batch) - # TODO(rcadene): policy.unnormalize_outputs(out_dict) - loss = output_dict["loss"] + loss = output_dict["loss"].mean() grad_scaler.scale(loss).backward() # Unscale the graident of the optimzer's assigned params in-place **prior to gradient clipping**. @@ -150,14 +149,12 @@ def update_policy( policy.update() info = { - "loss": loss.item(), "grad_norm": float(grad_norm), "lr": optimizer.param_groups[0]["lr"], "update_s": time.perf_counter() - start_time, - **{k: v for k, v in output_dict.items() if k != "loss"}, + **{k: v.detach().mean().item() for k, v in output_dict.items() if "loss" in k}, + **{k: v for k, v in output_dict.items() if "loss" not in k}, } - info.update({k: v for k, v in output_dict.items() if k not in info}) - return info diff --git a/tests/scripts/save_policy_to_safetensors.py b/tests/scripts/save_policy_to_safetensors.py index 5236b7ae5..c69472618 100644 --- a/tests/scripts/save_policy_to_safetensors.py +++ b/tests/scripts/save_policy_to_safetensors.py @@ -13,6 +13,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +""" +Save the policy tests artifacts. + +Note: Run on the cluster + +Example of usage: +```bash +DATA_DIR=tests/data python tests/scripts/save_policy_to_safetensors.py +``` +""" + +import platform import shutil from pathlib import Path @@ -54,7 +66,7 @@ def get_policy_stats(env_name, policy_name, extra_overrides): output_dict = {k: v for k, v in output_dict.items() if isinstance(v, torch.Tensor)} loss = output_dict["loss"] - loss.backward() + loss.mean().backward() grad_stats = {} for key, param in policy.named_parameters(): if param.requires_grad: @@ -96,10 +108,21 @@ def save_policy_to_safetensors(output_dir, env_name, policy_name, extra_override print(f"Overwrite existing safetensors in '{env_policy_dir}':") print(f" - Validate with: `git add {env_policy_dir}`") print(f" - Revert with: `git checkout -- {env_policy_dir}`") + + output_dict, grad_stats, param_stats, actions = get_policy_stats(env_name, policy_name, extra_overrides) + + from safetensors.torch import load_file + + if (env_policy_dir / "output_dict.safetensors").exists(): + prev_loss = load_file(env_policy_dir / "output_dict.safetensors")["loss"] + print(f"Previous loss={prev_loss}") + print(f"New loss={output_dict['loss'].mean()}") + print() + + if env_policy_dir.exists(): shutil.rmtree(env_policy_dir) env_policy_dir.mkdir(parents=True, exist_ok=True) - output_dict, grad_stats, param_stats, actions = get_policy_stats(env_name, policy_name, extra_overrides) save_file(output_dict, env_policy_dir / "output_dict.safetensors") save_file(grad_stats, env_policy_dir / "grad_stats.safetensors") save_file(param_stats, env_policy_dir / "param_stats.safetensors") @@ -107,27 +130,32 @@ def save_policy_to_safetensors(output_dir, env_name, policy_name, extra_override if __name__ == "__main__": + if platform.machine() != "x86_64": + raise OSError("Generate policy artifacts on x86_64 machine since it is used for the unit tests. ") + env_policies = [ - # ("xarm", "tdmpc", ["policy.use_mpc=false"], "use_policy"), - # ("xarm", "tdmpc", ["policy.use_mpc=true"], "use_mpc"), - # ( - # "pusht", - # "diffusion", - # [ - # "policy.n_action_steps=8", - # "policy.num_inference_steps=10", - # "policy.down_dims=[128, 256, 512]", - # ], - # "", - # ), - # ("aloha", "act", ["policy.n_action_steps=10"], ""), - # ("aloha", "act", ["policy.n_action_steps=1000", "policy.chunk_size=1000"], "_1000_steps"), - # ("dora_aloha_real", "act_real", ["policy.n_action_steps=10"], ""), - # ("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"], ""), + ("xarm", "tdmpc", ["policy.use_mpc=false"], "use_policy"), + ("xarm", "tdmpc", ["policy.use_mpc=true"], "use_mpc"), + ( + "pusht", + "diffusion", + [ + "policy.n_action_steps=8", + "policy.num_inference_steps=10", + "policy.down_dims=[128, 256, 512]", + ], + "", + ), + ("aloha", "act", ["policy.n_action_steps=10"], ""), + ("aloha", "act", ["policy.n_action_steps=1000", "policy.chunk_size=1000"], "_1000_steps"), + ("dora_aloha_real", "act_real", ["policy.n_action_steps=10"], ""), + ("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"], ""), ] if len(env_policies) == 0: raise RuntimeError("No policies were provided!") for env, policy, extra_overrides, file_name_extra in env_policies: + print(f"env={env} policy={policy} extra_overrides={extra_overrides}") save_policy_to_safetensors( "tests/data/save_policy_to_safetensors", env, policy, extra_overrides, file_name_extra ) + print() diff --git a/tests/test_policies.py b/tests/test_policies.py index d90f00716..1209bd124 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -147,10 +147,11 @@ def test_policy(env_name, policy_name, extra_overrides): # Check that we run select_actions and get the appropriate output. env = make_env(cfg, n_envs=2) + batch_size = 2 dataloader = torch.utils.data.DataLoader( dataset, num_workers=0, - batch_size=2, + batch_size=batch_size, shuffle=True, pin_memory=DEVICE != "cpu", drop_last=True, @@ -164,12 +165,19 @@ def test_policy(env_name, policy_name, extra_overrides): # Test updating the policy (and test that it does not mutate the batch) batch_ = deepcopy(batch) - policy.forward(batch) + out = policy.forward(batch) assert set(batch) == set(batch_), "Batch keys are not the same after a forward pass." assert all( torch.equal(batch[k], batch_[k]) for k in batch ), "Batch values are not the same after a forward pass." + # Test loss can be visualized using visualize_dataset_html.py + for key in out: + if "loss" in key: + assert ( + out[key].ndim == 1 and out[key].shape[0] == batch_size + ), f"1 loss value per item in the batch is expected, but {out[key].shape} provided instead." + # reset the policy and environment policy.reset() observation, _ = env.reset(seed=cfg.seed) @@ -234,6 +242,7 @@ def test_policy_defaults(policy_name: str): [ ("xarm", "tdmpc"), ("pusht", "diffusion"), + ("pusht", "vqbet"), ("aloha", "act"), ], ) @@ -250,7 +259,7 @@ def test_yaml_matches_dataclass(env_name: str, policy_name: str): def test_save_and_load_pretrained(policy_name: str): policy_cls, _ = get_policy_and_config_classes(policy_name) policy: Policy = policy_cls() - save_dir = "/tmp/test_save_and_load_pretrained_{policy_cls.__name__}" + save_dir = f"/tmp/test_save_and_load_pretrained_{policy_cls.__name__}" policy.save_pretrained(save_dir) policy_ = policy_cls.from_pretrained(save_dir) assert all(torch.equal(p, p_) for p, p_ in zip(policy.parameters(), policy_.parameters(), strict=True)) @@ -365,6 +374,7 @@ def test_normalize(insert_temporal_dim): ["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"], "", ), + ("pusht", "vqbet", "[]", ""), ("aloha", "act", ["policy.n_action_steps=10"], ""), ("aloha", "act", ["policy.n_action_steps=1000", "policy.chunk_size=1000"], "_1000_steps"), ("dora_aloha_real", "act_real", ["policy.n_action_steps=10"], ""), @@ -461,7 +471,3 @@ def test_act_temporal_ensembler(): assert torch.all(offline_avg <= einops.reduce(seq_slice, "b s 1 -> b 1", "max")) # Selected atol=1e-4 keeping in mind actions in [-1, 1] and excepting 0.01% error. assert torch.allclose(online_avg, offline_avg, atol=1e-4) - - -if __name__ == "__main__": - test_act_temporal_ensembler() diff --git a/tests/test_visualize_dataset.py b/tests/test_visualize_dataset.py index 075e2b372..422175ef0 100644 --- a/tests/test_visualize_dataset.py +++ b/tests/test_visualize_dataset.py @@ -25,13 +25,13 @@ from lerobot.scripts.visualize_dataset import visualize_dataset ["lerobot/pusht"], ) @pytest.mark.parametrize("root", [Path(__file__).parent / "data"]) -def test_visualize_local_dataset(tmpdir, repo_id, root): +def test_visualize_dataset_root(tmpdir, repo_id, root): rrd_path = visualize_dataset( repo_id, + root=root, episode_index=0, batch_size=32, save=True, output_dir=tmpdir, - root=root, ) assert rrd_path.exists()