[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
@@ -167,7 +167,7 @@ from lerobot.common.robot_devices.control_utils import (
|
|||||||
warmup_record,
|
warmup_record,
|
||||||
)
|
)
|
||||||
from lerobot.common.robot_devices.robots.utils import Robot, make_robot_from_config
|
from lerobot.common.robot_devices.robots.utils import Robot, make_robot_from_config
|
||||||
from lerobot.common.robot_devices.utils import busy_wait, safe_disconnect
|
from lerobot.common.robot_devices.utils import safe_disconnect
|
||||||
from lerobot.common.utils.utils import has_method, init_logging, log_say
|
from lerobot.common.utils.utils import has_method, init_logging, log_say
|
||||||
from lerobot.configs import parser
|
from lerobot.configs import parser
|
||||||
|
|
||||||
|
|||||||
@@ -346,7 +346,6 @@ class AddCurrentToObservation(gym.ObservationWrapper):
|
|||||||
old_high = self.observation_space["observation.state"].high
|
old_high = self.observation_space["observation.state"].high
|
||||||
old_shape = self.observation_space["observation.state"].shape
|
old_shape = self.observation_space["observation.state"].shape
|
||||||
|
|
||||||
|
|
||||||
new_low = np.concatenate([old_low, np.zeros(num_dof)])
|
new_low = np.concatenate([old_low, np.zeros(num_dof)])
|
||||||
new_high = np.concatenate([old_high, np.ones(num_dof) * max_current])
|
new_high = np.concatenate([old_high, np.ones(num_dof) * max_current])
|
||||||
|
|
||||||
@@ -359,9 +358,10 @@ class AddCurrentToObservation(gym.ObservationWrapper):
|
|||||||
dtype=np.float32,
|
dtype=np.float32,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def observation(self, observation):
|
def observation(self, observation):
|
||||||
present_current = self.unwrapped.robot.follower_arms["main"].read("Present_Current").astype(np.float32)
|
present_current = (
|
||||||
|
self.unwrapped.robot.follower_arms["main"].read("Present_Current").astype(np.float32)
|
||||||
|
)
|
||||||
observation["observation.state"] = torch.cat(
|
observation["observation.state"] = torch.cat(
|
||||||
[observation["observation.state"], torch.from_numpy(present_current)], dim=-1
|
[observation["observation.state"], torch.from_numpy(present_current)], dim=-1
|
||||||
)
|
)
|
||||||
@@ -859,7 +859,10 @@ class GripperActionWrapper(gym.ActionWrapper):
|
|||||||
action, is_intervention = action
|
action, is_intervention = action
|
||||||
|
|
||||||
if self.gripper_sleep > 0.0:
|
if self.gripper_sleep > 0.0:
|
||||||
if self.last_gripper_action is not None and time.perf_counter() - self.last_gripper_action_time < self.gripper_sleep:
|
if (
|
||||||
|
self.last_gripper_action is not None
|
||||||
|
and time.perf_counter() - self.last_gripper_action_time < self.gripper_sleep
|
||||||
|
):
|
||||||
action[-1] = self.last_gripper_action
|
action[-1] = self.last_gripper_action
|
||||||
else:
|
else:
|
||||||
self.last_gripper_action_time = time.perf_counter()
|
self.last_gripper_action_time = time.perf_counter()
|
||||||
@@ -1506,7 +1509,7 @@ def main(cfg: EnvConfig):
|
|||||||
|
|
||||||
dt_s = time.perf_counter() - start_loop_s
|
dt_s = time.perf_counter() - start_loop_s
|
||||||
busy_wait(1 / cfg.fps - dt_s)
|
busy_wait(1 / cfg.fps - dt_s)
|
||||||
|
|
||||||
logging.info(f"Success after 20 steps {sucesses}")
|
logging.info(f"Success after 20 steps {sucesses}")
|
||||||
logging.info(f"success rate {sum(sucesses) / len(sucesses)}")
|
logging.info(f"success rate {sum(sucesses) / len(sucesses)}")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user