WIP
This commit is contained in:
@@ -44,7 +44,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
repo_id: str,
|
repo_id: str,
|
||||||
root: Path | None = DATA_DIR,
|
root: Path | None = None,
|
||||||
split: str = "train",
|
split: str = "train",
|
||||||
image_transforms: Callable | None = None,
|
image_transforms: Callable | None = None,
|
||||||
delta_timestamps: dict[list[float]] | None = None,
|
delta_timestamps: dict[list[float]] | None = None,
|
||||||
@@ -53,22 +53,24 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.repo_id = repo_id
|
self.repo_id = repo_id
|
||||||
self.root = root
|
self.root = root
|
||||||
|
if self.root is None and DATA_DIR is not None:
|
||||||
|
self.root = DATA_DIR
|
||||||
self.split = split
|
self.split = split
|
||||||
self.image_transforms = image_transforms
|
self.image_transforms = image_transforms
|
||||||
self.delta_timestamps = delta_timestamps
|
self.delta_timestamps = delta_timestamps
|
||||||
# load data from hub or locally when root is provided
|
# load data from hub or locally when root is provided
|
||||||
# TODO(rcadene, aliberts): implement faster transfer
|
# TODO(rcadene, aliberts): implement faster transfer
|
||||||
# https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
|
# https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
|
||||||
self.hf_dataset = load_hf_dataset(repo_id, CODEBASE_VERSION, root, split)
|
self.hf_dataset = load_hf_dataset(repo_id, CODEBASE_VERSION, self.root, split)
|
||||||
if split == "train":
|
if split == "train":
|
||||||
self.episode_data_index = load_episode_data_index(repo_id, CODEBASE_VERSION, root)
|
self.episode_data_index = load_episode_data_index(repo_id, CODEBASE_VERSION, root)
|
||||||
else:
|
else:
|
||||||
self.episode_data_index = calculate_episode_data_index(self.hf_dataset)
|
self.episode_data_index = calculate_episode_data_index(self.hf_dataset)
|
||||||
self.hf_dataset = reset_episode_index(self.hf_dataset)
|
self.hf_dataset = reset_episode_index(self.hf_dataset)
|
||||||
self.stats = load_stats(repo_id, CODEBASE_VERSION, root)
|
self.stats = load_stats(repo_id, CODEBASE_VERSION, self.root)
|
||||||
self.info = load_info(repo_id, CODEBASE_VERSION, root)
|
self.info = load_info(repo_id, CODEBASE_VERSION, self.root)
|
||||||
if self.video:
|
if self.video:
|
||||||
self.videos_dir = load_videos(repo_id, CODEBASE_VERSION, root)
|
self.videos_dir = load_videos(repo_id, CODEBASE_VERSION, self.root)
|
||||||
self.video_backend = video_backend if video_backend is not None else "pyav"
|
self.video_backend = video_backend if video_backend is not None else "pyav"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@@ -233,9 +233,6 @@ class Logger:
|
|||||||
if self._wandb is not None:
|
if self._wandb is not None:
|
||||||
for k, v in d.items():
|
for k, v in d.items():
|
||||||
if not isinstance(v, (int, float, str)):
|
if not isinstance(v, (int, float, str)):
|
||||||
logging.warning(
|
|
||||||
f'WandB logging of key "{k}" was ignored as its type is not handled by this wrapper.'
|
|
||||||
)
|
|
||||||
continue
|
continue
|
||||||
self._wandb.log({f"{mode}/{k}": v}, step=step)
|
self._wandb.log({f"{mode}/{k}": v}, step=step)
|
||||||
|
|
||||||
|
|||||||
@@ -134,25 +134,26 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
|
|||||||
batch = self.normalize_targets(batch)
|
batch = self.normalize_targets(batch)
|
||||||
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
|
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
|
||||||
|
|
||||||
l1_loss = (
|
bsize = actions_hat.shape[0]
|
||||||
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
|
l1_loss = F.l1_loss(batch["action"], actions_hat, reduction="none")
|
||||||
).mean()
|
l1_loss = l1_loss * ~batch["action_is_pad"].unsqueeze(-1)
|
||||||
|
l1_loss = l1_loss.view(bsize, -1).mean(dim=1)
|
||||||
|
|
||||||
|
out_dict = {}
|
||||||
|
out_dict["l1_loss"] = l1_loss
|
||||||
|
|
||||||
loss_dict = {"l1_loss": l1_loss.item()}
|
|
||||||
if self.config.use_vae:
|
if self.config.use_vae:
|
||||||
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
|
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
|
||||||
# each dimension independently, we sum over the latent dimension to get the total
|
# each dimension independently, we sum over the latent dimension to get the total
|
||||||
# KL-divergence per batch element, then take the mean over the batch.
|
# KL-divergence per batch element, then take the mean over the batch.
|
||||||
# (See App. B of https://arxiv.org/abs/1312.6114 for more details).
|
# (See App. B of https://arxiv.org/abs/1312.6114 for more details).
|
||||||
mean_kld = (
|
kld_loss = (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1)
|
||||||
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
|
out_dict["loss"] = l1_loss + kld_loss * self.config.kl_weight
|
||||||
)
|
|
||||||
loss_dict["kld_loss"] = mean_kld.item()
|
|
||||||
loss_dict["loss"] = l1_loss + mean_kld * self.config.kl_weight
|
|
||||||
else:
|
else:
|
||||||
loss_dict["loss"] = l1_loss
|
out_dict["loss"] = l1_loss
|
||||||
|
|
||||||
return loss_dict
|
out_dict["action"] = self.unnormalize_outputs({"action": actions_hat})["action"]
|
||||||
|
return out_dict
|
||||||
|
|
||||||
|
|
||||||
class ACTTemporalEnsembler:
|
class ACTTemporalEnsembler:
|
||||||
|
|||||||
@@ -341,7 +341,11 @@ class DiffusionModel(nn.Module):
|
|||||||
in_episode_bound = ~batch["action_is_pad"]
|
in_episode_bound = ~batch["action_is_pad"]
|
||||||
loss = loss * in_episode_bound.unsqueeze(-1)
|
loss = loss * in_episode_bound.unsqueeze(-1)
|
||||||
|
|
||||||
return loss.mean()
|
# Compute average per item in the batch
|
||||||
|
bsize = loss.shape[0]
|
||||||
|
loss = loss.reshape(bsize, -1).mean(1)
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
class SpatialSoftmax(nn.Module):
|
class SpatialSoftmax(nn.Module):
|
||||||
|
|||||||
@@ -396,51 +396,39 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
|||||||
# Compute consistency loss as MSE loss between latents predicted from the rollout and latents
|
# Compute consistency loss as MSE loss between latents predicted from the rollout and latents
|
||||||
# predicted from the (target model's) observation encoder.
|
# predicted from the (target model's) observation encoder.
|
||||||
consistency_loss = (
|
consistency_loss = (
|
||||||
(
|
temporal_loss_coeffs
|
||||||
temporal_loss_coeffs
|
* F.mse_loss(z_preds[1:], z_targets, reduction="none").mean(dim=-1)
|
||||||
* F.mse_loss(z_preds[1:], z_targets, reduction="none").mean(dim=-1)
|
# `z_preds` depends on the current observation and the actions.
|
||||||
# `z_preds` depends on the current observation and the actions.
|
* ~batch["observation.state_is_pad"][0]
|
||||||
* ~batch["observation.state_is_pad"][0]
|
* ~batch["action_is_pad"]
|
||||||
* ~batch["action_is_pad"]
|
# `z_targets` depends on the next observation.
|
||||||
# `z_targets` depends on the next observation.
|
* ~batch["observation.state_is_pad"][1:]
|
||||||
* ~batch["observation.state_is_pad"][1:]
|
).sum(0)
|
||||||
)
|
|
||||||
.sum(0)
|
|
||||||
.mean()
|
|
||||||
)
|
|
||||||
# Compute the reward loss as MSE loss between rewards predicted from the rollout and the dataset
|
# Compute the reward loss as MSE loss between rewards predicted from the rollout and the dataset
|
||||||
# rewards.
|
# rewards.
|
||||||
reward_loss = (
|
reward_loss = (
|
||||||
(
|
temporal_loss_coeffs
|
||||||
temporal_loss_coeffs
|
* F.mse_loss(reward_preds, reward, reduction="none")
|
||||||
* F.mse_loss(reward_preds, reward, reduction="none")
|
* ~batch["next.reward_is_pad"]
|
||||||
* ~batch["next.reward_is_pad"]
|
# `reward_preds` depends on the current observation and the actions.
|
||||||
# `reward_preds` depends on the current observation and the actions.
|
* ~batch["observation.state_is_pad"][0]
|
||||||
* ~batch["observation.state_is_pad"][0]
|
* ~batch["action_is_pad"]
|
||||||
* ~batch["action_is_pad"]
|
).sum(0)
|
||||||
)
|
|
||||||
.sum(0)
|
|
||||||
.mean()
|
|
||||||
)
|
|
||||||
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
||||||
q_value_loss = (
|
q_value_loss = (
|
||||||
(
|
temporal_loss_coeffs
|
||||||
temporal_loss_coeffs
|
* F.mse_loss(
|
||||||
* F.mse_loss(
|
q_preds_ensemble,
|
||||||
q_preds_ensemble,
|
einops.repeat(q_targets, "t b -> e t b", e=q_preds_ensemble.shape[0]),
|
||||||
einops.repeat(q_targets, "t b -> e t b", e=q_preds_ensemble.shape[0]),
|
reduction="none",
|
||||||
reduction="none",
|
).sum(0) # sum over ensemble
|
||||||
).sum(0) # sum over ensemble
|
# `q_preds_ensemble` depends on the first observation and the actions.
|
||||||
# `q_preds_ensemble` depends on the first observation and the actions.
|
* ~batch["observation.state_is_pad"][0]
|
||||||
* ~batch["observation.state_is_pad"][0]
|
* ~batch["action_is_pad"]
|
||||||
* ~batch["action_is_pad"]
|
# q_targets depends on the reward and the next observations.
|
||||||
# q_targets depends on the reward and the next observations.
|
* ~batch["next.reward_is_pad"]
|
||||||
* ~batch["next.reward_is_pad"]
|
* ~batch["observation.state_is_pad"][1:]
|
||||||
* ~batch["observation.state_is_pad"][1:]
|
).sum(0)
|
||||||
)
|
|
||||||
.sum(0)
|
|
||||||
.mean()
|
|
||||||
)
|
|
||||||
# Compute state value loss as in eqn 3 of FOWM.
|
# Compute state value loss as in eqn 3 of FOWM.
|
||||||
diff = v_targets - v_preds
|
diff = v_targets - v_preds
|
||||||
# Expectile loss penalizes:
|
# Expectile loss penalizes:
|
||||||
@@ -450,16 +438,12 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
|||||||
diff > 0, self.config.expectile_weight, (1 - self.config.expectile_weight)
|
diff > 0, self.config.expectile_weight, (1 - self.config.expectile_weight)
|
||||||
) * (diff**2)
|
) * (diff**2)
|
||||||
v_value_loss = (
|
v_value_loss = (
|
||||||
(
|
temporal_loss_coeffs
|
||||||
temporal_loss_coeffs
|
* raw_v_value_loss
|
||||||
* raw_v_value_loss
|
# `v_targets` depends on the first observation and the actions, as does `v_preds`.
|
||||||
# `v_targets` depends on the first observation and the actions, as does `v_preds`.
|
* ~batch["observation.state_is_pad"][0]
|
||||||
* ~batch["observation.state_is_pad"][0]
|
* ~batch["action_is_pad"]
|
||||||
* ~batch["action_is_pad"]
|
).sum(0)
|
||||||
)
|
|
||||||
.sum(0)
|
|
||||||
.mean()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Calculate the advantage weighted regression loss for π as detailed in FOWM 3.1.
|
# Calculate the advantage weighted regression loss for π as detailed in FOWM 3.1.
|
||||||
# We won't need these gradients again so detach.
|
# We won't need these gradients again so detach.
|
||||||
@@ -492,7 +476,7 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
|||||||
# `action_preds` depends on the first observation and the actions.
|
# `action_preds` depends on the first observation and the actions.
|
||||||
* ~batch["observation.state_is_pad"][0]
|
* ~batch["observation.state_is_pad"][0]
|
||||||
* ~batch["action_is_pad"]
|
* ~batch["action_is_pad"]
|
||||||
).mean()
|
).sum(0)
|
||||||
|
|
||||||
loss = (
|
loss = (
|
||||||
self.config.consistency_coeff * consistency_loss
|
self.config.consistency_coeff * consistency_loss
|
||||||
@@ -504,13 +488,13 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
|||||||
|
|
||||||
info.update(
|
info.update(
|
||||||
{
|
{
|
||||||
"consistency_loss": consistency_loss.item(),
|
"consistency_loss": consistency_loss,
|
||||||
"reward_loss": reward_loss.item(),
|
"reward_loss": reward_loss,
|
||||||
"Q_value_loss": q_value_loss.item(),
|
"Q_value_loss": q_value_loss,
|
||||||
"V_value_loss": v_value_loss.item(),
|
"V_value_loss": v_value_loss,
|
||||||
"pi_loss": pi_loss.item(),
|
"pi_loss": pi_loss,
|
||||||
"loss": loss,
|
"loss": loss,
|
||||||
"sum_loss": loss.item() * self.config.horizon,
|
"sum_loss": loss * self.config.horizon,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -13,7 +13,13 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
from huggingface_hub.utils._errors import RepositoryNotFoundError
|
||||||
|
from huggingface_hub.utils._validators import HFValidationError
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
@@ -47,3 +53,26 @@ def get_dtype_from_parameters(module: nn.Module) -> torch.dtype:
|
|||||||
Note: assumes that all parameters have the same dtype.
|
Note: assumes that all parameters have the same dtype.
|
||||||
"""
|
"""
|
||||||
return next(iter(module.parameters())).dtype
|
return next(iter(module.parameters())).dtype
|
||||||
|
|
||||||
|
|
||||||
|
def get_pretrained_policy_path(pretrained_policy_name_or_path, revision=None):
|
||||||
|
try:
|
||||||
|
pretrained_policy_path = Path(snapshot_download(pretrained_policy_name_or_path, revision=revision))
|
||||||
|
except (HFValidationError, RepositoryNotFoundError) as e:
|
||||||
|
if isinstance(e, HFValidationError):
|
||||||
|
error_message = (
|
||||||
|
"The provided pretrained_policy_name_or_path is not a valid Hugging Face Hub repo ID."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
error_message = (
|
||||||
|
"The provided pretrained_policy_name_or_path was not found on the Hugging Face Hub."
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.warning(f"{error_message} Treating it as a local directory.")
|
||||||
|
pretrained_policy_path = Path(pretrained_policy_name_or_path)
|
||||||
|
if not pretrained_policy_path.is_dir() or not pretrained_policy_path.exists():
|
||||||
|
raise ValueError(
|
||||||
|
"The provided pretrained_policy_name_or_path is not a valid/existing Hugging Face Hub "
|
||||||
|
"repo ID, nor is it an existing local directory."
|
||||||
|
)
|
||||||
|
return pretrained_policy_path
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ training:
|
|||||||
online_steps_between_rollouts: 1
|
online_steps_between_rollouts: 1
|
||||||
|
|
||||||
delta_timestamps:
|
delta_timestamps:
|
||||||
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
|
action: "[i / ${fps} for i in range(1, ${policy.chunk_size} + 1)]"
|
||||||
|
|
||||||
eval:
|
eval:
|
||||||
n_episodes: 50
|
n_episodes: 50
|
||||||
|
|||||||
@@ -50,7 +50,7 @@ training:
|
|||||||
online_steps_between_rollouts: 1
|
online_steps_between_rollouts: 1
|
||||||
|
|
||||||
delta_timestamps:
|
delta_timestamps:
|
||||||
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
|
action: "[i / ${fps} for i in range(1, ${policy.chunk_size} + 1)]"
|
||||||
|
|
||||||
eval:
|
eval:
|
||||||
n_episodes: 50
|
n_episodes: 50
|
||||||
|
|||||||
@@ -48,7 +48,7 @@ training:
|
|||||||
online_steps_between_rollouts: 1
|
online_steps_between_rollouts: 1
|
||||||
|
|
||||||
delta_timestamps:
|
delta_timestamps:
|
||||||
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
|
action: "[i / ${fps} for i in range(1, ${policy.chunk_size} + 1)]"
|
||||||
|
|
||||||
eval:
|
eval:
|
||||||
n_episodes: 50
|
n_episodes: 50
|
||||||
|
|||||||
@@ -56,9 +56,6 @@ import einops
|
|||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from huggingface_hub import snapshot_download
|
|
||||||
from huggingface_hub.utils._errors import RepositoryNotFoundError
|
|
||||||
from huggingface_hub.utils._validators import HFValidationError
|
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from tqdm import trange
|
from tqdm import trange
|
||||||
|
|
||||||
@@ -68,7 +65,7 @@ from lerobot.common.envs.utils import preprocess_observation
|
|||||||
from lerobot.common.logger import log_output_dir
|
from lerobot.common.logger import log_output_dir
|
||||||
from lerobot.common.policies.factory import make_policy
|
from lerobot.common.policies.factory import make_policy
|
||||||
from lerobot.common.policies.policy_protocol import Policy
|
from lerobot.common.policies.policy_protocol import Policy
|
||||||
from lerobot.common.policies.utils import get_device_from_parameters
|
from lerobot.common.policies.utils import get_device_from_parameters, get_pretrained_policy_path
|
||||||
from lerobot.common.utils.io_utils import write_video
|
from lerobot.common.utils.io_utils import write_video
|
||||||
from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed
|
from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed
|
||||||
|
|
||||||
@@ -501,29 +498,6 @@ def main(
|
|||||||
logging.info("End of eval")
|
logging.info("End of eval")
|
||||||
|
|
||||||
|
|
||||||
def get_pretrained_policy_path(pretrained_policy_name_or_path, revision=None):
|
|
||||||
try:
|
|
||||||
pretrained_policy_path = Path(snapshot_download(pretrained_policy_name_or_path, revision=revision))
|
|
||||||
except (HFValidationError, RepositoryNotFoundError) as e:
|
|
||||||
if isinstance(e, HFValidationError):
|
|
||||||
error_message = (
|
|
||||||
"The provided pretrained_policy_name_or_path is not a valid Hugging Face Hub repo ID."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
error_message = (
|
|
||||||
"The provided pretrained_policy_name_or_path was not found on the Hugging Face Hub."
|
|
||||||
)
|
|
||||||
|
|
||||||
logging.warning(f"{error_message} Treating it as a local directory.")
|
|
||||||
pretrained_policy_path = Path(pretrained_policy_name_or_path)
|
|
||||||
if not pretrained_policy_path.is_dir() or not pretrained_policy_path.exists():
|
|
||||||
raise ValueError(
|
|
||||||
"The provided pretrained_policy_name_or_path is not a valid/existing Hugging Face Hub "
|
|
||||||
"repo ID, nor is it an existing local directory."
|
|
||||||
)
|
|
||||||
return pretrained_policy_path
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
init_logging()
|
init_logging()
|
||||||
|
|
||||||
|
|||||||
@@ -120,8 +120,7 @@ def update_policy(
|
|||||||
policy.train()
|
policy.train()
|
||||||
with torch.autocast(device_type=device.type) if use_amp else nullcontext():
|
with torch.autocast(device_type=device.type) if use_amp else nullcontext():
|
||||||
output_dict = policy.forward(batch)
|
output_dict = policy.forward(batch)
|
||||||
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
|
loss = output_dict["loss"].mean()
|
||||||
loss = output_dict["loss"]
|
|
||||||
grad_scaler.scale(loss).backward()
|
grad_scaler.scale(loss).backward()
|
||||||
|
|
||||||
# Unscale the graident of the optimzer's assigned params in-place **prior to gradient clipping**.
|
# Unscale the graident of the optimzer's assigned params in-place **prior to gradient clipping**.
|
||||||
@@ -150,14 +149,12 @@ def update_policy(
|
|||||||
policy.update()
|
policy.update()
|
||||||
|
|
||||||
info = {
|
info = {
|
||||||
"loss": loss.item(),
|
|
||||||
"grad_norm": float(grad_norm),
|
"grad_norm": float(grad_norm),
|
||||||
"lr": optimizer.param_groups[0]["lr"],
|
"lr": optimizer.param_groups[0]["lr"],
|
||||||
"update_s": time.perf_counter() - start_time,
|
"update_s": time.perf_counter() - start_time,
|
||||||
**{k: v for k, v in output_dict.items() if k != "loss"},
|
**{k: v.detach().mean().item() for k, v in output_dict.items() if "loss" in k},
|
||||||
|
**{k: v for k, v in output_dict.items() if "loss" not in k},
|
||||||
}
|
}
|
||||||
info.update({k: v for k, v in output_dict.items() if k not in info})
|
|
||||||
|
|
||||||
return info
|
return info
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -13,6 +13,18 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
Save the policy tests artifacts.
|
||||||
|
|
||||||
|
Note: Run on the cluster
|
||||||
|
|
||||||
|
Example of usage:
|
||||||
|
```bash
|
||||||
|
DATA_DIR=tests/data python tests/scripts/save_policy_to_safetensors.py
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
import platform
|
||||||
import shutil
|
import shutil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@@ -54,7 +66,7 @@ def get_policy_stats(env_name, policy_name, extra_overrides):
|
|||||||
output_dict = {k: v for k, v in output_dict.items() if isinstance(v, torch.Tensor)}
|
output_dict = {k: v for k, v in output_dict.items() if isinstance(v, torch.Tensor)}
|
||||||
loss = output_dict["loss"]
|
loss = output_dict["loss"]
|
||||||
|
|
||||||
loss.backward()
|
loss.mean().backward()
|
||||||
grad_stats = {}
|
grad_stats = {}
|
||||||
for key, param in policy.named_parameters():
|
for key, param in policy.named_parameters():
|
||||||
if param.requires_grad:
|
if param.requires_grad:
|
||||||
@@ -96,10 +108,21 @@ def save_policy_to_safetensors(output_dir, env_name, policy_name, extra_override
|
|||||||
print(f"Overwrite existing safetensors in '{env_policy_dir}':")
|
print(f"Overwrite existing safetensors in '{env_policy_dir}':")
|
||||||
print(f" - Validate with: `git add {env_policy_dir}`")
|
print(f" - Validate with: `git add {env_policy_dir}`")
|
||||||
print(f" - Revert with: `git checkout -- {env_policy_dir}`")
|
print(f" - Revert with: `git checkout -- {env_policy_dir}`")
|
||||||
|
|
||||||
|
output_dict, grad_stats, param_stats, actions = get_policy_stats(env_name, policy_name, extra_overrides)
|
||||||
|
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
|
if (env_policy_dir / "output_dict.safetensors").exists():
|
||||||
|
prev_loss = load_file(env_policy_dir / "output_dict.safetensors")["loss"]
|
||||||
|
print(f"Previous loss={prev_loss}")
|
||||||
|
print(f"New loss={output_dict['loss'].mean()}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
if env_policy_dir.exists():
|
||||||
shutil.rmtree(env_policy_dir)
|
shutil.rmtree(env_policy_dir)
|
||||||
|
|
||||||
env_policy_dir.mkdir(parents=True, exist_ok=True)
|
env_policy_dir.mkdir(parents=True, exist_ok=True)
|
||||||
output_dict, grad_stats, param_stats, actions = get_policy_stats(env_name, policy_name, extra_overrides)
|
|
||||||
save_file(output_dict, env_policy_dir / "output_dict.safetensors")
|
save_file(output_dict, env_policy_dir / "output_dict.safetensors")
|
||||||
save_file(grad_stats, env_policy_dir / "grad_stats.safetensors")
|
save_file(grad_stats, env_policy_dir / "grad_stats.safetensors")
|
||||||
save_file(param_stats, env_policy_dir / "param_stats.safetensors")
|
save_file(param_stats, env_policy_dir / "param_stats.safetensors")
|
||||||
@@ -107,27 +130,32 @@ def save_policy_to_safetensors(output_dir, env_name, policy_name, extra_override
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
if platform.machine() != "x86_64":
|
||||||
|
raise OSError("Generate policy artifacts on x86_64 machine since it is used for the unit tests. ")
|
||||||
|
|
||||||
env_policies = [
|
env_policies = [
|
||||||
# ("xarm", "tdmpc", ["policy.use_mpc=false"], "use_policy"),
|
("xarm", "tdmpc", ["policy.use_mpc=false"], "use_policy"),
|
||||||
# ("xarm", "tdmpc", ["policy.use_mpc=true"], "use_mpc"),
|
("xarm", "tdmpc", ["policy.use_mpc=true"], "use_mpc"),
|
||||||
# (
|
(
|
||||||
# "pusht",
|
"pusht",
|
||||||
# "diffusion",
|
"diffusion",
|
||||||
# [
|
[
|
||||||
# "policy.n_action_steps=8",
|
"policy.n_action_steps=8",
|
||||||
# "policy.num_inference_steps=10",
|
"policy.num_inference_steps=10",
|
||||||
# "policy.down_dims=[128, 256, 512]",
|
"policy.down_dims=[128, 256, 512]",
|
||||||
# ],
|
],
|
||||||
# "",
|
"",
|
||||||
# ),
|
),
|
||||||
# ("aloha", "act", ["policy.n_action_steps=10"], ""),
|
("aloha", "act", ["policy.n_action_steps=10"], ""),
|
||||||
# ("aloha", "act", ["policy.n_action_steps=1000", "policy.chunk_size=1000"], "_1000_steps"),
|
("aloha", "act", ["policy.n_action_steps=1000", "policy.chunk_size=1000"], "_1000_steps"),
|
||||||
# ("dora_aloha_real", "act_real", ["policy.n_action_steps=10"], ""),
|
("dora_aloha_real", "act_real", ["policy.n_action_steps=10"], ""),
|
||||||
# ("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"], ""),
|
("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"], ""),
|
||||||
]
|
]
|
||||||
if len(env_policies) == 0:
|
if len(env_policies) == 0:
|
||||||
raise RuntimeError("No policies were provided!")
|
raise RuntimeError("No policies were provided!")
|
||||||
for env, policy, extra_overrides, file_name_extra in env_policies:
|
for env, policy, extra_overrides, file_name_extra in env_policies:
|
||||||
|
print(f"env={env} policy={policy} extra_overrides={extra_overrides}")
|
||||||
save_policy_to_safetensors(
|
save_policy_to_safetensors(
|
||||||
"tests/data/save_policy_to_safetensors", env, policy, extra_overrides, file_name_extra
|
"tests/data/save_policy_to_safetensors", env, policy, extra_overrides, file_name_extra
|
||||||
)
|
)
|
||||||
|
print()
|
||||||
|
|||||||
@@ -147,10 +147,11 @@ def test_policy(env_name, policy_name, extra_overrides):
|
|||||||
# Check that we run select_actions and get the appropriate output.
|
# Check that we run select_actions and get the appropriate output.
|
||||||
env = make_env(cfg, n_envs=2)
|
env = make_env(cfg, n_envs=2)
|
||||||
|
|
||||||
|
batch_size = 2
|
||||||
dataloader = torch.utils.data.DataLoader(
|
dataloader = torch.utils.data.DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
num_workers=0,
|
num_workers=0,
|
||||||
batch_size=2,
|
batch_size=batch_size,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
pin_memory=DEVICE != "cpu",
|
pin_memory=DEVICE != "cpu",
|
||||||
drop_last=True,
|
drop_last=True,
|
||||||
@@ -164,12 +165,19 @@ def test_policy(env_name, policy_name, extra_overrides):
|
|||||||
|
|
||||||
# Test updating the policy (and test that it does not mutate the batch)
|
# Test updating the policy (and test that it does not mutate the batch)
|
||||||
batch_ = deepcopy(batch)
|
batch_ = deepcopy(batch)
|
||||||
policy.forward(batch)
|
out = policy.forward(batch)
|
||||||
assert set(batch) == set(batch_), "Batch keys are not the same after a forward pass."
|
assert set(batch) == set(batch_), "Batch keys are not the same after a forward pass."
|
||||||
assert all(
|
assert all(
|
||||||
torch.equal(batch[k], batch_[k]) for k in batch
|
torch.equal(batch[k], batch_[k]) for k in batch
|
||||||
), "Batch values are not the same after a forward pass."
|
), "Batch values are not the same after a forward pass."
|
||||||
|
|
||||||
|
# Test loss can be visualized using visualize_dataset_html.py
|
||||||
|
for key in out:
|
||||||
|
if "loss" in key:
|
||||||
|
assert (
|
||||||
|
out[key].ndim == 1 and out[key].shape[0] == batch_size
|
||||||
|
), f"1 loss value per item in the batch is expected, but {out[key].shape} provided instead."
|
||||||
|
|
||||||
# reset the policy and environment
|
# reset the policy and environment
|
||||||
policy.reset()
|
policy.reset()
|
||||||
observation, _ = env.reset(seed=cfg.seed)
|
observation, _ = env.reset(seed=cfg.seed)
|
||||||
@@ -234,6 +242,7 @@ def test_policy_defaults(policy_name: str):
|
|||||||
[
|
[
|
||||||
("xarm", "tdmpc"),
|
("xarm", "tdmpc"),
|
||||||
("pusht", "diffusion"),
|
("pusht", "diffusion"),
|
||||||
|
("pusht", "vqbet"),
|
||||||
("aloha", "act"),
|
("aloha", "act"),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@@ -250,7 +259,7 @@ def test_yaml_matches_dataclass(env_name: str, policy_name: str):
|
|||||||
def test_save_and_load_pretrained(policy_name: str):
|
def test_save_and_load_pretrained(policy_name: str):
|
||||||
policy_cls, _ = get_policy_and_config_classes(policy_name)
|
policy_cls, _ = get_policy_and_config_classes(policy_name)
|
||||||
policy: Policy = policy_cls()
|
policy: Policy = policy_cls()
|
||||||
save_dir = "/tmp/test_save_and_load_pretrained_{policy_cls.__name__}"
|
save_dir = f"/tmp/test_save_and_load_pretrained_{policy_cls.__name__}"
|
||||||
policy.save_pretrained(save_dir)
|
policy.save_pretrained(save_dir)
|
||||||
policy_ = policy_cls.from_pretrained(save_dir)
|
policy_ = policy_cls.from_pretrained(save_dir)
|
||||||
assert all(torch.equal(p, p_) for p, p_ in zip(policy.parameters(), policy_.parameters(), strict=True))
|
assert all(torch.equal(p, p_) for p, p_ in zip(policy.parameters(), policy_.parameters(), strict=True))
|
||||||
@@ -365,6 +374,7 @@ def test_normalize(insert_temporal_dim):
|
|||||||
["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
|
["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
|
||||||
"",
|
"",
|
||||||
),
|
),
|
||||||
|
("pusht", "vqbet", "[]", ""),
|
||||||
("aloha", "act", ["policy.n_action_steps=10"], ""),
|
("aloha", "act", ["policy.n_action_steps=10"], ""),
|
||||||
("aloha", "act", ["policy.n_action_steps=1000", "policy.chunk_size=1000"], "_1000_steps"),
|
("aloha", "act", ["policy.n_action_steps=1000", "policy.chunk_size=1000"], "_1000_steps"),
|
||||||
("dora_aloha_real", "act_real", ["policy.n_action_steps=10"], ""),
|
("dora_aloha_real", "act_real", ["policy.n_action_steps=10"], ""),
|
||||||
@@ -461,7 +471,3 @@ def test_act_temporal_ensembler():
|
|||||||
assert torch.all(offline_avg <= einops.reduce(seq_slice, "b s 1 -> b 1", "max"))
|
assert torch.all(offline_avg <= einops.reduce(seq_slice, "b s 1 -> b 1", "max"))
|
||||||
# Selected atol=1e-4 keeping in mind actions in [-1, 1] and excepting 0.01% error.
|
# Selected atol=1e-4 keeping in mind actions in [-1, 1] and excepting 0.01% error.
|
||||||
assert torch.allclose(online_avg, offline_avg, atol=1e-4)
|
assert torch.allclose(online_avg, offline_avg, atol=1e-4)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
test_act_temporal_ensembler()
|
|
||||||
|
|||||||
@@ -25,13 +25,13 @@ from lerobot.scripts.visualize_dataset import visualize_dataset
|
|||||||
["lerobot/pusht"],
|
["lerobot/pusht"],
|
||||||
)
|
)
|
||||||
@pytest.mark.parametrize("root", [Path(__file__).parent / "data"])
|
@pytest.mark.parametrize("root", [Path(__file__).parent / "data"])
|
||||||
def test_visualize_local_dataset(tmpdir, repo_id, root):
|
def test_visualize_dataset_root(tmpdir, repo_id, root):
|
||||||
rrd_path = visualize_dataset(
|
rrd_path = visualize_dataset(
|
||||||
repo_id,
|
repo_id,
|
||||||
|
root=root,
|
||||||
episode_index=0,
|
episode_index=0,
|
||||||
batch_size=32,
|
batch_size=32,
|
||||||
save=True,
|
save=True,
|
||||||
output_dir=tmpdir,
|
output_dir=tmpdir,
|
||||||
root=root,
|
|
||||||
)
|
)
|
||||||
assert rrd_path.exists()
|
assert rrd_path.exists()
|
||||||
|
|||||||
Reference in New Issue
Block a user