Compare commits
4 Commits
my-fix-bas
...
user/rcade
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b502a82005 | ||
|
|
12a1b8f55a | ||
|
|
205e0c9dde | ||
|
|
5b74205e16 |
@@ -233,9 +233,6 @@ class Logger:
|
||||
if self._wandb is not None:
|
||||
for k, v in d.items():
|
||||
if not isinstance(v, (int, float, str)):
|
||||
logging.warning(
|
||||
f'WandB logging of key "{k}" was ignored as its type is not handled by this wrapper.'
|
||||
)
|
||||
continue
|
||||
self._wandb.log({f"{mode}/{k}": v}, step=step)
|
||||
|
||||
|
||||
@@ -139,25 +139,26 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
batch = self.normalize_targets(batch)
|
||||
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
|
||||
|
||||
l1_loss = (
|
||||
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
|
||||
).mean()
|
||||
bsize = actions_hat.shape[0]
|
||||
l1_loss = F.l1_loss(batch["action"], actions_hat, reduction="none")
|
||||
l1_loss = l1_loss * ~batch["action_is_pad"].unsqueeze(-1)
|
||||
l1_loss = l1_loss.view(bsize, -1).mean(dim=1)
|
||||
|
||||
out_dict = {}
|
||||
out_dict["l1_loss"] = l1_loss
|
||||
|
||||
loss_dict = {"l1_loss": l1_loss.item()}
|
||||
if self.config.use_vae:
|
||||
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
|
||||
# each dimension independently, we sum over the latent dimension to get the total
|
||||
# KL-divergence per batch element, then take the mean over the batch.
|
||||
# (See App. B of https://arxiv.org/abs/1312.6114 for more details).
|
||||
mean_kld = (
|
||||
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
|
||||
)
|
||||
loss_dict["kld_loss"] = mean_kld.item()
|
||||
loss_dict["loss"] = l1_loss + mean_kld * self.config.kl_weight
|
||||
kld_loss = (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1)
|
||||
out_dict["loss"] = l1_loss + kld_loss * self.config.kl_weight
|
||||
else:
|
||||
loss_dict["loss"] = l1_loss
|
||||
out_dict["loss"] = l1_loss
|
||||
|
||||
return loss_dict
|
||||
out_dict["action"] = self.unnormalize_outputs({"action": actions_hat})["action"]
|
||||
return out_dict
|
||||
|
||||
|
||||
class ACT(nn.Module):
|
||||
|
||||
@@ -25,7 +25,7 @@ training:
|
||||
online_steps_between_rollouts: 1
|
||||
|
||||
delta_timestamps:
|
||||
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
|
||||
action: "[i / ${fps} for i in range(1, ${policy.chunk_size} + 1)]"
|
||||
|
||||
eval:
|
||||
n_episodes: 50
|
||||
|
||||
@@ -51,7 +51,7 @@ training:
|
||||
online_steps_between_rollouts: 1
|
||||
|
||||
delta_timestamps:
|
||||
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
|
||||
action: "[i / ${fps} for i in range(1, ${policy.chunk_size} + 1)]"
|
||||
|
||||
eval:
|
||||
n_episodes: 50
|
||||
|
||||
@@ -49,7 +49,7 @@ training:
|
||||
online_steps_between_rollouts: 1
|
||||
|
||||
delta_timestamps:
|
||||
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
|
||||
action: "[i / ${fps} for i in range(1, ${policy.chunk_size} + 1)]"
|
||||
|
||||
eval:
|
||||
n_episodes: 50
|
||||
|
||||
@@ -108,7 +108,7 @@ def update_policy(
|
||||
with torch.autocast(device_type=device.type) if use_amp else nullcontext():
|
||||
output_dict = policy.forward(batch)
|
||||
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
|
||||
loss = output_dict["loss"]
|
||||
loss = output_dict["loss"].mean()
|
||||
grad_scaler.scale(loss).backward()
|
||||
|
||||
# Unscale the graident of the optimzer's assigned params in-place **prior to gradient clipping**.
|
||||
|
||||
@@ -27,10 +27,9 @@ from lerobot.scripts.visualize_dataset import visualize_dataset
|
||||
def test_visualize_dataset(tmpdir, repo_id):
|
||||
rrd_path = visualize_dataset(
|
||||
repo_id,
|
||||
episode_index=0,
|
||||
batch_size=32,
|
||||
save=True,
|
||||
episode_indices=[0],
|
||||
output_dir=tmpdir,
|
||||
serve=False,
|
||||
)
|
||||
assert rrd_path.exists()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user