nit in control_robot.py

This commit is contained in:
Michel Aractingi
2024-12-11 00:30:33 +01:00
parent e9ef46f134
commit 3d7e74d162
8 changed files with 7 additions and 65 deletions

View File

@@ -28,7 +28,7 @@ def safe_stop_image_writer(func):
try: try:
return func(*args, **kwargs) return func(*args, **kwargs)
except Exception as e: except Exception as e:
dataset = kwargs.get("dataset", None) dataset = kwargs.get("dataset")
image_writer = getattr(dataset, "image_writer", None) if dataset else None image_writer = getattr(dataset, "image_writer", None) if dataset else None
if image_writer is not None: if image_writer is not None:
print("Waiting for image writer to terminate...") print("Waiting for image writer to terminate...")

View File

@@ -25,13 +25,13 @@ from glob import glob
from pathlib import Path from pathlib import Path
import torch import torch
import wandb
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from omegaconf import DictConfig, OmegaConf from omegaconf import DictConfig, OmegaConf
from termcolor import colored from termcolor import colored
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler from torch.optim.lr_scheduler import LRScheduler
import wandb
from lerobot.common.policies.policy_protocol import Policy from lerobot.common.policies.policy_protocol import Policy
from lerobot.common.utils.utils import get_global_random_state, set_global_random_state from lerobot.common.utils.utils import get_global_random_state, set_global_random_state

View File

@@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from dataclasses import dataclass, field from dataclasses import dataclass
@dataclass @dataclass

View File

@@ -15,11 +15,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F # noqa: N812
from huggingface_hub import PyTorchModelHubMixin from huggingface_hub import PyTorchModelHubMixin
class HILSerlPolicy( class HILSerlPolicy(
nn.Module, nn.Module,
PyTorchModelHubMixin, PyTorchModelHubMixin,

View File

@@ -1,23 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
@dataclass
class SACConfig:
pass

View File

@@ -1,30 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
import torch.nn.functional as F # noqa: N812
from huggingface_hub import PyTorchModelHubMixin
class SACPolicy(
nn.Module,
PyTorchModelHubMixin,
library_name="lerobot",
repo_url="https://github.com/huggingface/lerobot",
tags=["robotics", "SAC"],
):
pass

View File

@@ -224,12 +224,7 @@ def record(
else: else:
raise NotImplementedError("Only single-task recording is supported for now") raise NotImplementedError("Only single-task recording is supported for now")
if single_task: # Load pretrained policy
task = single_task
else:
raise NotImplementedError("Only single-task recording is supported for now")
# Load pretrained policy
if pretrained_policy_name_or_path is not None: if pretrained_policy_name_or_path is not None:
policy, policy_fps, device, use_amp = init_policy(pretrained_policy_name_or_path, policy_overrides) policy, policy_fps, device, use_amp = init_policy(pretrained_policy_name_or_path, policy_overrides)
@@ -263,6 +258,7 @@ def record(
use_videos=video, use_videos=video,
image_writer_processes=num_image_writer_processes, image_writer_processes=num_image_writer_processes,
image_writer_threads=num_image_writer_threads_per_camera * len(robot.cameras), image_writer_threads=num_image_writer_threads_per_camera * len(robot.cameras),
features=extra_features,
) )
if not robot.is_connected: if not robot.is_connected:

View File

@@ -22,6 +22,7 @@ from pprint import pformat
import hydra import hydra
import torch import torch
import torch.nn as nn import torch.nn as nn
import wandb
from deepdiff import DeepDiff from deepdiff import DeepDiff
from omegaconf import DictConfig, OmegaConf from omegaconf import DictConfig, OmegaConf
from termcolor import colored from termcolor import colored
@@ -30,7 +31,6 @@ from torch.cuda.amp import GradScaler
from torch.utils.data import DataLoader, WeightedRandomSampler, random_split from torch.utils.data import DataLoader, WeightedRandomSampler, random_split
from tqdm import tqdm from tqdm import tqdm
import wandb
from lerobot.common.datasets.factory import resolve_delta_timestamps from lerobot.common.datasets.factory import resolve_delta_timestamps
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.logger import Logger from lerobot.common.logger import Logger