Dataset v2.0 (#461)

Co-authored-by: Remi <remi.cadene@huggingface.co>
This commit is contained in:
Simon Alibert
2024-11-29 19:04:00 +01:00
committed by GitHub
parent 96c7052777
commit 32eb0cec8f
71 changed files with 6115 additions and 2235 deletions

View File

@@ -171,9 +171,9 @@ def log_train_info(logger: Logger, info, step, cfg, dataset, is_online):
# A sample is an (observation,action) pair, where observation and action
# can be on multiple timestamps. In a batch, we have `batch_size`` number of samples.
num_samples = (step + 1) * cfg.training.batch_size
avg_samples_per_ep = dataset.num_samples / dataset.num_episodes
avg_samples_per_ep = dataset.num_frames / dataset.num_episodes
num_episodes = num_samples / avg_samples_per_ep
num_epochs = num_samples / dataset.num_samples
num_epochs = num_samples / dataset.num_frames
log_items = [
f"step:{format_big_number(step)}",
# number of samples seen during training
@@ -208,9 +208,9 @@ def log_eval_info(logger, info, step, cfg, dataset, is_online):
# A sample is an (observation,action) pair, where observation and action
# can be on multiple timestamps. In a batch, we have `batch_size`` number of samples.
num_samples = (step + 1) * cfg.training.batch_size
avg_samples_per_ep = dataset.num_samples / dataset.num_episodes
avg_samples_per_ep = dataset.num_frames / dataset.num_episodes
num_episodes = num_samples / avg_samples_per_ep
num_epochs = num_samples / dataset.num_samples
num_epochs = num_samples / dataset.num_frames
log_items = [
f"step:{format_big_number(step)}",
# number of samples seen during training
@@ -328,7 +328,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
logging.info("make_policy")
policy = make_policy(
hydra_cfg=cfg,
dataset_stats=offline_dataset.stats if not cfg.resume else None,
dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
)
assert isinstance(policy, nn.Module)
@@ -349,7 +349,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
logging.info(f"{cfg.env.task=}")
logging.info(f"{cfg.training.offline_steps=} ({format_big_number(cfg.training.offline_steps)})")
logging.info(f"{cfg.training.online_steps=}")
logging.info(f"{offline_dataset.num_samples=} ({format_big_number(offline_dataset.num_samples)})")
logging.info(f"{offline_dataset.num_frames=} ({format_big_number(offline_dataset.num_frames)})")
logging.info(f"{offline_dataset.num_episodes=}")
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
@@ -573,7 +573,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
online_drop_n_last_frames=cfg.training.get("drop_n_last_frames", 0) + 1,
online_sampling_ratio=cfg.training.online_sampling_ratio,
)
sampler.num_samples = len(concat_dataset)
sampler.num_frames = len(concat_dataset)
update_online_buffer_s = time.perf_counter() - start_update_buffer_time