diff --git a/lerobot/common/datasets/image_writer.py b/lerobot/common/datasets/image_writer.py index 9564fb59..85dd6830 100644 --- a/lerobot/common/datasets/image_writer.py +++ b/lerobot/common/datasets/image_writer.py @@ -28,7 +28,7 @@ def safe_stop_image_writer(func): try: return func(*args, **kwargs) except Exception as e: - dataset = kwargs.get("dataset", None) + dataset = kwargs.get("dataset") image_writer = getattr(dataset, "image_writer", None) if dataset else None if image_writer is not None: print("Waiting for image writer to terminate...") diff --git a/lerobot/common/logger.py b/lerobot/common/logger.py index dec8b465..4015492d 100644 --- a/lerobot/common/logger.py +++ b/lerobot/common/logger.py @@ -25,13 +25,13 @@ from glob import glob from pathlib import Path import torch +import wandb from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE from omegaconf import DictConfig, OmegaConf from termcolor import colored from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler -import wandb from lerobot.common.policies.policy_protocol import Policy from lerobot.common.utils.utils import get_global_random_state, set_global_random_state diff --git a/lerobot/common/policies/hilserl/configuration_hilserl.py b/lerobot/common/policies/hilserl/configuration_hilserl.py index c1bd52cd..f1bc850f 100644 --- a/lerobot/common/policies/hilserl/configuration_hilserl.py +++ b/lerobot/common/policies/hilserl/configuration_hilserl.py @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass, field +from dataclasses import dataclass @dataclass diff --git a/lerobot/common/policies/hilserl/modeling_hilserl.py b/lerobot/common/policies/hilserl/modeling_hilserl.py index f130c7ad..236ed433 100644 --- a/lerobot/common/policies/hilserl/modeling_hilserl.py +++ b/lerobot/common/policies/hilserl/modeling_hilserl.py @@ -15,11 +15,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import torch import torch.nn as nn -import torch.nn.functional as F # noqa: N812 from huggingface_hub import PyTorchModelHubMixin + class HILSerlPolicy( nn.Module, PyTorchModelHubMixin, diff --git a/lerobot/common/policies/sac/configuration_hilserl.py b/lerobot/common/policies/sac/configuration_hilserl.py deleted file mode 100644 index e0e2f05b..00000000 --- a/lerobot/common/policies/sac/configuration_hilserl.py +++ /dev/null @@ -1,23 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2024 The HuggingFace Inc. team. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dataclasses import dataclass, field - - -@dataclass -class SACConfig: - pass diff --git a/lerobot/common/policies/sac/modeling_hilserl.py b/lerobot/common/policies/sac/modeling_hilserl.py deleted file mode 100644 index d7b87fca..00000000 --- a/lerobot/common/policies/sac/modeling_hilserl.py +++ /dev/null @@ -1,30 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2024 The HuggingFace Inc. team. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -import torch.nn as nn -import torch.nn.functional as F # noqa: N812 -from huggingface_hub import PyTorchModelHubMixin - -class SACPolicy( - nn.Module, - PyTorchModelHubMixin, - library_name="lerobot", - repo_url="https://github.com/huggingface/lerobot", - tags=["robotics", "SAC"], -): - pass \ No newline at end of file diff --git a/lerobot/scripts/control_robot.py b/lerobot/scripts/control_robot.py index b876cd1e..3d9fe542 100644 --- a/lerobot/scripts/control_robot.py +++ b/lerobot/scripts/control_robot.py @@ -224,12 +224,7 @@ def record( else: raise NotImplementedError("Only single-task recording is supported for now") - if single_task: - task = single_task - else: - raise NotImplementedError("Only single-task recording is supported for now") - - # Load pretrained policy + # Load pretrained policy if pretrained_policy_name_or_path is not None: policy, policy_fps, device, use_amp = init_policy(pretrained_policy_name_or_path, policy_overrides) @@ -263,6 +258,7 @@ def record( use_videos=video, image_writer_processes=num_image_writer_processes, image_writer_threads=num_image_writer_threads_per_camera * len(robot.cameras), + features=extra_features, ) if not robot.is_connected: diff --git a/lerobot/scripts/train_hilserl_classifier.py b/lerobot/scripts/train_hilserl_classifier.py index ea8336a9..78659dc8 100644 --- a/lerobot/scripts/train_hilserl_classifier.py +++ b/lerobot/scripts/train_hilserl_classifier.py @@ -22,6 +22,7 @@ from pprint import pformat import hydra import torch import torch.nn as nn +import wandb from deepdiff import DeepDiff from omegaconf import DictConfig, OmegaConf from termcolor import colored @@ -30,7 +31,6 @@ from torch.cuda.amp import GradScaler from torch.utils.data import DataLoader, WeightedRandomSampler, random_split from tqdm import tqdm -import wandb from lerobot.common.datasets.factory import resolve_delta_timestamps from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.logger import Logger