Added normalization schemes and style checks

This commit is contained in:
Michel Aractingi
2024-12-29 12:51:21 +00:00
parent 08ec971086
commit dc54d357ca
10 changed files with 150 additions and 156 deletions

View File

@@ -24,7 +24,7 @@ python lerobot/scripts/eval_on_robot.py \
```
**NOTE** (michel-aractingi): This script is incomplete and it is being prepared
for running training on the real robot.
for running training on the real robot.
"""
import argparse
@@ -47,7 +47,7 @@ from lerobot.common.utils.utils import (
def rollout(robot: Robot, policy: Policy, fps: int, control_time_s: float = 20, use_amp: bool = True) -> dict:
"""Run a batched policy rollout on the real robot.
"""Run a batched policy rollout on the real robot.
The return dictionary contains:
"robot": A a dictionary of (batch, sequence + 1, *) tensors mapped to observation
@@ -64,7 +64,7 @@ def rollout(robot: Robot, policy: Policy, fps: int, control_time_s: float = 20,
extraneous elements from the sequences above.
Args:
robot: The robot class that defines the interface with the real robot.
robot: The robot class that defines the interface with the real robot.
policy: The policy. Must be a PyTorch nn module.
Returns:
@@ -77,7 +77,7 @@ def rollout(robot: Robot, policy: Policy, fps: int, control_time_s: float = 20,
listener, events = init_keyboard_listener()
# Reset the policy. TODO (michel-aractingi) add real policy evaluation once the code is ready.
# policy.reset()
# policy.reset()
# Get observation from real robot
observation = robot.capture_observation()

View File

@@ -95,12 +95,14 @@ def make_optimizer_and_scheduler(cfg, policy):
lr_scheduler = None
elif policy.name == "sac":
optimizer = torch.optim.Adam([
{'params': policy.actor.parameters(), 'lr': policy.config.actor_lr},
{'params': policy.critic_ensemble.parameters(), 'lr': policy.config.critic_lr},
{'params': policy.temperature.parameters(), 'lr': policy.config.temperature_lr},
])
lr_scheduler = None
optimizer = torch.optim.Adam(
[
{"params": policy.actor.parameters(), "lr": policy.config.actor_lr},
{"params": policy.critic_ensemble.parameters(), "lr": policy.config.critic_lr},
{"params": policy.temperature.parameters(), "lr": policy.config.temperature_lr},
]
)
lr_scheduler = None
elif cfg.policy.name == "vqbet":
from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTOptimizer, VQBeTScheduler

View File

@@ -22,7 +22,6 @@ from pprint import pformat
import hydra
import torch
import torch.nn as nn
import wandb
from deepdiff import DeepDiff
from omegaconf import DictConfig, OmegaConf
from termcolor import colored
@@ -31,6 +30,7 @@ from torch.cuda.amp import GradScaler
from torch.utils.data import DataLoader, WeightedRandomSampler, random_split
from tqdm import tqdm
import wandb
from lerobot.common.datasets.factory import resolve_delta_timestamps
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.logger import Logger