forked from tangger/lerobot
- Added JointMaskingActionSpace wrapper in gym_manipulator in order to select which joints will be controlled. For example, we can disable the gripper actions for some tasks.
- Added Nan detection mechanisms in the actor, learner and gym_manipulator for the case where we encounter nans in the loop. - changed the non-blocking in the `.to(device)` functions to only work for the case of cuda because they were causing nans when running the policy on mps - Added some joint clipping and limits in the env, robot and policy configs. TODO clean this part and make the limits in one config file only. Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
committed by
AdilZouitine
parent
3cb43f801c
commit
c623824139
@@ -43,16 +43,27 @@ class BatchTransition(TypedDict):
|
||||
|
||||
def move_transition_to_device(transition: Transition, device: str = "cpu") -> Transition:
|
||||
# Move state tensors to CPU
|
||||
transition["state"] = {key: val.to(device, non_blocking=True) for key, val in transition["state"].items()}
|
||||
device = torch.device(device)
|
||||
transition["state"] = {
|
||||
key: val.to(device, non_blocking=device.type == "cuda") for key, val in transition["state"].items()
|
||||
}
|
||||
|
||||
# Move action to CPU
|
||||
transition["action"] = transition["action"].to(device, non_blocking=True)
|
||||
transition["action"] = transition["action"].to(device, non_blocking=device.type == "cuda")
|
||||
|
||||
# No need to move reward or done, as they are float and bool
|
||||
|
||||
# No need to move reward or done, as they are float and bool
|
||||
if isinstance(transition["reward"], torch.Tensor):
|
||||
transition["reward"] = transition["reward"].to(device=device, non_blocking=device.type == "cuda")
|
||||
|
||||
if isinstance(transition["done"], torch.Tensor):
|
||||
transition["done"] = transition["done"].to(device, non_blocking=device.type == "cuda")
|
||||
|
||||
# Move next_state tensors to CPU
|
||||
transition["next_state"] = {
|
||||
key: val.to(device, non_blocking=True) for key, val in transition["next_state"].items()
|
||||
key: val.to(device, non_blocking=device.type == "cuda")
|
||||
for key, val in transition["next_state"].items()
|
||||
}
|
||||
|
||||
# If complementary_info is present, move its tensors to CPU
|
||||
|
||||
Reference in New Issue
Block a user