Add review feedback
This commit is contained in:
@@ -46,6 +46,7 @@ repos:
|
|||||||
rev: v3.19.1
|
rev: v3.19.1
|
||||||
hooks:
|
hooks:
|
||||||
- id: pyupgrade
|
- id: pyupgrade
|
||||||
|
# Exclude generated protobuf files
|
||||||
exclude: '^(.*_pb2_grpc\.py|.*_pb2\.py$)'
|
exclude: '^(.*_pb2_grpc\.py|.*_pb2\.py$)'
|
||||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
rev: v0.11.5
|
rev: v0.11.5
|
||||||
|
|||||||
@@ -167,13 +167,6 @@ class VideoRecordConfig:
|
|||||||
trajectory_name: str = "trajectory"
|
trajectory_name: str = "trajectory"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class WrapperConfig:
|
|
||||||
"""Configuration for environment wrappers."""
|
|
||||||
|
|
||||||
joint_masking_action_space: list[bool] | None = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class EEActionSpaceConfig:
|
class EEActionSpaceConfig:
|
||||||
"""Configuration parameters for end-effector action space."""
|
"""Configuration parameters for end-effector action space."""
|
||||||
@@ -186,7 +179,7 @@ class EEActionSpaceConfig:
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class EnvWrapperConfig:
|
class EnvTransformConfig:
|
||||||
"""Configuration for environment wrappers."""
|
"""Configuration for environment wrappers."""
|
||||||
|
|
||||||
ee_action_space_params: EEActionSpaceConfig = field(default_factory=EEActionSpaceConfig)
|
ee_action_space_params: EEActionSpaceConfig = field(default_factory=EEActionSpaceConfig)
|
||||||
@@ -211,7 +204,7 @@ class HILSerlRobotEnvConfig(EnvConfig):
|
|||||||
"""Configuration for the HILSerlRobotEnv environment."""
|
"""Configuration for the HILSerlRobotEnv environment."""
|
||||||
|
|
||||||
robot: Optional[RobotConfig] = None
|
robot: Optional[RobotConfig] = None
|
||||||
wrapper: Optional[EnvWrapperConfig] = None
|
wrapper: Optional[EnvTransformConfig] = None
|
||||||
fps: int = 10
|
fps: int = 10
|
||||||
name: str = "real_robot"
|
name: str = "real_robot"
|
||||||
mode: str = None # Either "record", "replay", None
|
mode: str = None # Either "record", "replay", None
|
||||||
@@ -224,9 +217,8 @@ class HILSerlRobotEnvConfig(EnvConfig):
|
|||||||
push_to_hub: bool = True
|
push_to_hub: bool = True
|
||||||
pretrained_policy_name_or_path: Optional[str] = None
|
pretrained_policy_name_or_path: Optional[str] = None
|
||||||
reward_classifier_pretrained_path: Optional[str] = None
|
reward_classifier_pretrained_path: Optional[str] = None
|
||||||
number_of_steps_after_success: int = (
|
# For the reward classifier, to record more positive examples after a success
|
||||||
0 # For the reward classifier, to record more positive examples after a success
|
number_of_steps_after_success: int = 0
|
||||||
)
|
|
||||||
|
|
||||||
def gym_kwargs(self) -> dict:
|
def gym_kwargs(self) -> dict:
|
||||||
return {}
|
return {}
|
||||||
@@ -265,7 +257,7 @@ class HILEnvConfig(EnvConfig):
|
|||||||
################# args from hilserlrobotenv
|
################# args from hilserlrobotenv
|
||||||
reward_classifier_pretrained_path: Optional[str] = None
|
reward_classifier_pretrained_path: Optional[str] = None
|
||||||
robot: Optional[RobotConfig] = None
|
robot: Optional[RobotConfig] = None
|
||||||
wrapper: Optional[EnvWrapperConfig] = None
|
wrapper: Optional[EnvTransformConfig] = None
|
||||||
mode: str = None # Either "record", "replay", None
|
mode: str = None # Either "record", "replay", None
|
||||||
repo_id: Optional[str] = None
|
repo_id: Optional[str] = None
|
||||||
dataset_root: Optional[str] = None
|
dataset_root: Optional[str] = None
|
||||||
|
|||||||
@@ -151,6 +151,7 @@ class Normalize(nn.Module):
|
|||||||
# TODO(rcadene): should we remove torch.no_grad?
|
# TODO(rcadene): should we remove torch.no_grad?
|
||||||
@torch.no_grad
|
@torch.no_grad
|
||||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||||
|
# TODO: Remove this shallow copy
|
||||||
batch = dict(batch) # shallow copy avoids mutating the input batch
|
batch = dict(batch) # shallow copy avoids mutating the input batch
|
||||||
for key, ft in self.features.items():
|
for key, ft in self.features.items():
|
||||||
if key not in batch:
|
if key not in batch:
|
||||||
@@ -337,7 +338,6 @@ class NormalizeBuffer(nn.Module):
|
|||||||
_initialize_stats_buffers(self, features, norm_map, stats)
|
_initialize_stats_buffers(self, features, norm_map, stats)
|
||||||
|
|
||||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||||
batch = dict(batch)
|
|
||||||
for key, ft in self.features.items():
|
for key, ft in self.features.items():
|
||||||
if key not in batch:
|
if key not in batch:
|
||||||
continue
|
continue
|
||||||
@@ -386,7 +386,6 @@ class UnnormalizeBuffer(nn.Module):
|
|||||||
_initialize_stats_buffers(self, features, norm_map, stats)
|
_initialize_stats_buffers(self, features, norm_map, stats)
|
||||||
|
|
||||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||||
batch = dict(batch)
|
|
||||||
for key, ft in self.features.items():
|
for key, ft in self.features.items():
|
||||||
if key not in batch:
|
if key not in batch:
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -115,8 +115,11 @@ def init_logging(log_file: Path | None = None, display_pid: bool = False):
|
|||||||
fnameline = f"{record.pathname}:{record.lineno}"
|
fnameline = f"{record.pathname}:{record.lineno}"
|
||||||
|
|
||||||
# NOTE: Display PID is useful for multi-process logging.
|
# NOTE: Display PID is useful for multi-process logging.
|
||||||
pid_str = f"[PID: {os.getpid()}]" if display_pid else ""
|
if display_pid:
|
||||||
message = f"{record.levelname} {pid_str} {dt} {fnameline[-15:]:>15} {record.msg}"
|
pid_str = f"[PID: {os.getpid()}]"
|
||||||
|
message = f"{record.levelname} {pid_str} {dt} {fnameline[-15:]:>15} {record.msg}"
|
||||||
|
else:
|
||||||
|
message = f"{record.levelname} {dt} {fnameline[-15:]:>15} {record.msg}"
|
||||||
return message
|
return message
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
@@ -131,7 +134,7 @@ def init_logging(log_file: Path | None = None, display_pid: bool = False):
|
|||||||
logging.getLogger().addHandler(console_handler)
|
logging.getLogger().addHandler(console_handler)
|
||||||
|
|
||||||
if log_file is not None:
|
if log_file is not None:
|
||||||
# File handler
|
# Additionally write logs to file
|
||||||
file_handler = logging.FileHandler(log_file)
|
file_handler = logging.FileHandler(log_file)
|
||||||
file_handler.setFormatter(formatter)
|
file_handler.setFormatter(formatter)
|
||||||
logging.getLogger().addHandler(file_handler)
|
logging.getLogger().addHandler(file_handler)
|
||||||
@@ -247,11 +250,23 @@ class TimerManager:
|
|||||||
|
|
||||||
Examples
|
Examples
|
||||||
--------
|
--------
|
||||||
>>> timer = TimerManager("Policy", log=False)
|
```python
|
||||||
>>> for _ in range(3):
|
# Example 1: Using context manager
|
||||||
... with timer:
|
timer = TimerManager("Policy", log=False)
|
||||||
... time.sleep(0.01)
|
for _ in range(3):
|
||||||
>>> print(timer.last, timer.fps_avg, timer.percentile(90))
|
with timer:
|
||||||
|
time.sleep(0.01)
|
||||||
|
print(timer.last, timer.fps_avg, timer.percentile(90)) # Prints: 0.01 100.0 0.01
|
||||||
|
```
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Example 2: Using start/stop methods
|
||||||
|
timer = TimerManager("Policy", log=False)
|
||||||
|
timer.start()
|
||||||
|
time.sleep(0.01)
|
||||||
|
timer.stop()
|
||||||
|
print(timer.last, timer.fps_avg, timer.percentile(90)) # Prints: 0.01 100.0 0.01
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|||||||
@@ -149,12 +149,8 @@ class WandBLogger:
|
|||||||
|
|
||||||
if custom_step_key is not None:
|
if custom_step_key is not None:
|
||||||
value_custom_step = d[custom_step_key]
|
value_custom_step = d[custom_step_key]
|
||||||
self._wandb.log(
|
data = {f"{mode}/{k}": v, f"{mode}/{custom_step_key}": value_custom_step}
|
||||||
{
|
self._wandb.log(data)
|
||||||
f"{mode}/{k}": v,
|
|
||||||
f"{mode}/{custom_step_key}": value_custom_step,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
self._wandb.log(data={f"{mode}/{k}": v}, step=step)
|
self._wandb.log(data={f"{mode}/{k}": v}, step=step)
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ TRAIN_CONFIG_NAME = "train_config.json"
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TrainPipelineConfig(HubMixin):
|
class TrainPipelineConfig(HubMixin):
|
||||||
dataset: DatasetConfig | None = None # NOTE: In RL, we don't need a dataset
|
dataset: DatasetConfig | None = None # NOTE: In RL, we don't need an offline dataset
|
||||||
env: envs.EnvConfig | None = None
|
env: envs.EnvConfig | None = None
|
||||||
policy: PreTrainedConfig | None = None
|
policy: PreTrainedConfig | None = None
|
||||||
# Set `dir` to where you would like to save all of the run outputs. If you run another training session # with the same value for `dir` its contents will be overwritten unless you set `resume` to true.
|
# Set `dir` to where you would like to save all of the run outputs. If you run another training session # with the same value for `dir` its contents will be overwritten unless you set `resume` to true.
|
||||||
@@ -124,10 +124,7 @@ class TrainPipelineConfig(HubMixin):
|
|||||||
return draccus.encode(self)
|
return draccus.encode(self)
|
||||||
|
|
||||||
def _save_pretrained(self, save_directory: Path) -> None:
|
def _save_pretrained(self, save_directory: Path) -> None:
|
||||||
with (
|
with open(save_directory / TRAIN_CONFIG_NAME, "w") as f, draccus.config_type("json"):
|
||||||
open(save_directory / TRAIN_CONFIG_NAME, "w") as f,
|
|
||||||
draccus.config_type("json"),
|
|
||||||
):
|
|
||||||
draccus.dump(self, f, indent=4)
|
draccus.dump(self, f, indent=4)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -720,7 +720,7 @@ def teleoperate_gym_env(env, controller, fps: int = 30):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from lerobot.common.envs.configs import EEActionSpaceConfig, EnvWrapperConfig, HILSerlRobotEnvConfig
|
from lerobot.common.envs.configs import EEActionSpaceConfig, EnvTransformConfig, HILSerlRobotEnvConfig
|
||||||
from lerobot.common.robot_devices.robots.configs import RobotConfig
|
from lerobot.common.robot_devices.robots.configs import RobotConfig
|
||||||
from lerobot.common.robot_devices.robots.utils import make_robot_from_config
|
from lerobot.common.robot_devices.robots.utils import make_robot_from_config
|
||||||
from lerobot.scripts.server.gym_manipulator import make_robot_env
|
from lerobot.scripts.server.gym_manipulator import make_robot_env
|
||||||
@@ -781,7 +781,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
elif args.mode in ["keyboard_gym", "gamepad_gym"]:
|
elif args.mode in ["keyboard_gym", "gamepad_gym"]:
|
||||||
# Gym environment control modes
|
# Gym environment control modes
|
||||||
cfg = HILSerlRobotEnvConfig(robot=robot_config, wrapper=EnvWrapperConfig())
|
cfg = HILSerlRobotEnvConfig(robot=robot_config, wrapper=EnvTransformConfig())
|
||||||
cfg.wrapper.ee_action_space_params = EEActionSpaceConfig(
|
cfg.wrapper.ee_action_space_params = EEActionSpaceConfig(
|
||||||
x_step_size=0.03, y_step_size=0.03, z_step_size=0.03, bounds=bounds
|
x_step_size=0.03, y_step_size=0.03, z_step_size=0.03, bounds=bounds
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user