Merge remote-tracking branch 'upstream/main' into fix_pusht_diffusion

This commit is contained in:
Alexander Soare
2024-03-21 10:20:52 +00:00
15 changed files with 276 additions and 46 deletions

View File

@@ -12,7 +12,7 @@ from lerobot.common.datasets.factory import make_offline_buffer
from lerobot.common.envs.factory import make_env
from lerobot.common.logger import Logger, log_output_dir
from lerobot.common.policies.factory import make_policy
from lerobot.common.utils import format_big_number, init_logging, set_seed
from lerobot.common.utils import format_big_number, get_safe_torch_device, init_logging, set_seed
from lerobot.scripts.eval import eval_policy
@@ -112,13 +112,13 @@ def train(cfg: dict, out_dir=None, job_name=None):
raise NotImplementedError()
if job_name is None:
raise NotImplementedError()
if cfg.online_steps > 0:
assert cfg.rollout_batch_size == 1, "rollout_batch_size > 1 not supported for online training steps"
init_logging()
if cfg.device == "cuda":
assert torch.cuda.is_available()
else:
logging.warning("Using CPU, this will be slow.")
# Check device is available
get_safe_torch_device(cfg.device, log=True)
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
@@ -220,11 +220,11 @@ def train(cfg: dict, out_dir=None, job_name=None):
# TODO: add configurable number of rollout? (default=1)
with torch.no_grad():
rollout = env.rollout(
max_steps=cfg.env.episode_length // cfg.n_action_steps,
max_steps=cfg.env.episode_length,
policy=td_policy,
auto_cast_to_device=True,
)
assert len(rollout) <= cfg.env.episode_length // cfg.n_action_steps
assert len(rollout) <= cfg.env.episode_length
# set same episode index for all time steps contained in this rollout
rollout["episode"] = torch.tensor([env_step] * len(rollout), dtype=torch.int)
online_buffer.extend(rollout)