forked from tangger/lerobot
Make sure policies don't mutate the batch (#323)
This commit is contained in:
@@ -14,6 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import inspect
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
import einops
|
||||
@@ -161,8 +162,13 @@ def test_policy(env_name, policy_name, extra_overrides):
|
||||
for key in batch:
|
||||
batch[key] = batch[key].to(DEVICE, non_blocking=True)
|
||||
|
||||
# Test updating the policy
|
||||
# Test updating the policy (and test that it does not mutate the batch)
|
||||
batch_ = deepcopy(batch)
|
||||
policy.forward(batch)
|
||||
assert set(batch) == set(batch_), "Batch keys are not the same after a forward pass."
|
||||
assert all(
|
||||
torch.equal(batch[k], batch_[k]) for k in batch
|
||||
), "Batch values are not the same after a forward pass."
|
||||
|
||||
# reset the policy and environment
|
||||
policy.reset()
|
||||
@@ -174,9 +180,16 @@ def test_policy(env_name, policy_name, extra_overrides):
|
||||
# send observation to device/gpu
|
||||
observation = {key: observation[key].to(DEVICE, non_blocking=True) for key in observation}
|
||||
|
||||
# get the next action for the environment
|
||||
# get the next action for the environment (also check that the observation batch is not modified)
|
||||
observation_ = deepcopy(observation)
|
||||
with torch.inference_mode():
|
||||
action = policy.select_action(observation).cpu().numpy()
|
||||
assert set(observation) == set(
|
||||
observation_
|
||||
), "Observation batch keys are not the same after a forward pass."
|
||||
assert all(
|
||||
torch.equal(observation[k], observation_[k]) for k in observation
|
||||
), "Observation batch values are not the same after a forward pass."
|
||||
|
||||
# Test step through policy
|
||||
env.step(action)
|
||||
|
||||
Reference in New Issue
Block a user