forked from tangger/lerobot
Merge remote-tracking branch 'upstream/main' into fix_pusht_diffusion
This commit is contained in:
@@ -18,7 +18,7 @@ from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.logger import log_output_dir
|
||||
from lerobot.common.policies.abstract import AbstractPolicy
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.utils import init_logging, set_seed
|
||||
from lerobot.common.utils import get_safe_torch_device, init_logging, set_seed
|
||||
|
||||
|
||||
def write_video(video_path, stacked_frames, fps):
|
||||
@@ -35,7 +35,8 @@ def eval_policy(
|
||||
fps: int = 15,
|
||||
return_first_video: bool = False,
|
||||
):
|
||||
policy.eval()
|
||||
if policy is not None:
|
||||
policy.eval()
|
||||
start = time.time()
|
||||
sum_rewards = []
|
||||
max_rewards = []
|
||||
@@ -56,7 +57,8 @@ def eval_policy(
|
||||
with torch.inference_mode():
|
||||
# TODO(alexander-soare): When `break_when_any_done == False` this rolls out for max_steps even when all
|
||||
# envs are done the first time. But we only use the first rollout. This is a waste of compute.
|
||||
policy.clear_action_queue()
|
||||
if policy is not None:
|
||||
policy.clear_action_queue()
|
||||
rollout = env.rollout(
|
||||
max_steps=max_steps,
|
||||
policy=policy,
|
||||
@@ -129,10 +131,8 @@ def eval(cfg: dict, out_dir=None):
|
||||
|
||||
init_logging()
|
||||
|
||||
if cfg.device == "cuda":
|
||||
assert torch.cuda.is_available()
|
||||
else:
|
||||
logging.warning("Using CPU, this will be slow.")
|
||||
# Check device is available
|
||||
get_safe_torch_device(cfg.device, log=True)
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
@@ -12,7 +12,7 @@ from lerobot.common.datasets.factory import make_offline_buffer
|
||||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.logger import Logger, log_output_dir
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.utils import format_big_number, init_logging, set_seed
|
||||
from lerobot.common.utils import format_big_number, get_safe_torch_device, init_logging, set_seed
|
||||
from lerobot.scripts.eval import eval_policy
|
||||
|
||||
|
||||
@@ -112,13 +112,13 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||
raise NotImplementedError()
|
||||
if job_name is None:
|
||||
raise NotImplementedError()
|
||||
if cfg.online_steps > 0:
|
||||
assert cfg.rollout_batch_size == 1, "rollout_batch_size > 1 not supported for online training steps"
|
||||
|
||||
init_logging()
|
||||
|
||||
if cfg.device == "cuda":
|
||||
assert torch.cuda.is_available()
|
||||
else:
|
||||
logging.warning("Using CPU, this will be slow.")
|
||||
# Check device is available
|
||||
get_safe_torch_device(cfg.device, log=True)
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
@@ -220,11 +220,11 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||
# TODO: add configurable number of rollout? (default=1)
|
||||
with torch.no_grad():
|
||||
rollout = env.rollout(
|
||||
max_steps=cfg.env.episode_length // cfg.n_action_steps,
|
||||
max_steps=cfg.env.episode_length,
|
||||
policy=td_policy,
|
||||
auto_cast_to_device=True,
|
||||
)
|
||||
assert len(rollout) <= cfg.env.episode_length // cfg.n_action_steps
|
||||
assert len(rollout) <= cfg.env.episode_length
|
||||
# set same episode index for all time steps contained in this rollout
|
||||
rollout["episode"] = torch.tensor([env_step] * len(rollout), dtype=torch.int)
|
||||
online_buffer.extend(rollout)
|
||||
|
||||
Reference in New Issue
Block a user