nit in control_robot.py
This commit is contained in:
@@ -28,7 +28,7 @@ def safe_stop_image_writer(func):
|
|||||||
try:
|
try:
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
dataset = kwargs.get("dataset", None)
|
dataset = kwargs.get("dataset")
|
||||||
image_writer = getattr(dataset, "image_writer", None) if dataset else None
|
image_writer = getattr(dataset, "image_writer", None) if dataset else None
|
||||||
if image_writer is not None:
|
if image_writer is not None:
|
||||||
print("Waiting for image writer to terminate...")
|
print("Waiting for image writer to terminate...")
|
||||||
|
|||||||
@@ -25,13 +25,13 @@ from glob import glob
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import wandb
|
||||||
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
|
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
|
||||||
from omegaconf import DictConfig, OmegaConf
|
from omegaconf import DictConfig, OmegaConf
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from torch.optim.lr_scheduler import LRScheduler
|
from torch.optim.lr_scheduler import LRScheduler
|
||||||
|
|
||||||
import wandb
|
|
||||||
from lerobot.common.policies.policy_protocol import Policy
|
from lerobot.common.policies.policy_protocol import Policy
|
||||||
from lerobot.common.utils.utils import get_global_random_state, set_global_random_state
|
from lerobot.common.utils.utils import get_global_random_state, set_global_random_state
|
||||||
|
|
||||||
|
|||||||
@@ -15,7 +15,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -15,11 +15,10 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F # noqa: N812
|
|
||||||
from huggingface_hub import PyTorchModelHubMixin
|
from huggingface_hub import PyTorchModelHubMixin
|
||||||
|
|
||||||
|
|
||||||
class HILSerlPolicy(
|
class HILSerlPolicy(
|
||||||
nn.Module,
|
nn.Module,
|
||||||
PyTorchModelHubMixin,
|
PyTorchModelHubMixin,
|
||||||
|
|||||||
@@ -1,23 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2024 The HuggingFace Inc. team.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class SACConfig:
|
|
||||||
pass
|
|
||||||
@@ -1,30 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2024 The HuggingFace Inc. team.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F # noqa: N812
|
|
||||||
from huggingface_hub import PyTorchModelHubMixin
|
|
||||||
|
|
||||||
class SACPolicy(
|
|
||||||
nn.Module,
|
|
||||||
PyTorchModelHubMixin,
|
|
||||||
library_name="lerobot",
|
|
||||||
repo_url="https://github.com/huggingface/lerobot",
|
|
||||||
tags=["robotics", "SAC"],
|
|
||||||
):
|
|
||||||
pass
|
|
||||||
@@ -224,12 +224,7 @@ def record(
|
|||||||
else:
|
else:
|
||||||
raise NotImplementedError("Only single-task recording is supported for now")
|
raise NotImplementedError("Only single-task recording is supported for now")
|
||||||
|
|
||||||
if single_task:
|
# Load pretrained policy
|
||||||
task = single_task
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("Only single-task recording is supported for now")
|
|
||||||
|
|
||||||
# Load pretrained policy
|
|
||||||
if pretrained_policy_name_or_path is not None:
|
if pretrained_policy_name_or_path is not None:
|
||||||
policy, policy_fps, device, use_amp = init_policy(pretrained_policy_name_or_path, policy_overrides)
|
policy, policy_fps, device, use_amp = init_policy(pretrained_policy_name_or_path, policy_overrides)
|
||||||
|
|
||||||
@@ -263,6 +258,7 @@ def record(
|
|||||||
use_videos=video,
|
use_videos=video,
|
||||||
image_writer_processes=num_image_writer_processes,
|
image_writer_processes=num_image_writer_processes,
|
||||||
image_writer_threads=num_image_writer_threads_per_camera * len(robot.cameras),
|
image_writer_threads=num_image_writer_threads_per_camera * len(robot.cameras),
|
||||||
|
features=extra_features,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not robot.is_connected:
|
if not robot.is_connected:
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ from pprint import pformat
|
|||||||
import hydra
|
import hydra
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import wandb
|
||||||
from deepdiff import DeepDiff
|
from deepdiff import DeepDiff
|
||||||
from omegaconf import DictConfig, OmegaConf
|
from omegaconf import DictConfig, OmegaConf
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
@@ -30,7 +31,6 @@ from torch.cuda.amp import GradScaler
|
|||||||
from torch.utils.data import DataLoader, WeightedRandomSampler, random_split
|
from torch.utils.data import DataLoader, WeightedRandomSampler, random_split
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
import wandb
|
|
||||||
from lerobot.common.datasets.factory import resolve_delta_timestamps
|
from lerobot.common.datasets.factory import resolve_delta_timestamps
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.common.logger import Logger
|
from lerobot.common.logger import Logger
|
||||||
|
|||||||
Reference in New Issue
Block a user