forked from tangger/lerobot
Compare commits
13 Commits
main
...
temp_branc
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
def42ff487 | ||
|
|
c9af8e36a7 | ||
|
|
ed66c92383 | ||
|
|
668d493bf9 | ||
|
|
67f4d7ea7a | ||
|
|
4b0c88ff8e | ||
|
|
b19fef9d18 | ||
|
|
1612e00e63 | ||
|
|
c3bc136420 | ||
|
|
1020bc3108 | ||
|
|
7fcf638c0d | ||
|
|
e35546f58e | ||
|
|
1aa8d4ac91 |
83
examples/12_train_hilserl_classifier.md
Normal file
83
examples/12_train_hilserl_classifier.md
Normal file
@@ -0,0 +1,83 @@
|
||||
# Training a HIL-SERL Reward Classifier with LeRobot
|
||||
|
||||
This tutorial provides step-by-step instructions for training a reward classifier using LeRobot.
|
||||
|
||||
---
|
||||
|
||||
## Training Script Overview
|
||||
|
||||
LeRobot includes a ready-to-use training script located at [`lerobot/scripts/train_hilserl_classifier.py`](../../lerobot/scripts/train_hilserl_classifier.py). Here's an outline of its workflow:
|
||||
|
||||
1. **Configuration Loading**
|
||||
The script uses Hydra to load a configuration file for subsequent steps. (Details on Hydra follow below.)
|
||||
|
||||
2. **Dataset Initialization**
|
||||
It loads a `LeRobotDataset` containing images and rewards. To optimize performance, a weighted random sampler is used to balance class sampling.
|
||||
|
||||
3. **Classifier Initialization**
|
||||
A lightweight classification head is built on top of a frozen, pretrained image encoder from HuggingFace. The classifier outputs either:
|
||||
- A single probability (binary classification), or
|
||||
- Logits (multi-class classification).
|
||||
|
||||
4. **Training Loop Execution**
|
||||
The script performs:
|
||||
- Forward and backward passes,
|
||||
- Optimization steps,
|
||||
- Periodic logging, evaluation, and checkpoint saving.
|
||||
|
||||
---
|
||||
|
||||
## Configuring with Hydra
|
||||
|
||||
For detailed information about Hydra usage, refer to [`examples/4_train_policy_with_script.md`](../examples/4_train_policy_with_script.md). However, note that training the reward classifier differs slightly and requires a separate configuration file.
|
||||
|
||||
### Config File Setup
|
||||
|
||||
The default `default.yaml` cannot launch the reward classifier training directly. Instead, you need a configuration file like [`lerobot/configs/policy/hilserl_classifier.yaml`](../../lerobot/configs/policy/hilserl_classifier.yaml), with the following adjustment:
|
||||
|
||||
Replace the `dataset_repo_id` field with the identifier for your dataset, which contains images and sparse rewards:
|
||||
|
||||
```yaml
|
||||
# Example: lerobot/configs/policy/reward_classifier.yaml
|
||||
dataset_repo_id: "my_dataset_repo_id"
|
||||
## Typical logs and metrics
|
||||
```
|
||||
When you start the training process, you will first see your full configuration being printed in the terminal. You can check it to make sure that you config it correctly and your config is not overrided by other files. The final configuration will also be saved with the checkpoint.
|
||||
|
||||
After that, you will see training log like this one:
|
||||
|
||||
```
|
||||
[2024-11-29 18:26:36,999][root][INFO] -
|
||||
Epoch 5/5
|
||||
Training: 82%|██████████████████████████████████████████████████████████████████████████████▋ | 91/111 [00:50<00:09, 2.04it/s, loss=0.2999, acc=69.99%]
|
||||
```
|
||||
|
||||
or evaluation log like:
|
||||
|
||||
```
|
||||
Validation: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:20<00:00, 1.37it/s]
|
||||
```
|
||||
|
||||
### Metrics Tracking with Weights & Biases (WandB)
|
||||
|
||||
If `wandb.enable` is set to `true`, the training and evaluation logs will also be saved in WandB. This allows you to track key metrics in real-time, including:
|
||||
|
||||
- **Training Metrics**:
|
||||
- `train/accuracy`
|
||||
- `train/loss`
|
||||
- `train/dataloading_s`
|
||||
- **Evaluation Metrics**:
|
||||
- `eval/accuracy`
|
||||
- `eval/loss`
|
||||
- `eval/eval_s`
|
||||
|
||||
#### Additional Features
|
||||
|
||||
You can also log sample predictions during evaluation. Each logged sample will include:
|
||||
|
||||
- The **input image**.
|
||||
- The **predicted label**.
|
||||
- The **true label**.
|
||||
- The **classifier's "confidence" (logits/probability)**.
|
||||
|
||||
These logs can be useful for diagnosing and debugging performance issues.
|
||||
@@ -291,7 +291,7 @@ class LeRobotDatasetMetadata:
|
||||
obj.root.mkdir(parents=True, exist_ok=False)
|
||||
|
||||
if robot is not None:
|
||||
features = get_features_from_robot(robot, use_videos)
|
||||
features = {**(features or {}), **get_features_from_robot(robot)}
|
||||
robot_type = robot.robot_type
|
||||
if not all(cam.fps == fps for cam in robot.cameras.values()):
|
||||
logging.warning(
|
||||
|
||||
@@ -25,6 +25,7 @@ from glob import glob
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import wandb
|
||||
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from termcolor import colored
|
||||
@@ -107,8 +108,6 @@ class Logger:
|
||||
self._wandb = None
|
||||
else:
|
||||
os.environ["WANDB_SILENT"] = "true"
|
||||
import wandb
|
||||
|
||||
wandb_run_id = None
|
||||
if cfg.resume:
|
||||
wandb_run_id = get_wandb_run_id_from_filesystem(self.checkpoints_dir)
|
||||
@@ -232,7 +231,7 @@ class Logger:
|
||||
# TODO(alexander-soare): Add local text log.
|
||||
if self._wandb is not None:
|
||||
for k, v in d.items():
|
||||
if not isinstance(v, (int, float, str)):
|
||||
if not isinstance(v, (int, float, str, wandb.Table)):
|
||||
logging.warning(
|
||||
f'WandB logging of key "{k}" was ignored as its type is not handled by this wrapper.'
|
||||
)
|
||||
|
||||
@@ -0,0 +1,36 @@
|
||||
import json
|
||||
import os
|
||||
from dataclasses import asdict, dataclass
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClassifierConfig:
|
||||
"""Configuration for the Classifier model."""
|
||||
|
||||
num_classes: int = 2
|
||||
hidden_dim: int = 256
|
||||
dropout_rate: float = 0.1
|
||||
model_name: str = "microsoft/resnet-50"
|
||||
device: str = "cuda" if torch.cuda.is_available() else "mps"
|
||||
model_type: str = "cnn" # "transformer" or "cnn"
|
||||
|
||||
def save_pretrained(self, save_dir):
|
||||
"""Save config to json file."""
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
# Convert to dict and save as JSON
|
||||
config_dict = asdict(self)
|
||||
with open(os.path.join(save_dir, "config.json"), "w") as f:
|
||||
json.dump(config_dict, f, indent=2)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path):
|
||||
"""Load config from json file."""
|
||||
config_file = os.path.join(pretrained_model_name_or_path, "config.json")
|
||||
|
||||
with open(config_file) as f:
|
||||
config_dict = json.load(f)
|
||||
|
||||
return cls(**config_dict)
|
||||
@@ -0,0 +1,134 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from huggingface_hub import PyTorchModelHubMixin
|
||||
from torch import Tensor, nn
|
||||
from transformers import AutoImageProcessor, AutoModel
|
||||
|
||||
from .configuration_classifier import ClassifierConfig
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ClassifierOutput:
|
||||
"""Wrapper for classifier outputs with additional metadata."""
|
||||
|
||||
def __init__(
|
||||
self, logits: Tensor, probabilities: Optional[Tensor] = None, hidden_states: Optional[Tensor] = None
|
||||
):
|
||||
self.logits = logits
|
||||
self.probabilities = probabilities
|
||||
self.hidden_states = hidden_states
|
||||
|
||||
|
||||
class Classifier(
|
||||
nn.Module,
|
||||
PyTorchModelHubMixin,
|
||||
# Add Hub metadata
|
||||
library_name="lerobot",
|
||||
repo_url="https://github.com/huggingface/lerobot",
|
||||
tags=["robotics", "vision-classifier"],
|
||||
):
|
||||
"""Image classifier built on top of a pre-trained encoder."""
|
||||
|
||||
# Add name attribute for factory
|
||||
name = "classifier"
|
||||
|
||||
def __init__(self, config: ClassifierConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.processor = AutoImageProcessor.from_pretrained(self.config.model_name, trust_remote_code=True)
|
||||
encoder = AutoModel.from_pretrained(self.config.model_name, trust_remote_code=True)
|
||||
# Extract vision model if we're given a multimodal model
|
||||
if hasattr(encoder, "vision_model"):
|
||||
logging.info("Multimodal model detected - using vision encoder only")
|
||||
self.encoder = encoder.vision_model
|
||||
self.vision_config = encoder.config.vision_config
|
||||
else:
|
||||
self.encoder = encoder
|
||||
self.vision_config = getattr(encoder, "config", None)
|
||||
|
||||
# Model type from config
|
||||
self.is_cnn = self.config.model_type == "cnn"
|
||||
|
||||
# For CNNs, initialize backbone
|
||||
if self.is_cnn:
|
||||
self._setup_cnn_backbone()
|
||||
|
||||
self._freeze_encoder()
|
||||
self._build_classifier_head()
|
||||
|
||||
def _setup_cnn_backbone(self):
|
||||
"""Set up CNN encoder"""
|
||||
if hasattr(self.encoder, "fc"):
|
||||
self.feature_dim = self.encoder.fc.in_features
|
||||
self.encoder = nn.Sequential(*list(self.encoder.children())[:-1])
|
||||
elif hasattr(self.encoder.config, "hidden_sizes"):
|
||||
self.feature_dim = self.encoder.config.hidden_sizes[-1] # Last channel dimension
|
||||
else:
|
||||
raise ValueError("Unsupported CNN architecture")
|
||||
|
||||
def _freeze_encoder(self) -> None:
|
||||
"""Freeze the encoder parameters."""
|
||||
for param in self.encoder.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def _build_classifier_head(self) -> None:
|
||||
"""Initialize the classifier head architecture."""
|
||||
# Get input dimension based on model type
|
||||
if self.is_cnn:
|
||||
input_dim = self.feature_dim
|
||||
else: # Transformer models
|
||||
if hasattr(self.encoder.config, "hidden_size"):
|
||||
input_dim = self.encoder.config.hidden_size
|
||||
else:
|
||||
raise ValueError("Unsupported transformer architecture since hidden_size is not found")
|
||||
|
||||
self.classifier_head = nn.Sequential(
|
||||
nn.Linear(input_dim, self.config.hidden_dim),
|
||||
nn.Dropout(self.config.dropout_rate),
|
||||
nn.LayerNorm(self.config.hidden_dim),
|
||||
nn.ReLU(),
|
||||
nn.Linear(self.config.hidden_dim, 1 if self.config.num_classes == 2 else self.config.num_classes),
|
||||
)
|
||||
|
||||
def _get_encoder_output(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Extract the appropriate output from the encoder."""
|
||||
# Process images with the processor (handles resizing and normalization)
|
||||
processed = self.processor(
|
||||
images=x, # LeRobotDataset already provides proper tensor format
|
||||
return_tensors="pt",
|
||||
)
|
||||
processed = processed["pixel_values"].to(x.device)
|
||||
|
||||
with torch.no_grad():
|
||||
if self.is_cnn:
|
||||
# The HF ResNet applies pooling internally
|
||||
outputs = self.encoder(processed)
|
||||
# Get pooled output directly
|
||||
features = outputs.pooler_output
|
||||
|
||||
if features.dim() > 2:
|
||||
features = features.squeeze(-1).squeeze(-1)
|
||||
return features
|
||||
else: # Transformer models
|
||||
outputs = self.encoder(processed)
|
||||
if hasattr(outputs, "pooler_output") and outputs.pooler_output is not None:
|
||||
return outputs.pooler_output
|
||||
return outputs.last_hidden_state[:, 0, :]
|
||||
|
||||
def forward(self, x: torch.Tensor) -> ClassifierOutput:
|
||||
"""Forward pass of the classifier."""
|
||||
# For training, we expect input to be a tensor directly from LeRobotDataset
|
||||
encoder_output = self._get_encoder_output(x)
|
||||
logits = self.classifier_head(encoder_output)
|
||||
|
||||
if self.config.num_classes == 2:
|
||||
logits = logits.squeeze(-1)
|
||||
probabilities = torch.sigmoid(logits)
|
||||
else:
|
||||
probabilities = torch.softmax(logits, dim=-1)
|
||||
|
||||
return ClassifierOutput(logits=logits, probabilities=probabilities, hidden_states=encoder_output)
|
||||
23
lerobot/common/policies/hilserl/configuration_hilserl.py
Normal file
23
lerobot/common/policies/hilserl/configuration_hilserl.py
Normal file
@@ -0,0 +1,23 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class HILSerlConfig:
|
||||
pass
|
||||
29
lerobot/common/policies/hilserl/modeling_hilserl.py
Normal file
29
lerobot/common/policies/hilserl/modeling_hilserl.py
Normal file
@@ -0,0 +1,29 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch.nn as nn
|
||||
from huggingface_hub import PyTorchModelHubMixin
|
||||
|
||||
|
||||
class HILSerlPolicy(
|
||||
nn.Module,
|
||||
PyTorchModelHubMixin,
|
||||
library_name="lerobot",
|
||||
repo_url="https://github.com/huggingface/lerobot",
|
||||
tags=["robotics", "hilserl"],
|
||||
):
|
||||
pass
|
||||
39
lerobot/common/policies/sac/configuration_sac.py
Normal file
39
lerobot/common/policies/sac/configuration_sac.py
Normal file
@@ -0,0 +1,39 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class SACConfig:
|
||||
discount = 0.99
|
||||
temperature_init = 1.0
|
||||
num_critics = 2
|
||||
critic_lr = 3e-4
|
||||
actor_lr = 3e-4
|
||||
critic_network_kwargs = {
|
||||
"hidden_dims": [256, 256],
|
||||
"activate_final": True,
|
||||
}
|
||||
actor_network_kwargs = {
|
||||
"hidden_dims": [256, 256],
|
||||
"activate_final": True,
|
||||
}
|
||||
policy_kwargs = {
|
||||
"tanh_squash_distribution": True,
|
||||
"std_parameterization": "uniform",
|
||||
}
|
||||
683
lerobot/common/policies/sac/modeling_sac.py
Normal file
683
lerobot/common/policies/sac/modeling_sac.py
Normal file
@@ -0,0 +1,683 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# TODO: (1) better device management
|
||||
|
||||
from collections import deque
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
|
||||
import einops
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor
|
||||
|
||||
from huggingface_hub import PyTorchModelHubMixin
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.sac.configuration_sac import SACConfig
|
||||
import numpy as np
|
||||
from typing import Callable, Optional, Tuple, Sequence
|
||||
|
||||
|
||||
|
||||
class SACPolicy(
|
||||
nn.Module,
|
||||
PyTorchModelHubMixin,
|
||||
library_name="lerobot",
|
||||
repo_url="https://github.com/huggingface/lerobot",
|
||||
tags=["robotics", "RL", "SAC"],
|
||||
):
|
||||
|
||||
def __init__(
|
||||
self, config: SACConfig | None = None, dataset_stats: dict[str, dict[str, Tensor]] | None = None
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
|
||||
if config is None:
|
||||
config = SACConfig()
|
||||
self.config = config
|
||||
|
||||
if config.input_normalization_modes is not None:
|
||||
self.normalize_inputs = Normalize(
|
||||
config.input_shapes, config.input_normalization_modes, dataset_stats
|
||||
)
|
||||
else:
|
||||
self.normalize_inputs = nn.Identity()
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||
)
|
||||
encoder = SACObservationEncoder(config)
|
||||
# Define networks
|
||||
critic_nets = []
|
||||
for _ in range(config.num_critics):
|
||||
critic_net = Critic(
|
||||
encoder=encoder,
|
||||
network=MLP(**config.critic_network_kwargs)
|
||||
)
|
||||
critic_nets.append(critic_net)
|
||||
|
||||
self.critic_ensemble = create_critic_ensemble(critic_nets, config.num_critics)
|
||||
self.critic_target = deepcopy(self.critic_ensemble)
|
||||
|
||||
self.actor_network = Policy(
|
||||
encoder=encoder,
|
||||
network=MLP(**config.actor_network_kwargs),
|
||||
action_dim=config.output_shapes["action"][0],
|
||||
**config.policy_kwargs
|
||||
)
|
||||
|
||||
self.temperature = LagrangeMultiplier(init_value=config.temperature_init)
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Clear observation and action queues. Should be called on `env.reset()`
|
||||
queues are populated during rollout of the policy, they contain the n latest observations and actions
|
||||
"""
|
||||
|
||||
self._queues = {
|
||||
"observation.state": deque(maxlen=1),
|
||||
"action": deque(maxlen=1),
|
||||
}
|
||||
if self._use_image:
|
||||
self._queues["observation.image"] = deque(maxlen=1)
|
||||
if self._use_env_state:
|
||||
self._queues["observation.environment_state"] = deque(maxlen=1)
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
actions, _ = self.actor_network(batch['observations'])###
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
|
||||
"""Run the batch through the model and compute the loss.
|
||||
|
||||
Returns a dictionary with loss as a tensor, and other information as native floats.
|
||||
"""
|
||||
batch = self.normalize_inputs(batch)
|
||||
# batch shape is (b, 2, ...) where index 1 returns the current observation and
|
||||
# the next observation for caluculating the right td index.
|
||||
actions = batch["action"][:, 0]
|
||||
rewards = batch["next.reward"][:, 0]
|
||||
observations = {}
|
||||
next_observations = {}
|
||||
for k in batch:
|
||||
if k.startswith("observation."):
|
||||
observations[k] = batch[k][:, 0]
|
||||
next_observations[k] = batch[k][:, 1]
|
||||
|
||||
# perform image augmentation
|
||||
|
||||
# reward bias
|
||||
# from HIL-SERL code base
|
||||
# add_or_replace={"rewards": batch["rewards"] + self.config["reward_bias"]} in reward_batch
|
||||
|
||||
|
||||
# calculate critics loss
|
||||
# 1- compute actions from policy
|
||||
action_preds, log_probs = self.actor_network(observations)
|
||||
# 2- compute q targets
|
||||
q_targets = self.target_qs(next_observations, action_preds)
|
||||
|
||||
# critics subsample size
|
||||
min_q = q_targets.min(dim=0)
|
||||
|
||||
# backup entropy
|
||||
td_target = rewards + self.discount * min_q
|
||||
|
||||
# 3- compute predicted qs
|
||||
q_preds = self.critic_ensemble(observations, actions)
|
||||
|
||||
# 4- Calculate loss
|
||||
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
||||
critics_loss = (
|
||||
F.mse_loss(
|
||||
q_preds,
|
||||
einops.repeat(td_target, "t b -> e t b", e=q_preds.shape[0]),
|
||||
reduction="none",
|
||||
).sum(0) # sum over ensemble
|
||||
# `q_preds_ensemble` depends on the first observation and the actions.
|
||||
* ~batch["observation.state_is_pad"][0]
|
||||
* ~batch["action_is_pad"]
|
||||
# q_targets depends on the reward and the next observations.
|
||||
* ~batch["next.reward_is_pad"]
|
||||
* ~batch["observation.state_is_pad"][1:]
|
||||
).sum(0).mean()
|
||||
|
||||
# calculate actors loss
|
||||
# 1- temperature
|
||||
temperature = self.temperature()
|
||||
|
||||
# 2- get actions (batch_size, action_dim) and log probs (batch_size,)
|
||||
actions, log_probs = self.actor_network(observations) \
|
||||
|
||||
# 3- get q-value predictions
|
||||
with torch.no_grad():
|
||||
q_preds = self.critic_ensemble(observations, actions, return_type="mean")
|
||||
actor_loss = (
|
||||
-(q_preds - temperature * log_probs).mean()
|
||||
* ~batch["observation.state_is_pad"][0]
|
||||
* ~batch["action_is_pad"]
|
||||
).mean()
|
||||
|
||||
|
||||
# calculate temperature loss
|
||||
# 1- calculate entropy
|
||||
entropy = -log_probs.mean()
|
||||
temperature_loss = temperature * (entropy - self.target_entropy).mean()
|
||||
|
||||
loss = critics_loss + actor_loss + temperature_loss
|
||||
|
||||
return {
|
||||
"critics_loss": critics_loss.item(),
|
||||
"actor_loss": actor_loss.item(),
|
||||
"temperature_loss": temperature_loss.item(),
|
||||
"temperature": temperature.item(),
|
||||
"entropy": entropy.item(),
|
||||
"loss": loss,
|
||||
|
||||
}
|
||||
|
||||
def update(self):
|
||||
self.critic_target.lerp_(self.critic_ensemble, self.config.critic_target_update_weight)
|
||||
#for target_param, param in zip(self.critic_target.parameters(), self.critic_ensemble.parameters()):
|
||||
# target_param.data.copy_(target_param.data * (1.0 - self.config.critic_target_update_weight) + param.data * self.critic_target_update_weight)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: SACConfig,
|
||||
activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(),
|
||||
activate_final: bool = False,
|
||||
dropout_rate: Optional[float] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.activate_final = config.activate_final
|
||||
layers = []
|
||||
|
||||
for i, size in enumerate(config.network_hidden_dims):
|
||||
layers.append(nn.Linear(config.network_hidden_dims[i-1] if i > 0 else config.network_hidden_dims[0], size))
|
||||
|
||||
if i + 1 < len(config.network_hidden_dims) or activate_final:
|
||||
if dropout_rate is not None and dropout_rate > 0:
|
||||
layers.append(nn.Dropout(p=dropout_rate))
|
||||
layers.append(nn.LayerNorm(size))
|
||||
layers.append(activations if isinstance(activations, nn.Module) else getattr(nn, activations)())
|
||||
|
||||
self.net = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x: torch.Tensor, train: bool = False) -> torch.Tensor:
|
||||
# in training mode or not. TODO: find better way to do this
|
||||
self.train(train)
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class Critic(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
encoder: Optional[nn.Module],
|
||||
network: nn.Module,
|
||||
init_final: Optional[float] = None,
|
||||
activate_final: bool = False,
|
||||
device: str = "cuda"
|
||||
):
|
||||
super().__init__()
|
||||
self.device = torch.device(device)
|
||||
self.encoder = encoder
|
||||
self.network = network
|
||||
self.init_final = init_final
|
||||
self.activate_final = activate_final
|
||||
|
||||
# Output layer
|
||||
if init_final is not None:
|
||||
if self.activate_final:
|
||||
self.output_layer = nn.Linear(network.net[-3].out_features, 1)
|
||||
else:
|
||||
self.output_layer = nn.Linear(network.net[-2].out_features, 1)
|
||||
nn.init.uniform_(self.output_layer.weight, -init_final, init_final)
|
||||
nn.init.uniform_(self.output_layer.bias, -init_final, init_final)
|
||||
else:
|
||||
if self.activate_final:
|
||||
self.output_layer = nn.Linear(network.net[-3].out_features, 1)
|
||||
else:
|
||||
self.output_layer = nn.Linear(network.net[-2].out_features, 1)
|
||||
orthogonal_init()(self.output_layer.weight)
|
||||
|
||||
self.to(self.device)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
observations: torch.Tensor,
|
||||
actions: torch.Tensor,
|
||||
train: bool = False
|
||||
) -> torch.Tensor:
|
||||
self.train(train)
|
||||
|
||||
observations = observations.to(self.device)
|
||||
actions = actions.to(self.device)
|
||||
|
||||
if self.encoder is not None:
|
||||
obs_enc = self.encoder(observations)
|
||||
else:
|
||||
obs_enc = observations
|
||||
|
||||
inputs = torch.cat([obs_enc, actions], dim=-1)
|
||||
x = self.network(inputs)
|
||||
value = self.output_layer(x)
|
||||
return value.squeeze(-1)
|
||||
|
||||
def q_value_ensemble(
|
||||
self,
|
||||
observations: torch.Tensor,
|
||||
actions: torch.Tensor,
|
||||
train: bool = False
|
||||
) -> torch.Tensor:
|
||||
observations = observations.to(self.device)
|
||||
actions = actions.to(self.device)
|
||||
|
||||
if len(actions.shape) == 3: # [batch_size, num_actions, action_dim]
|
||||
batch_size, num_actions = actions.shape[:2]
|
||||
obs_expanded = observations.unsqueeze(1).expand(-1, num_actions, -1)
|
||||
obs_flat = obs_expanded.reshape(-1, observations.shape[-1])
|
||||
actions_flat = actions.reshape(-1, actions.shape[-1])
|
||||
q_values = self(obs_flat, actions_flat, train)
|
||||
return q_values.reshape(batch_size, num_actions)
|
||||
else:
|
||||
return self(observations, actions, train)
|
||||
|
||||
|
||||
class Policy(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
encoder: Optional[nn.Module],
|
||||
network: nn.Module,
|
||||
action_dim: int,
|
||||
std_parameterization: str = "exp",
|
||||
std_min: float = 1e-5,
|
||||
std_max: float = 10.0,
|
||||
tanh_squash_distribution: bool = False,
|
||||
fixed_std: Optional[torch.Tensor] = None,
|
||||
init_final: Optional[float] = None,
|
||||
activate_final: bool = False,
|
||||
device: str = "cuda"
|
||||
):
|
||||
super().__init__()
|
||||
self.device = torch.device(device)
|
||||
self.encoder = encoder
|
||||
self.network = network
|
||||
self.action_dim = action_dim
|
||||
self.std_parameterization = std_parameterization
|
||||
self.std_min = std_min
|
||||
self.std_max = std_max
|
||||
self.tanh_squash_distribution = tanh_squash_distribution
|
||||
self.fixed_std = fixed_std.to(self.device) if fixed_std is not None else None
|
||||
self.activate_final = activate_final
|
||||
|
||||
# Mean layer
|
||||
if self.activate_final:
|
||||
self.mean_layer = nn.Linear(network.net[-3].out_features, action_dim)
|
||||
else:
|
||||
self.mean_layer = nn.Linear(network.net[-2].out_features, action_dim)
|
||||
if init_final is not None:
|
||||
nn.init.uniform_(self.mean_layer.weight, -init_final, init_final)
|
||||
nn.init.uniform_(self.mean_layer.bias, -init_final, init_final)
|
||||
else:
|
||||
orthogonal_init()(self.mean_layer.weight)
|
||||
|
||||
# Standard deviation layer or parameter
|
||||
if fixed_std is None:
|
||||
if std_parameterization == "uniform":
|
||||
self.log_stds = nn.Parameter(torch.zeros(action_dim, device=self.device))
|
||||
else:
|
||||
if self.activate_final:
|
||||
self.std_layer = nn.Linear(network.net[-3].out_features, action_dim)
|
||||
else:
|
||||
self.std_layer = nn.Linear(network.net[-2].out_features, action_dim)
|
||||
if init_final is not None:
|
||||
nn.init.uniform_(self.std_layer.weight, -init_final, init_final)
|
||||
nn.init.uniform_(self.std_layer.bias, -init_final, init_final)
|
||||
else:
|
||||
orthogonal_init()(self.std_layer.weight)
|
||||
|
||||
self.to(self.device)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
observations: torch.Tensor,
|
||||
temperature: float = 1.0,
|
||||
train: bool = False,
|
||||
non_squash_distribution: bool = False
|
||||
) -> torch.distributions.Distribution:
|
||||
self.train(train)
|
||||
|
||||
# Encode observations if encoder exists
|
||||
if self.encoder is not None:
|
||||
with torch.set_grad_enabled(train):
|
||||
obs_enc = self.encoder(observations, train=train)
|
||||
else:
|
||||
obs_enc = observations
|
||||
# Get network outputs
|
||||
outputs = self.network(obs_enc)
|
||||
means = self.mean_layer(outputs)
|
||||
|
||||
# Compute standard deviations
|
||||
if self.fixed_std is None:
|
||||
if self.std_parameterization == "exp":
|
||||
log_stds = self.std_layer(outputs)
|
||||
stds = torch.exp(log_stds)
|
||||
elif self.std_parameterization == "softplus":
|
||||
stds = torch.nn.functional.softplus(self.std_layer(outputs))
|
||||
elif self.std_parameterization == "uniform":
|
||||
stds = torch.exp(self.log_stds).expand_as(means)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid std_parameterization: {self.std_parameterization}"
|
||||
)
|
||||
else:
|
||||
assert self.std_parameterization == "fixed"
|
||||
stds = self.fixed_std.expand_as(means)
|
||||
|
||||
# Clip standard deviations and scale with temperature
|
||||
temperature = torch.tensor(temperature, device=self.device)
|
||||
stds = torch.clamp(stds, self.std_min, self.std_max) * torch.sqrt(temperature)
|
||||
|
||||
# Create distribution
|
||||
if self.tanh_squash_distribution and not non_squash_distribution:
|
||||
distribution = TanhMultivariateNormalDiag(
|
||||
loc=means,
|
||||
scale_diag=stds,
|
||||
)
|
||||
else:
|
||||
distribution = torch.distributions.Normal(
|
||||
loc=means,
|
||||
scale=stds,
|
||||
)
|
||||
|
||||
return distribution
|
||||
|
||||
def get_features(self, observations: torch.Tensor) -> torch.Tensor:
|
||||
"""Get encoded features from observations"""
|
||||
observations = observations.to(self.device)
|
||||
if self.encoder is not None:
|
||||
with torch.no_grad():
|
||||
return self.encoder(observations, train=False)
|
||||
return observations
|
||||
|
||||
|
||||
class SACObservationEncoder(nn.Module):
|
||||
"""Encode image and/or state vector observations.
|
||||
TODO(ke-wang): The original work allows for (1) stacking multiple history frames and (2) using pretrained resnet encoders.
|
||||
"""
|
||||
|
||||
def __init__(self, config: SACConfig):
|
||||
"""
|
||||
Creates encoders for pixel and/or state modalities.
|
||||
"""
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
if "observation.image" in config.input_shapes:
|
||||
self.image_enc_layers = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
config.input_shapes["observation.image"][0], config.image_encoder_hidden_dim, 7, stride=2
|
||||
),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 5, stride=2),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2),
|
||||
nn.ReLU(),
|
||||
)
|
||||
dummy_batch = torch.zeros(1, *config.input_shapes["observation.image"])
|
||||
with torch.inference_mode():
|
||||
out_shape = self.image_enc_layers(dummy_batch).shape[1:]
|
||||
self.image_enc_layers.extend(
|
||||
nn.Sequential(
|
||||
nn.Flatten(),
|
||||
nn.Linear(np.prod(out_shape), config.latent_dim),
|
||||
nn.LayerNorm(config.latent_dim),
|
||||
nn.Tanh(),
|
||||
)
|
||||
)
|
||||
if "observation.state" in config.input_shapes:
|
||||
self.state_enc_layers = nn.Sequential(
|
||||
nn.Linear(config.input_shapes["observation.state"][0], config.state_encoder_hidden_dim),
|
||||
nn.ELU(),
|
||||
nn.Linear(config.state_encoder_hidden_dim, config.latent_dim),
|
||||
nn.LayerNorm(config.latent_dim),
|
||||
nn.Tanh(),
|
||||
)
|
||||
if "observation.environment_state" in config.input_shapes:
|
||||
self.env_state_enc_layers = nn.Sequential(
|
||||
nn.Linear(
|
||||
config.input_shapes["observation.environment_state"][0], config.state_encoder_hidden_dim
|
||||
),
|
||||
nn.ELU(),
|
||||
nn.Linear(config.state_encoder_hidden_dim, config.latent_dim),
|
||||
nn.LayerNorm(config.latent_dim),
|
||||
nn.Tanh(),
|
||||
)
|
||||
|
||||
def forward(self, obs_dict: dict[str, Tensor]) -> Tensor:
|
||||
"""Encode the image and/or state vector.
|
||||
|
||||
Each modality is encoded into a feature vector of size (latent_dim,) and then a uniform mean is taken
|
||||
over all features.
|
||||
"""
|
||||
feat = []
|
||||
# Concatenate all images along the channel dimension.
|
||||
image_keys = [k for k in self.config.input_shapes if k.startswith("observation.image")]
|
||||
for image_key in image_keys:
|
||||
feat.append(flatten_forward_unflatten(self.image_enc_layers, obs_dict[image_key]))
|
||||
if "observation.environment_state" in self.config.input_shapes:
|
||||
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
|
||||
if "observation.state" in self.config.input_shapes:
|
||||
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
|
||||
return torch.stack(feat, dim=0).mean(0)
|
||||
|
||||
|
||||
class LagrangeMultiplier(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
init_value: float = 1.0,
|
||||
constraint_shape: Sequence[int] = (),
|
||||
device: str = "cuda"
|
||||
):
|
||||
super().__init__()
|
||||
self.device = torch.device(device)
|
||||
init_value = torch.log(torch.exp(torch.tensor(init_value, device=self.device)) - 1)
|
||||
|
||||
# Initialize the Lagrange multiplier as a parameter
|
||||
self.lagrange = nn.Parameter(
|
||||
torch.full(constraint_shape, init_value, dtype=torch.float32, device=self.device)
|
||||
)
|
||||
|
||||
self.to(self.device)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
lhs: Optional[torch.Tensor] = None,
|
||||
rhs: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
# Get the multiplier value based on parameterization
|
||||
multiplier = torch.nn.functional.softplus(self.lagrange)
|
||||
|
||||
# Return the raw multiplier if no constraint values provided
|
||||
if lhs is None:
|
||||
return multiplier
|
||||
|
||||
# Move inputs to device
|
||||
lhs = lhs.to(self.device)
|
||||
if rhs is not None:
|
||||
rhs = rhs.to(self.device)
|
||||
|
||||
# Use the multiplier to compute the Lagrange penalty
|
||||
if rhs is None:
|
||||
rhs = torch.zeros_like(lhs, device=self.device)
|
||||
|
||||
diff = lhs - rhs
|
||||
|
||||
assert diff.shape == multiplier.shape, f"Shape mismatch: {diff.shape} vs {multiplier.shape}"
|
||||
|
||||
return multiplier * diff
|
||||
|
||||
|
||||
# The TanhMultivariateNormalDiag is a probability distribution that represents a transformed normal (Gaussian) distribution where:
|
||||
# 1. The base distribution is a diagonal multivariate normal distribution
|
||||
# 2. The samples from this normal distribution are transformed through a tanh function, which squashes the values to be between -1 and 1
|
||||
# 3. Optionally, the values can be further transformed to fit within arbitrary bounds [low, high] using an affine transformation
|
||||
# This type of distribution is commonly used in reinforcement learning, particularly for continuous action spaces
|
||||
class TanhMultivariateNormalDiag(torch.distributions.TransformedDistribution):
|
||||
def __init__(
|
||||
self,
|
||||
loc: torch.Tensor,
|
||||
scale_diag: torch.Tensor,
|
||||
low: Optional[torch.Tensor] = None,
|
||||
high: Optional[torch.Tensor] = None,
|
||||
):
|
||||
# Create base normal distribution
|
||||
base_distribution = torch.distributions.Normal(loc=loc, scale=scale_diag)
|
||||
|
||||
# Create list of transforms
|
||||
transforms = []
|
||||
|
||||
# Add tanh transform
|
||||
transforms.append(torch.distributions.transforms.TanhTransform())
|
||||
|
||||
# Add rescaling transform if bounds are provided
|
||||
if low is not None and high is not None:
|
||||
transforms.append(
|
||||
torch.distributions.transforms.AffineTransform(
|
||||
loc=(high + low) / 2,
|
||||
scale=(high - low) / 2
|
||||
)
|
||||
)
|
||||
|
||||
# Initialize parent class
|
||||
super().__init__(
|
||||
base_distribution=base_distribution,
|
||||
transforms=transforms
|
||||
)
|
||||
|
||||
# Store parameters
|
||||
self.loc = loc
|
||||
self.scale_diag = scale_diag
|
||||
self.low = low
|
||||
self.high = high
|
||||
|
||||
def mode(self) -> torch.Tensor:
|
||||
"""Get the mode of the transformed distribution"""
|
||||
# The mode of a normal distribution is its mean
|
||||
mode = self.loc
|
||||
|
||||
# Apply transforms
|
||||
for transform in self.transforms:
|
||||
mode = transform(mode)
|
||||
|
||||
return mode
|
||||
|
||||
def rsample(self, sample_shape=torch.Size()) -> torch.Tensor:
|
||||
"""
|
||||
Reparameterized sample from the distribution
|
||||
"""
|
||||
# Sample from base distribution
|
||||
x = self.base_dist.rsample(sample_shape)
|
||||
|
||||
# Apply transforms
|
||||
for transform in self.transforms:
|
||||
x = transform(x)
|
||||
|
||||
return x
|
||||
|
||||
def log_prob(self, value: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Compute log probability of a value
|
||||
Includes the log det jacobian for the transforms
|
||||
"""
|
||||
# Initialize log prob
|
||||
log_prob = torch.zeros_like(value[..., 0])
|
||||
|
||||
# Inverse transforms to get back to normal distribution
|
||||
q = value
|
||||
for transform in reversed(self.transforms):
|
||||
q = transform.inv(q)
|
||||
log_prob = log_prob - transform.log_abs_det_jacobian(q, transform(q))
|
||||
|
||||
# Add base distribution log prob
|
||||
log_prob = log_prob + self.base_dist.log_prob(q).sum(-1)
|
||||
|
||||
return log_prob
|
||||
|
||||
def sample_and_log_prob(self, sample_shape=torch.Size()) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Sample from the distribution and compute log probability
|
||||
"""
|
||||
x = self.rsample(sample_shape)
|
||||
log_prob = self.log_prob(x)
|
||||
return x, log_prob
|
||||
|
||||
def entropy(self) -> torch.Tensor:
|
||||
"""
|
||||
Compute entropy of the distribution
|
||||
"""
|
||||
# Start with base distribution entropy
|
||||
entropy = self.base_dist.entropy().sum(-1)
|
||||
|
||||
# Add log det jacobian for each transform
|
||||
x = self.rsample()
|
||||
for transform in self.transforms:
|
||||
entropy = entropy + transform.log_abs_det_jacobian(x, transform(x))
|
||||
x = transform(x)
|
||||
|
||||
return entropy
|
||||
|
||||
|
||||
def create_critic_ensemble(critic_class, num_critics: int, device: str = "cuda") -> nn.ModuleList:
|
||||
"""Creates an ensemble of critic networks"""
|
||||
critics = nn.ModuleList([critic_class() for _ in range(num_critics)])
|
||||
return critics.to(device)
|
||||
|
||||
|
||||
def orthogonal_init():
|
||||
return lambda x: torch.nn.init.orthogonal_(x, gain=1.0)
|
||||
|
||||
|
||||
# borrowed from tdmpc
|
||||
def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor:
|
||||
"""Helper to temporarily flatten extra dims at the start of the image tensor.
|
||||
|
||||
Args:
|
||||
fn: Callable that the image tensor will be passed to. It should accept (B, C, H, W) and return
|
||||
(B, *), where * is any number of dimensions.
|
||||
image_tensor: An image tensor of shape (**, C, H, W), where ** is any number of dimensions and
|
||||
can be more than 1 dimensions, generally different from *.
|
||||
Returns:
|
||||
A return value from the callable reshaped to (**, *).
|
||||
"""
|
||||
if image_tensor.ndim == 4:
|
||||
return fn(image_tensor)
|
||||
start_dims = image_tensor.shape[:-3]
|
||||
inp = torch.flatten(image_tensor, end_dim=-4)
|
||||
flat_out = fn(inp)
|
||||
return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:]))
|
||||
|
||||
@@ -120,14 +120,22 @@ def predict_action(observation, policy, device, use_amp):
|
||||
return action
|
||||
|
||||
|
||||
def init_keyboard_listener():
|
||||
# Allow to exit early while recording an episode or resetting the environment,
|
||||
# by tapping the right arrow key '->'. This might require a sudo permission
|
||||
# to allow your terminal to monitor keyboard events.
|
||||
def init_keyboard_listener(assign_rewards=False):
|
||||
"""
|
||||
Initializes a keyboard listener to enable early termination of an episode
|
||||
or environment reset by pressing the right arrow key ('->'). This may require
|
||||
sudo permissions to allow the terminal to monitor keyboard events.
|
||||
|
||||
Args:
|
||||
assign_rewards (bool): If True, allows annotating the collected trajectory
|
||||
with a binary reward at the end of the episode to indicate success.
|
||||
"""
|
||||
events = {}
|
||||
events["exit_early"] = False
|
||||
events["rerecord_episode"] = False
|
||||
events["stop_recording"] = False
|
||||
if assign_rewards:
|
||||
events["next.reward"] = 0
|
||||
|
||||
if is_headless():
|
||||
logging.warning(
|
||||
@@ -152,6 +160,13 @@ def init_keyboard_listener():
|
||||
print("Escape key pressed. Stopping data recording...")
|
||||
events["stop_recording"] = True
|
||||
events["exit_early"] = True
|
||||
elif assign_rewards and key == keyboard.Key.space:
|
||||
events["next.reward"] = 1 if events["next.reward"] == 0 else 0
|
||||
print(
|
||||
"Space key pressed. Assigning new reward to the subsequent frames. New reward:",
|
||||
events["next.reward"],
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error handling key press: {e}")
|
||||
|
||||
@@ -272,6 +287,8 @@ def control_loop(
|
||||
|
||||
if dataset is not None:
|
||||
frame = {**observation, **action}
|
||||
if "next.reward" in events:
|
||||
frame["next.reward"] = events["next.reward"]
|
||||
dataset.add_frame(frame)
|
||||
|
||||
if display_cameras and not is_headless():
|
||||
@@ -301,6 +318,8 @@ def reset_environment(robot, events, reset_time_s):
|
||||
|
||||
timestamp = 0
|
||||
start_vencod_t = time.perf_counter()
|
||||
if "next.reward" in events:
|
||||
events["next.reward"] = 0
|
||||
|
||||
# Wait if necessary
|
||||
with tqdm.tqdm(total=reset_time_s, desc="Waiting") as pbar:
|
||||
|
||||
48
lerobot/configs/policy/hilserl_classifier.yaml
Normal file
48
lerobot/configs/policy/hilserl_classifier.yaml
Normal file
@@ -0,0 +1,48 @@
|
||||
# @package _global_
|
||||
|
||||
defaults:
|
||||
- _self_
|
||||
|
||||
seed: 13
|
||||
dataset_repo_id: "dataset_repo_id"
|
||||
train_split_proportion: 0.8
|
||||
|
||||
# Required by logger
|
||||
env:
|
||||
name: "classifier"
|
||||
task: "binary_classification"
|
||||
|
||||
|
||||
training:
|
||||
num_epochs: 5
|
||||
batch_size: 16
|
||||
learning_rate: 1e-4
|
||||
num_workers: 4
|
||||
grad_clip_norm: 10
|
||||
use_amp: true
|
||||
log_freq: 1
|
||||
eval_freq: 1 # How often to run validation (in epochs)
|
||||
save_freq: 1 # How often to save checkpoints (in epochs)
|
||||
save_checkpoint: true
|
||||
image_key: "observation.images.phone"
|
||||
label_key: "next.reward"
|
||||
|
||||
eval:
|
||||
batch_size: 16
|
||||
num_samples_to_log: 30 # Number of validation samples to log in the table
|
||||
|
||||
policy:
|
||||
name: "hilserl/classifier"
|
||||
model_name: "facebook/convnext-base-224"
|
||||
model_type: "cnn"
|
||||
|
||||
wandb:
|
||||
enable: false
|
||||
project: "classifier-training"
|
||||
entity: "wandb_entity"
|
||||
job_name: "classifier_training_0"
|
||||
disable_artifact: false
|
||||
|
||||
device: "mps"
|
||||
resume: false
|
||||
output_dir: "output"
|
||||
@@ -10,7 +10,7 @@ max_relative_target: null
|
||||
leader_arms:
|
||||
main:
|
||||
_target_: lerobot.common.robot_devices.motors.dynamixel.DynamixelMotorsBus
|
||||
port: /dev/tty.usbmodem575E0031751
|
||||
port: /dev/tty.usbmodem58760430441
|
||||
motors:
|
||||
# name: (index, model)
|
||||
shoulder_pan: [1, "xl330-m077"]
|
||||
@@ -23,7 +23,7 @@ leader_arms:
|
||||
follower_arms:
|
||||
main:
|
||||
_target_: lerobot.common.robot_devices.motors.dynamixel.DynamixelMotorsBus
|
||||
port: /dev/tty.usbmodem575E0032081
|
||||
port: /dev/tty.usbmodem585A0083391
|
||||
motors:
|
||||
# name: (index, model)
|
||||
shoulder_pan: [1, "xl430-w250"]
|
||||
|
||||
@@ -18,7 +18,7 @@ max_relative_target: null
|
||||
leader_arms:
|
||||
main:
|
||||
_target_: lerobot.common.robot_devices.motors.feetech.FeetechMotorsBus
|
||||
port: /dev/tty.usbmodem585A0077581
|
||||
port: /dev/tty.usbmodem58760433331
|
||||
motors:
|
||||
# name: (index, model)
|
||||
shoulder_pan: [1, "sts3215"]
|
||||
|
||||
@@ -191,6 +191,7 @@ def record(
|
||||
single_task: str,
|
||||
pretrained_policy_name_or_path: str | None = None,
|
||||
policy_overrides: List[str] | None = None,
|
||||
assign_rewards: bool = False,
|
||||
fps: int | None = None,
|
||||
warmup_time_s: int | float = 2,
|
||||
episode_time_s: int | float = 10,
|
||||
@@ -214,6 +215,9 @@ def record(
|
||||
policy = None
|
||||
device = None
|
||||
use_amp = None
|
||||
extra_features = (
|
||||
{"next.reward": {"dtype": "int64", "shape": (1,), "names": None}} if assign_rewards else None
|
||||
)
|
||||
|
||||
if single_task:
|
||||
task = single_task
|
||||
@@ -254,12 +258,12 @@ def record(
|
||||
use_videos=video,
|
||||
image_writer_processes=num_image_writer_processes,
|
||||
image_writer_threads=num_image_writer_threads_per_camera * len(robot.cameras),
|
||||
features=extra_features,
|
||||
)
|
||||
|
||||
if not robot.is_connected:
|
||||
robot.connect()
|
||||
|
||||
listener, events = init_keyboard_listener()
|
||||
listener, events = init_keyboard_listener(assign_rewards=assign_rewards)
|
||||
|
||||
# Execute a few seconds without recording to:
|
||||
# 1. teleoperate the robot to move it in starting position if no policy provided,
|
||||
@@ -469,12 +473,12 @@ if __name__ == "__main__":
|
||||
default=1,
|
||||
help="Upload dataset to Hugging Face hub.",
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--tags",
|
||||
type=str,
|
||||
nargs="*",
|
||||
help="Add tags to your dataset on the hub.",
|
||||
)
|
||||
# parser_record.add_argument(
|
||||
# "--tags",
|
||||
# type=str,
|
||||
# nargs="*",
|
||||
# help="Add tags to your dataset on the hub.",
|
||||
# )
|
||||
parser_record.add_argument(
|
||||
"--num-image-writer-processes",
|
||||
type=int,
|
||||
@@ -517,6 +521,12 @@ if __name__ == "__main__":
|
||||
nargs="*",
|
||||
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--assign-rewards",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Enables the assignation of rewards to frames (by default no assignation). When enabled, assign a 0 reward to frames until the space bar is pressed which assign a 1 reward. Press the space bar a second time to assign a 0 reward. The reward assigned is reset to 0 when the episode ends.",
|
||||
)
|
||||
|
||||
parser_replay = subparsers.add_parser("replay", parents=[base_parser])
|
||||
parser_replay.add_argument(
|
||||
|
||||
335
lerobot/scripts/eval_on_robot.py
Normal file
335
lerobot/scripts/eval_on_robot.py
Normal file
@@ -0,0 +1,335 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Evaluate a policy by running rollouts on the real robot and computing metrics.
|
||||
|
||||
Usage examples: evaluate a checkpoint from the LeRobot training script for 10 episodes.
|
||||
|
||||
```
|
||||
python lerobot/scripts/eval_on_robot.py \
|
||||
-p outputs/train/model/checkpoints/005000/pretrained_model \
|
||||
eval.n_episodes=10
|
||||
```
|
||||
|
||||
**NOTE** (michel-aractingi): This script is incomplete and it is being prepared
|
||||
for running training on the real robot.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import time
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from tqdm import trange
|
||||
|
||||
from lerobot.common.policies.policy_protocol import Policy
|
||||
from lerobot.common.robot_devices.control_utils import busy_wait, is_headless
|
||||
from lerobot.common.robot_devices.robots.factory import Robot, make_robot
|
||||
from lerobot.common.utils.utils import (
|
||||
init_hydra_config,
|
||||
init_logging,
|
||||
log_say,
|
||||
)
|
||||
|
||||
|
||||
def rollout(robot: Robot, policy: Policy, fps: int, control_time_s: float = 20, use_amp: bool = True) -> dict:
|
||||
"""Run a batched policy rollout on the real robot.
|
||||
|
||||
The return dictionary contains:
|
||||
"robot": A a dictionary of (batch, sequence + 1, *) tensors mapped to observation
|
||||
keys. NOTE the that this has an extra sequence element relative to the other keys in the
|
||||
dictionary. This is because an extra observation is included for after the environment is
|
||||
terminated or truncated.
|
||||
"action": A (batch, sequence, action_dim) tensor of actions applied based on the observations (not
|
||||
including the last observations).
|
||||
"reward": A (batch, sequence) tensor of rewards received for applying the actions.
|
||||
"success": A (batch, sequence) tensor of success conditions (the only time this can be True is upon
|
||||
environment termination/truncation).
|
||||
"done": A (batch, sequence) tensor of **cumulative** done conditions. For any given batch element,
|
||||
the first True is followed by True's all the way till the end. This can be used for masking
|
||||
extraneous elements from the sequences above.
|
||||
|
||||
Args:
|
||||
robot: The robot class that defines the interface with the real robot.
|
||||
policy: The policy. Must be a PyTorch nn module.
|
||||
|
||||
Returns:
|
||||
The dictionary described above.
|
||||
"""
|
||||
# assert isinstance(policy, nn.Module), "Policy must be a PyTorch nn module."
|
||||
# device = get_device_from_parameters(policy)
|
||||
|
||||
# define keyboard listener
|
||||
listener, events = init_keyboard_listener()
|
||||
|
||||
# Reset the policy. TODO (michel-aractingi) add real policy evaluation once the code is ready.
|
||||
# policy.reset()
|
||||
|
||||
# Get observation from real robot
|
||||
observation = robot.capture_observation()
|
||||
|
||||
# Calculate reward. TODO (michel-aractingi)
|
||||
# in HIL-SERL it will be with a reward classifier
|
||||
reward = calculate_reward(observation)
|
||||
all_observations = []
|
||||
all_actions = []
|
||||
all_rewards = []
|
||||
all_successes = []
|
||||
|
||||
start_episode_t = time.perf_counter()
|
||||
timestamp = 0.0
|
||||
while timestamp < control_time_s:
|
||||
start_loop_t = time.perf_counter()
|
||||
|
||||
all_observations.append(deepcopy(observation))
|
||||
# observation = {key: observation[key].to(device, non_blocking=True) for key in observation}
|
||||
|
||||
# Apply the next action.
|
||||
while events["pause_policy"] and not events["human_intervention_step"]:
|
||||
busy_wait(0.5)
|
||||
|
||||
if events["human_intervention_step"]:
|
||||
# take over the robot's actions
|
||||
observation, action = robot.teleop_step(record_data=True)
|
||||
action = action["action"] # teleop step returns torch tensors but in a dict
|
||||
else:
|
||||
# explore with policy
|
||||
with torch.inference_mode():
|
||||
action = robot.follower_arms["main"].read("Present_Position")
|
||||
action = torch.from_numpy(action)
|
||||
robot.send_action(action)
|
||||
# action = predict_action(observation, policy, device, use_amp)
|
||||
|
||||
observation = robot.capture_observation()
|
||||
# Calculate reward
|
||||
# in HIL-SERL it will be with a reward classifier
|
||||
reward = calculate_reward(observation)
|
||||
|
||||
all_actions.append(action)
|
||||
all_rewards.append(torch.from_numpy(reward))
|
||||
all_successes.append(torch.tensor([False]))
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
busy_wait(1 / fps - dt_s)
|
||||
timestamp = time.perf_counter() - start_episode_t
|
||||
if events["exit_early"]:
|
||||
events["exit_early"] = False
|
||||
events["human_intervention_step"] = False
|
||||
events["pause_policy"] = False
|
||||
break
|
||||
all_observations.append(deepcopy(observation))
|
||||
|
||||
dones = torch.tensor([False] * len(all_actions))
|
||||
dones[-1] = True
|
||||
# Stack the sequence along the first dimension so that we have (batch, sequence, *) tensors.
|
||||
ret = {
|
||||
"action": torch.stack(all_actions, dim=1),
|
||||
"next.reward": torch.stack(all_rewards, dim=1),
|
||||
"next.success": torch.stack(all_successes, dim=1),
|
||||
"done": dones,
|
||||
}
|
||||
stacked_observations = {}
|
||||
for key in all_observations[0]:
|
||||
stacked_observations[key] = torch.stack([obs[key] for obs in all_observations], dim=1)
|
||||
ret["observation"] = stacked_observations
|
||||
|
||||
listener.stop()
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
def eval_policy(
|
||||
robot: Robot,
|
||||
policy: torch.nn.Module,
|
||||
fps: float,
|
||||
n_episodes: int,
|
||||
control_time_s: int = 20,
|
||||
use_amp: bool = True,
|
||||
) -> dict:
|
||||
"""
|
||||
Args:
|
||||
env: The batch of environments.
|
||||
policy: The policy.
|
||||
n_episodes: The number of episodes to evaluate.
|
||||
Returns:
|
||||
Dictionary with metrics and data regarding the rollouts.
|
||||
"""
|
||||
# TODO (michel-aractingi) comment this out for testing with a fixed policy
|
||||
# assert isinstance(policy, Policy)
|
||||
# policy.eval()
|
||||
|
||||
sum_rewards = []
|
||||
max_rewards = []
|
||||
successes = []
|
||||
rollouts = []
|
||||
|
||||
start_eval = time.perf_counter()
|
||||
progbar = trange(n_episodes, desc="Evaluating policy on real robot")
|
||||
for _batch_idx in progbar:
|
||||
rollout_data = rollout(robot, policy, fps, control_time_s, use_amp)
|
||||
|
||||
rollouts.append(rollout_data)
|
||||
sum_rewards.append(sum(rollout_data["next.reward"]))
|
||||
max_rewards.append(max(rollout_data["next.reward"]))
|
||||
successes.append(rollout_data["next.success"][-1])
|
||||
|
||||
info = {
|
||||
"per_episode": [
|
||||
{
|
||||
"episode_ix": i,
|
||||
"sum_reward": sum_reward,
|
||||
"max_reward": max_reward,
|
||||
"pc_success": success * 100,
|
||||
}
|
||||
for i, (sum_reward, max_reward, success) in enumerate(
|
||||
zip(
|
||||
sum_rewards[:n_episodes],
|
||||
max_rewards[:n_episodes],
|
||||
successes[:n_episodes],
|
||||
strict=False,
|
||||
)
|
||||
)
|
||||
],
|
||||
"aggregated": {
|
||||
"avg_sum_reward": float(np.nanmean(torch.cat(sum_rewards[:n_episodes]))),
|
||||
"avg_max_reward": float(np.nanmean(torch.cat(max_rewards[:n_episodes]))),
|
||||
"pc_success": float(np.nanmean(torch.cat(successes[:n_episodes])) * 100),
|
||||
"eval_s": time.time() - start_eval,
|
||||
"eval_ep_s": (time.time() - start_eval) / n_episodes,
|
||||
},
|
||||
}
|
||||
|
||||
if robot.is_connected:
|
||||
robot.disconnect()
|
||||
|
||||
return info
|
||||
|
||||
|
||||
def calculate_reward(observation):
|
||||
"""
|
||||
Method to calculate reward function in some way.
|
||||
In HIL-SERL this is done through defining a reward classifier
|
||||
"""
|
||||
# reward = reward_classifier(observation)
|
||||
return np.array([0.0])
|
||||
|
||||
|
||||
def init_keyboard_listener():
|
||||
# Allow to exit early while recording an episode or resetting the environment,
|
||||
# by tapping the right arrow key '->'. This might require a sudo permission
|
||||
# to allow your terminal to monitor keyboard events.
|
||||
events = {}
|
||||
events["exit_early"] = False
|
||||
events["rerecord_episode"] = False
|
||||
events["pause_policy"] = False
|
||||
events["human_intervention_step"] = False
|
||||
|
||||
if is_headless():
|
||||
logging.warning(
|
||||
"Headless environment detected. On-screen cameras display and keyboard inputs will not be available."
|
||||
)
|
||||
listener = None
|
||||
return listener, events
|
||||
|
||||
# Only import pynput if not in a headless environment
|
||||
from pynput import keyboard
|
||||
|
||||
def on_press(key):
|
||||
try:
|
||||
if key == keyboard.Key.right:
|
||||
print("Right arrow key pressed. Exiting loop...")
|
||||
events["exit_early"] = True
|
||||
elif key == keyboard.Key.left:
|
||||
print("Left arrow key pressed. Exiting loop and rerecord the last episode...")
|
||||
events["rerecord_episode"] = True
|
||||
events["exit_early"] = True
|
||||
elif key == keyboard.Key.space:
|
||||
# check if first space press then pause the policy for the user to get ready
|
||||
# if second space press then the user is ready to start intervention
|
||||
if not events["pause_policy"]:
|
||||
print(
|
||||
"Space key pressed. Human intervention required.\n"
|
||||
"Place the leader in similar pose to the follower and press space again."
|
||||
)
|
||||
events["pause_policy"] = True
|
||||
log_say("Human intervention stage. Get ready to take over.", play_sounds=True)
|
||||
else:
|
||||
events["human_intervention_step"] = True
|
||||
print("Space key pressed. Human intervention starting.")
|
||||
log_say("Starting human intervention.", play_sounds=True)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error handling key press: {e}")
|
||||
|
||||
listener = keyboard.Listener(on_press=on_press)
|
||||
listener.start()
|
||||
|
||||
return listener, events
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
init_logging()
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
|
||||
)
|
||||
group = parser.add_mutually_exclusive_group(required=True)
|
||||
group.add_argument(
|
||||
"--robot-path",
|
||||
type=str,
|
||||
default="lerobot/configs/robot/koch.yaml",
|
||||
help="Path to robot yaml file used to instantiate the robot using `make_robot` factory function.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--robot-overrides",
|
||||
type=str,
|
||||
nargs="*",
|
||||
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
|
||||
)
|
||||
group.add_argument(
|
||||
"-p",
|
||||
"--pretrained-policy-name-or-path",
|
||||
help=(
|
||||
"Either the repo ID of a model hosted on the Hub or a path to a directory containing weights "
|
||||
"saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch "
|
||||
"(useful for debugging). This argument is mutually exclusive with `--config`."
|
||||
),
|
||||
)
|
||||
group.add_argument(
|
||||
"--config",
|
||||
help=(
|
||||
"Path to a yaml config you want to use for initializing a policy from scratch (useful for "
|
||||
"debugging). This argument is mutually exclusive with `--pretrained-policy-name-or-path` (`-p`)."
|
||||
),
|
||||
)
|
||||
parser.add_argument("--revision", help="Optionally provide the Hugging Face Hub revision ID.")
|
||||
parser.add_argument(
|
||||
"--out-dir",
|
||||
help=(
|
||||
"Where to save the evaluation outputs. If not provided, outputs are saved in "
|
||||
"outputs/eval/{timestamp}_{env_name}_{policy_name}"
|
||||
),
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
robot_cfg = init_hydra_config(args.robot_path, args.robot_overrides)
|
||||
robot = make_robot(robot_cfg)
|
||||
if not robot.is_connected:
|
||||
robot.connect()
|
||||
|
||||
eval_policy(robot, None, fps=40, n_episodes=2, control_time_s=100)
|
||||
310
lerobot/scripts/train_hilserl_classifier.py
Normal file
310
lerobot/scripts/train_hilserl_classifier.py
Normal file
@@ -0,0 +1,310 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
|
||||
import hydra
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import wandb
|
||||
from deepdiff import DeepDiff
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from termcolor import colored
|
||||
from torch import optim
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.utils.data import DataLoader, WeightedRandomSampler, random_split
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.common.datasets.factory import resolve_delta_timestamps
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.logger import Logger
|
||||
from lerobot.common.policies.factory import _policy_cfg_from_hydra_cfg
|
||||
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
|
||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
|
||||
from lerobot.common.utils.utils import (
|
||||
format_big_number,
|
||||
get_safe_torch_device,
|
||||
init_hydra_config,
|
||||
set_global_seed,
|
||||
)
|
||||
|
||||
|
||||
def get_model(cfg, logger):
|
||||
classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg)
|
||||
model = Classifier(classifier_config)
|
||||
if cfg.resume:
|
||||
model.load_state_dict(Classifier.from_pretrained(str(logger.last_pretrained_model_dir)).state_dict())
|
||||
return model
|
||||
|
||||
|
||||
def create_balanced_sampler(dataset, cfg):
|
||||
# Creates a weighted sampler to handle class imbalance
|
||||
|
||||
labels = torch.tensor([item[cfg.training.label_key] for item in dataset])
|
||||
_, counts = torch.unique(labels, return_counts=True)
|
||||
class_weights = 1.0 / counts.float()
|
||||
sample_weights = class_weights[labels]
|
||||
|
||||
return WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)
|
||||
|
||||
|
||||
def train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device, logger, step, cfg):
|
||||
# Single epoch training loop with AMP support and progress tracking
|
||||
model.train()
|
||||
correct = 0
|
||||
total = 0
|
||||
|
||||
pbar = tqdm(train_loader, desc="Training")
|
||||
for batch_idx, batch in enumerate(pbar):
|
||||
start_time = time.perf_counter()
|
||||
images = batch[cfg.training.image_key].to(device)
|
||||
labels = batch[cfg.training.label_key].float().to(device)
|
||||
|
||||
# Forward pass with optional AMP
|
||||
with torch.autocast(device_type=device.type) if cfg.training.use_amp else nullcontext():
|
||||
outputs = model(images)
|
||||
loss = criterion(outputs.logits, labels)
|
||||
|
||||
# Backward pass with gradient scaling if AMP enabled
|
||||
optimizer.zero_grad()
|
||||
if cfg.training.use_amp:
|
||||
grad_scaler.scale(loss).backward()
|
||||
grad_scaler.step(optimizer)
|
||||
grad_scaler.update()
|
||||
else:
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# Track metrics
|
||||
if model.config.num_classes == 2:
|
||||
predictions = (torch.sigmoid(outputs.logits) > 0.5).float()
|
||||
else:
|
||||
predictions = torch.argmax(outputs.logits, dim=1)
|
||||
correct += (predictions == labels).sum().item()
|
||||
total += labels.size(0)
|
||||
|
||||
current_acc = 100 * correct / total
|
||||
train_info = {
|
||||
"loss": loss.item(),
|
||||
"accuracy": current_acc,
|
||||
"dataloading_s": time.perf_counter() - start_time,
|
||||
}
|
||||
|
||||
logger.log_dict(train_info, step + batch_idx, mode="train")
|
||||
pbar.set_postfix({"loss": f"{loss.item():.4f}", "acc": f"{current_acc:.2f}%"})
|
||||
|
||||
|
||||
def validate(model, val_loader, criterion, device, logger, cfg, num_samples_to_log=8):
|
||||
# Validation loop with metric tracking and sample logging
|
||||
model.eval()
|
||||
correct = 0
|
||||
total = 0
|
||||
batch_start_time = time.perf_counter()
|
||||
samples = []
|
||||
running_loss = 0
|
||||
|
||||
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.training.use_amp else nullcontext():
|
||||
for batch in tqdm(val_loader, desc="Validation"):
|
||||
images = batch[cfg.training.image_key].to(device)
|
||||
labels = batch[cfg.training.label_key].float().to(device)
|
||||
|
||||
outputs = model(images)
|
||||
loss = criterion(outputs.logits, labels)
|
||||
|
||||
# Track metrics
|
||||
if model.config.num_classes == 2:
|
||||
predictions = (torch.sigmoid(outputs.logits) > 0.5).float()
|
||||
else:
|
||||
predictions = torch.argmax(outputs.logits, dim=1)
|
||||
correct += (predictions == labels).sum().item()
|
||||
total += labels.size(0)
|
||||
running_loss += loss.item()
|
||||
|
||||
# Log sample predictions for visualization
|
||||
if len(samples) < num_samples_to_log:
|
||||
for i in range(min(num_samples_to_log - len(samples), len(images))):
|
||||
if model.config.num_classes == 2:
|
||||
confidence = round(outputs.probabilities[i].item(), 3)
|
||||
else:
|
||||
confidence = [round(prob, 3) for prob in outputs.probabilities[i].tolist()]
|
||||
samples.append(
|
||||
{
|
||||
"image": wandb.Image(images[i].cpu()),
|
||||
"true_label": labels[i].item(),
|
||||
"predicted": predictions[i].item(),
|
||||
"confidence": confidence,
|
||||
}
|
||||
)
|
||||
|
||||
accuracy = 100 * correct / total
|
||||
avg_loss = running_loss / len(val_loader)
|
||||
|
||||
eval_info = {
|
||||
"loss": avg_loss,
|
||||
"accuracy": accuracy,
|
||||
"eval_s": time.perf_counter() - batch_start_time,
|
||||
"eval/prediction_samples": wandb.Table(
|
||||
data=[[s["image"], s["true_label"], s["predicted"], f"{s['confidence']}"] for s in samples],
|
||||
columns=["Image", "True Label", "Predicted", "Confidence"],
|
||||
)
|
||||
if logger._cfg.wandb.enable
|
||||
else None,
|
||||
}
|
||||
|
||||
return accuracy, eval_info
|
||||
|
||||
|
||||
@hydra.main(version_base="1.2", config_path="../configs", config_name="hilserl_classifier")
|
||||
def train(cfg: DictConfig) -> None:
|
||||
# Main training pipeline with support for resuming training
|
||||
logging.info(OmegaConf.to_yaml(cfg))
|
||||
|
||||
# Initialize training environment
|
||||
device = get_safe_torch_device(cfg.device, log=True)
|
||||
set_global_seed(cfg.seed)
|
||||
|
||||
out_dir = Path(cfg.output_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
logger = Logger(cfg, out_dir, cfg.wandb.job_name if cfg.wandb.enable else None)
|
||||
|
||||
# Setup dataset and dataloaders
|
||||
dataset = LeRobotDataset(cfg.dataset_repo_id)
|
||||
logging.info(f"Dataset size: {len(dataset)}")
|
||||
|
||||
train_size = int(cfg.train_split_proportion * len(dataset))
|
||||
val_size = len(dataset) - train_size
|
||||
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
|
||||
|
||||
sampler = create_balanced_sampler(train_dataset, cfg)
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=cfg.training.batch_size,
|
||||
num_workers=cfg.training.num_workers,
|
||||
sampler=sampler,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
val_loader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=cfg.eval.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=cfg.training.num_workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
# Resume training if requested
|
||||
step = 0
|
||||
best_val_acc = 0
|
||||
|
||||
if cfg.resume:
|
||||
if not Logger.get_last_checkpoint_dir(out_dir).exists():
|
||||
raise RuntimeError(
|
||||
"You have set resume=True, but there is no model checkpoint in "
|
||||
f"{Logger.get_last_checkpoint_dir(out_dir)}"
|
||||
)
|
||||
checkpoint_cfg_path = str(Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml")
|
||||
logging.info(
|
||||
colored(
|
||||
"You have set resume=True, indicating that you wish to resume a run",
|
||||
color="yellow",
|
||||
attrs=["bold"],
|
||||
)
|
||||
)
|
||||
# Load and validate checkpoint configuration
|
||||
checkpoint_cfg = init_hydra_config(checkpoint_cfg_path)
|
||||
# Check for differences between the checkpoint configuration and provided configuration.
|
||||
# Hack to resolve the delta_timestamps ahead of time in order to properly diff.
|
||||
resolve_delta_timestamps(cfg)
|
||||
diff = DeepDiff(OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg))
|
||||
# Ignore the `resume` and parameters.
|
||||
if "values_changed" in diff and "root['resume']" in diff["values_changed"]:
|
||||
del diff["values_changed"]["root['resume']"]
|
||||
if len(diff) > 0:
|
||||
logging.warning(
|
||||
"At least one difference was detected between the checkpoint configuration and "
|
||||
f"the provided configuration: \n{pformat(diff)}\nNote that the checkpoint configuration "
|
||||
"takes precedence.",
|
||||
)
|
||||
# Use the checkpoint config instead of the provided config (but keep `resume` parameter).
|
||||
cfg = checkpoint_cfg
|
||||
cfg.resume = True
|
||||
|
||||
# Initialize model and training components
|
||||
model = get_model(cfg=cfg, logger=logger).to(device)
|
||||
|
||||
optimizer = optim.AdamW(model.parameters(), lr=cfg.training.learning_rate)
|
||||
# Use BCEWithLogitsLoss for binary classification and CrossEntropyLoss for multi-class
|
||||
criterion = nn.BCEWithLogitsLoss() if model.config.num_classes == 2 else nn.CrossEntropyLoss()
|
||||
grad_scaler = GradScaler(enabled=cfg.training.use_amp)
|
||||
|
||||
# Log model parameters
|
||||
num_learnable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
num_total_params = sum(p.numel() for p in model.parameters())
|
||||
logging.info(f"Learnable parameters: {format_big_number(num_learnable_params)}")
|
||||
logging.info(f"Total parameters: {format_big_number(num_total_params)}")
|
||||
|
||||
if cfg.resume:
|
||||
step = logger.load_last_training_state(optimizer, None)
|
||||
|
||||
# Training loop with validation and checkpointing
|
||||
for epoch in range(cfg.training.num_epochs):
|
||||
logging.info(f"\nEpoch {epoch+1}/{cfg.training.num_epochs}")
|
||||
|
||||
train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device, logger, step, cfg)
|
||||
|
||||
# Periodic validation
|
||||
if cfg.training.eval_freq > 0 and (epoch + 1) % cfg.training.eval_freq == 0:
|
||||
val_acc, eval_info = validate(
|
||||
model,
|
||||
val_loader,
|
||||
criterion,
|
||||
device,
|
||||
logger,
|
||||
cfg,
|
||||
)
|
||||
logger.log_dict(eval_info, step + len(train_loader), mode="eval")
|
||||
|
||||
# Save best model
|
||||
if val_acc > best_val_acc:
|
||||
best_val_acc = val_acc
|
||||
logger.save_checkpoint(
|
||||
train_step=step + len(train_loader),
|
||||
policy=model,
|
||||
optimizer=optimizer,
|
||||
scheduler=None,
|
||||
identifier="best",
|
||||
)
|
||||
|
||||
# Periodic checkpointing
|
||||
if cfg.training.save_checkpoint and (epoch + 1) % cfg.training.save_freq == 0:
|
||||
logger.save_checkpoint(
|
||||
train_step=step + len(train_loader),
|
||||
policy=model,
|
||||
optimizer=optimizer,
|
||||
scheduler=None,
|
||||
identifier=f"{epoch+1:06d}",
|
||||
)
|
||||
|
||||
step += len(train_loader)
|
||||
|
||||
logging.info("Training completed")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
||||
251
tests/test_train_hilserl_classifier.py
Normal file
251
tests/test_train_hilserl_classifier.py
Normal file
@@ -0,0 +1,251 @@
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from hydra import compose, initialize_config_dir
|
||||
from torch import nn
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
|
||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
|
||||
from lerobot.scripts.train_hilserl_classifier import (
|
||||
create_balanced_sampler,
|
||||
train,
|
||||
train_epoch,
|
||||
validate,
|
||||
)
|
||||
|
||||
|
||||
class MockDataset(Dataset):
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
self.meta = MagicMock()
|
||||
self.meta.stats = {}
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.data[idx]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
|
||||
def make_dummy_model():
|
||||
model_config = ClassifierConfig(num_classes=2, model_name="hf-tiny-model-private/tiny-random-ResNetModel")
|
||||
model = Classifier(config=model_config)
|
||||
return model
|
||||
|
||||
|
||||
def test_create_balanced_sampler():
|
||||
# Mock dataset with imbalanced classes
|
||||
data = [
|
||||
{"label": 0},
|
||||
{"label": 0},
|
||||
{"label": 1},
|
||||
{"label": 0},
|
||||
{"label": 1},
|
||||
{"label": 1},
|
||||
{"label": 1},
|
||||
{"label": 1},
|
||||
]
|
||||
dataset = MockDataset(data)
|
||||
cfg = MagicMock()
|
||||
cfg.training.label_key = "label"
|
||||
|
||||
sampler = create_balanced_sampler(dataset, cfg)
|
||||
|
||||
# Get weights from the sampler
|
||||
weights = sampler.weights.float()
|
||||
|
||||
# Check that samples have appropriate weights
|
||||
labels = [item["label"] for item in data]
|
||||
class_counts = torch.tensor([labels.count(0), labels.count(1)], dtype=torch.float32)
|
||||
class_weights = 1.0 / class_counts
|
||||
expected_weights = torch.tensor([class_weights[label] for label in labels], dtype=torch.float32)
|
||||
|
||||
# Test that the weights are correct
|
||||
assert torch.allclose(weights, expected_weights)
|
||||
|
||||
|
||||
def test_train_epoch():
|
||||
model = make_dummy_model()
|
||||
# Mock components
|
||||
model.train = MagicMock()
|
||||
|
||||
train_loader = [
|
||||
{
|
||||
"image": torch.rand(2, 3, 224, 224),
|
||||
"label": torch.tensor([0.0, 1.0]),
|
||||
}
|
||||
]
|
||||
|
||||
criterion = nn.BCEWithLogitsLoss()
|
||||
optimizer = MagicMock()
|
||||
grad_scaler = MagicMock()
|
||||
device = torch.device("cpu")
|
||||
logger = MagicMock()
|
||||
step = 0
|
||||
cfg = MagicMock()
|
||||
cfg.training.image_key = "image"
|
||||
cfg.training.label_key = "label"
|
||||
cfg.training.use_amp = False
|
||||
|
||||
# Call the function under test
|
||||
train_epoch(
|
||||
model,
|
||||
train_loader,
|
||||
criterion,
|
||||
optimizer,
|
||||
grad_scaler,
|
||||
device,
|
||||
logger,
|
||||
step,
|
||||
cfg,
|
||||
)
|
||||
|
||||
# Check that model.train() was called
|
||||
model.train.assert_called_once()
|
||||
|
||||
# Check that optimizer.zero_grad() was called
|
||||
optimizer.zero_grad.assert_called()
|
||||
|
||||
# Check that logger.log_dict was called
|
||||
logger.log_dict.assert_called()
|
||||
|
||||
|
||||
def test_validate():
|
||||
model = make_dummy_model()
|
||||
|
||||
# Mock components
|
||||
model.eval = MagicMock()
|
||||
val_loader = [
|
||||
{
|
||||
"image": torch.rand(2, 3, 224, 224),
|
||||
"label": torch.tensor([0.0, 1.0]),
|
||||
}
|
||||
]
|
||||
criterion = nn.BCEWithLogitsLoss()
|
||||
device = torch.device("cpu")
|
||||
logger = MagicMock()
|
||||
cfg = MagicMock()
|
||||
cfg.training.image_key = "image"
|
||||
cfg.training.label_key = "label"
|
||||
cfg.training.use_amp = False
|
||||
|
||||
# Call validate
|
||||
accuracy, eval_info = validate(model, val_loader, criterion, device, logger, cfg)
|
||||
|
||||
# Check that model.eval() was called
|
||||
model.eval.assert_called_once()
|
||||
|
||||
# Check accuracy/eval_info are calculated and of the correct type
|
||||
assert isinstance(accuracy, float)
|
||||
assert isinstance(eval_info, dict)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("resume", [True, False])
|
||||
@patch("lerobot.scripts.train_hilserl_classifier.init_hydra_config")
|
||||
@patch("lerobot.scripts.train_hilserl_classifier.Logger.get_last_checkpoint_dir")
|
||||
@patch("lerobot.scripts.train_hilserl_classifier.Logger.get_last_pretrained_model_dir")
|
||||
@patch("lerobot.scripts.train_hilserl_classifier.Logger")
|
||||
@patch("lerobot.scripts.train_hilserl_classifier.LeRobotDataset")
|
||||
@patch("lerobot.scripts.train_hilserl_classifier.make_policy")
|
||||
def test_resume_function(
|
||||
mock_make_policy,
|
||||
mock_dataset,
|
||||
mock_logger,
|
||||
mock_get_last_pretrained_model_dir,
|
||||
mock_get_last_checkpoint_dir,
|
||||
mock_init_hydra_config,
|
||||
resume,
|
||||
):
|
||||
# Initialize Hydra
|
||||
test_file_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
config_dir = os.path.abspath(os.path.join(test_file_dir, "..", "lerobot", "configs", "policy"))
|
||||
assert os.path.exists(config_dir), f"Config directory does not exist at {config_dir}"
|
||||
|
||||
with initialize_config_dir(config_dir=config_dir, job_name="test_app", version_base="1.2"):
|
||||
cfg = compose(
|
||||
config_name="reward_classifier",
|
||||
overrides=[
|
||||
"device=cpu",
|
||||
"seed=42",
|
||||
f"output_dir={tempfile.mkdtemp()}",
|
||||
"wandb.enable=False",
|
||||
f"resume={resume}",
|
||||
"dataset_repo_id=dataset_repo_id",
|
||||
"train_split_proportion=0.8",
|
||||
"training.num_workers=0",
|
||||
"training.batch_size=2",
|
||||
"training.image_key=image",
|
||||
"training.label_key=label",
|
||||
"training.use_amp=False",
|
||||
"training.num_epochs=1",
|
||||
"eval.batch_size=2",
|
||||
],
|
||||
)
|
||||
|
||||
# Mock the init_hydra_config function to return cfg
|
||||
mock_init_hydra_config.return_value = cfg
|
||||
|
||||
# Mock dataset
|
||||
dataset = MockDataset([{"image": torch.rand(3, 224, 224), "label": i % 2} for i in range(10)])
|
||||
mock_dataset.return_value = dataset
|
||||
|
||||
# Mock checkpoint handling
|
||||
mock_checkpoint_dir = MagicMock(spec=Path)
|
||||
mock_checkpoint_dir.exists.return_value = resume # Only exists if resuming
|
||||
mock_get_last_checkpoint_dir.return_value = mock_checkpoint_dir
|
||||
mock_get_last_pretrained_model_dir.return_value = Path(tempfile.mkdtemp())
|
||||
|
||||
# Mock logger
|
||||
logger = MagicMock()
|
||||
resumed_step = 1000
|
||||
if resume:
|
||||
logger.load_last_training_state.return_value = resumed_step
|
||||
else:
|
||||
logger.load_last_training_state.return_value = 0
|
||||
mock_logger.return_value = logger
|
||||
|
||||
# Instantiate the model and set make_policy to return it
|
||||
model = make_dummy_model()
|
||||
mock_make_policy.return_value = model
|
||||
|
||||
# Call train
|
||||
train(cfg)
|
||||
|
||||
# Check that checkpoint handling methods were called
|
||||
if resume:
|
||||
mock_get_last_checkpoint_dir.assert_called_once_with(Path(cfg.output_dir))
|
||||
mock_get_last_pretrained_model_dir.assert_called_once_with(Path(cfg.output_dir))
|
||||
mock_checkpoint_dir.exists.assert_called_once()
|
||||
logger.load_last_training_state.assert_called_once()
|
||||
else:
|
||||
mock_get_last_checkpoint_dir.assert_not_called()
|
||||
mock_get_last_pretrained_model_dir.assert_not_called()
|
||||
mock_checkpoint_dir.exists.assert_not_called()
|
||||
logger.load_last_training_state.assert_not_called()
|
||||
|
||||
# Collect the steps from logger.log_dict calls
|
||||
train_log_calls = logger.log_dict.call_args_list
|
||||
|
||||
# Extract the steps used in the train logging
|
||||
steps = []
|
||||
for call in train_log_calls:
|
||||
mode = call.kwargs.get("mode", call.args[2] if len(call.args) > 2 else None)
|
||||
if mode == "train":
|
||||
step = call.kwargs.get("step", call.args[1] if len(call.args) > 1 else None)
|
||||
steps.append(step)
|
||||
|
||||
expected_start_step = resumed_step if resume else 0
|
||||
|
||||
# Calculate expected_steps
|
||||
train_size = int(cfg.train_split_proportion * len(dataset))
|
||||
batch_size = cfg.training.batch_size
|
||||
num_batches = (train_size + batch_size - 1) // batch_size
|
||||
|
||||
expected_steps = [expected_start_step + i for i in range(num_batches)]
|
||||
|
||||
assert steps == expected_steps, f"Expected steps {expected_steps}, got {steps}"
|
||||
Reference in New Issue
Block a user