Add multithreading for video generation, Speed policy sampling

This commit is contained in:
Cadene
2024-02-24 18:18:39 +00:00
parent 591985c67d
commit aed02dc7c6
4 changed files with 59 additions and 6 deletions

View File

@@ -56,6 +56,35 @@ python lerobot/scripts/eval.py \
- [ ] add diffusion - [ ] add diffusion
- [ ] add aloha 2 - [ ] add aloha 2
## Profile
**Example**
```python
from torch.profiler import profile, record_function, ProfilerActivity
def trace_handler(prof):
prof.export_chrome_trace(f"tmp/trace_schedule_{prof.step_num}.json")
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
schedule=torch.profiler.schedule(
wait=2,
warmup=2,
active=3,
),
on_trace_ready=trace_handler
) as prof:
with record_function("eval_policy"):
for i in range(num_episodes):
prof.step()
```
```bash
python lerobot/scripts/eval.py \
pretrained_model_path=/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/final.pt \
eval_episodes=7
```
## Contribute ## Contribute
**style** **style**

View File

@@ -51,6 +51,11 @@ class TOLD(nn.Module):
"""Predicts next latent state (d) and single-step reward (R).""" """Predicts next latent state (d) and single-step reward (R)."""
x = torch.cat([z, a], dim=-1) x = torch.cat([z, a], dim=-1)
return self._dynamics(x), self._reward(x) return self._dynamics(x), self._reward(x)
def next_dynamics(self, z, a):
"""Predicts next latent state (d)."""
x = torch.cat([z, a], dim=-1)
return self._dynamics(x)
def pi(self, z, std=0): def pi(self, z, std=0):
"""Samples an action from the learned policy (pi).""" """Samples an action from the learned policy (pi)."""
@@ -191,7 +196,7 @@ class TDMPC(nn.Module):
_z = z.repeat(num_pi_trajs, 1) _z = z.repeat(num_pi_trajs, 1)
for t in range(horizon): for t in range(horizon):
pi_actions[t] = self.model.pi(_z, self.cfg.min_std) pi_actions[t] = self.model.pi(_z, self.cfg.min_std)
_z, _ = self.model.next(_z, pi_actions[t]) _z = self.model.next_dynamics(_z, pi_actions[t])
# Initialize state and parameters # Initialize state and parameters
z = z.repeat(self.cfg.num_samples + num_pi_trajs, 1) z = z.repeat(self.cfg.num_samples + num_pi_trajs, 1)
@@ -241,6 +246,11 @@ class TDMPC(nn.Module):
mean, std = self.cfg.momentum * mean + (1 - self.cfg.momentum) * _mean, _std mean, std = self.cfg.momentum * mean + (1 - self.cfg.momentum) * _mean, _std
# Outputs # Outputs
# TODO(rcadene): remove numpy with
# # Convert score tensor to probabilities using softmax
# probabilities = torch.softmax(score, dim=0)
# # Generate a random sample index based on the probabilities
# sample_index = torch.multinomial(probabilities, 1).item()
score = score.squeeze(1).cpu().numpy() score = score.squeeze(1).cpu().numpy()
actions = elite_actions[:, np.random.choice(np.arange(score.shape[0]), p=score)] actions = elite_actions[:, np.random.choice(np.arange(score.shape[0]), p=score)]
self._prev_mean = mean self._prev_mean = mean

View File

@@ -11,7 +11,10 @@ from torchrl.envs import EnvBase
from lerobot.common.envs.factory import make_env from lerobot.common.envs.factory import make_env
from lerobot.common.tdmpc import TDMPC from lerobot.common.tdmpc import TDMPC
from lerobot.common.utils import set_seed from lerobot.common.utils import set_seed
import threading
def write_video(video_path, stacked_frames, fps):
imageio.mimsave(video_path, stacked_frames, fps=fps)
def eval_policy( def eval_policy(
env: EnvBase, env: EnvBase,
@@ -29,6 +32,7 @@ def eval_policy(
sum_rewards = [] sum_rewards = []
max_rewards = [] max_rewards = []
successes = [] successes = []
threads = []
for i in range(num_episodes): for i in range(num_episodes):
ep_frames = [] ep_frames = []
@@ -63,7 +67,12 @@ def eval_policy(
if save_video: if save_video:
video_dir.mkdir(parents=True, exist_ok=True) video_dir.mkdir(parents=True, exist_ok=True)
video_path = video_dir / f"eval_episode_{i}.mp4" video_path = video_dir / f"eval_episode_{i}.mp4"
imageio.mimsave(video_path, stacked_frames, fps=fps) thread = threading.Thread(
target=write_video,
args=(str(video_path), stacked_frames, fps),
)
thread.start()
threads.append(thread)
first_episode = i == 0 first_episode = i == 0
if wandb and first_episode: if wandb and first_episode:
@@ -72,6 +81,9 @@ def eval_policy(
) )
wandb.log({"eval_video": eval_video}, step=env_step) wandb.log({"eval_video": eval_video}, step=env_step)
for thread in threads:
thread.join()
metrics = { metrics = {
"avg_sum_reward": np.nanmean(sum_rewards), "avg_sum_reward": np.nanmean(sum_rewards),
"avg_max_reward": np.nanmean(max_rewards), "avg_max_reward": np.nanmean(max_rewards),
@@ -90,6 +102,7 @@ def eval(cfg: dict, out_dir=None):
raise NotImplementedError() raise NotImplementedError()
assert torch.cuda.is_available() assert torch.cuda.is_available()
torch.backends.cudnn.benchmark = True
set_seed(cfg.seed) set_seed(cfg.seed)
print(colored("Log dir:", "yellow", attrs=["bold"]), out_dir) print(colored("Log dir:", "yellow", attrs=["bold"]), out_dir)
@@ -98,9 +111,9 @@ def eval(cfg: dict, out_dir=None):
if cfg.pretrained_model_path: if cfg.pretrained_model_path:
policy = TDMPC(cfg) policy = TDMPC(cfg)
if "offline" in cfg.pretrained_model_path: if "offline" in cfg.pretrained_model_path:
policy.step = 25000 policy.step[0] = 25000
elif "final" in cfg.pretrained_model_path: elif "final" in cfg.pretrained_model_path:
policy.step = 100000 policy.step[0] = 100000
else: else:
raise NotImplementedError() raise NotImplementedError()
policy.load(cfg.pretrained_model_path) policy.load(cfg.pretrained_model_path)

View File

@@ -46,6 +46,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
raise NotImplementedError() raise NotImplementedError()
assert torch.cuda.is_available() assert torch.cuda.is_available()
torch.backends.cudnn.benchmark = True
set_seed(cfg.seed) set_seed(cfg.seed)
print(colored("Work dir:", "yellow", attrs=["bold"]), out_dir) print(colored("Work dir:", "yellow", attrs=["bold"]), out_dir)
@@ -55,9 +56,9 @@ def train(cfg: dict, out_dir=None, job_name=None):
# TODO(rcadene): hack for old pretrained models from fowm # TODO(rcadene): hack for old pretrained models from fowm
if "fowm" in cfg.pretrained_model_path: if "fowm" in cfg.pretrained_model_path:
if "offline" in cfg.pretrained_model_path: if "offline" in cfg.pretrained_model_path:
policy.step = 25000 policy.step[0] = 25000
elif "final" in cfg.pretrained_model_path: elif "final" in cfg.pretrained_model_path:
policy.step = 100000 policy.step[0] = 100000
else: else:
raise NotImplementedError() raise NotImplementedError()
policy.load(cfg.pretrained_model_path) policy.load(cfg.pretrained_model_path)