fix tests
This commit is contained in:
@@ -123,12 +123,8 @@ def make_offline_buffer(
|
|||||||
|
|
||||||
if cfg.policy.name == "diffusion" and cfg.env.name == "pusht":
|
if cfg.policy.name == "diffusion" and cfg.env.name == "pusht":
|
||||||
# TODO(rcadene): we overwrite stats to have the same as pretrained model, but we should remove this
|
# TODO(rcadene): we overwrite stats to have the same as pretrained model, but we should remove this
|
||||||
stats["observation", "state", "min"] = torch.tensor(
|
stats["observation", "state", "min"] = torch.tensor([13.456424, 32.938293], dtype=torch.float32)
|
||||||
[13.456424, 32.938293], dtype=torch.float32
|
stats["observation", "state", "max"] = torch.tensor([496.14618, 510.9579], dtype=torch.float32)
|
||||||
)
|
|
||||||
stats["observation", "state", "max"] = torch.tensor(
|
|
||||||
[496.14618, 510.9579], dtype=torch.float32
|
|
||||||
)
|
|
||||||
stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
|
stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
|
||||||
stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
|
stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
|
||||||
|
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ import lerobot
|
|||||||
|
|
||||||
from lerobot.common.envs.aloha.env import AlohaEnv
|
from lerobot.common.envs.aloha.env import AlohaEnv
|
||||||
from lerobot.common.envs.pusht.env import PushtEnv
|
from lerobot.common.envs.pusht.env import PushtEnv
|
||||||
from lerobot.common.envs.simxarm import SimxarmEnv
|
from lerobot.common.envs.simxarm.env import SimxarmEnv
|
||||||
|
|
||||||
from lerobot.common.datasets.simxarm import SimxarmDataset
|
from lerobot.common.datasets.simxarm import SimxarmDataset
|
||||||
from lerobot.common.datasets.aloha import AlohaDataset
|
from lerobot.common.datasets.aloha import AlohaDataset
|
||||||
|
|||||||
Reference in New Issue
Block a user