Add multithreading for video generation, Speed policy sampling

This commit is contained in:
Cadene
2024-02-24 18:18:39 +00:00
parent 591985c67d
commit aed02dc7c6
4 changed files with 59 additions and 6 deletions

View File

@@ -51,6 +51,11 @@ class TOLD(nn.Module):
"""Predicts next latent state (d) and single-step reward (R)."""
x = torch.cat([z, a], dim=-1)
return self._dynamics(x), self._reward(x)
def next_dynamics(self, z, a):
"""Predicts next latent state (d)."""
x = torch.cat([z, a], dim=-1)
return self._dynamics(x)
def pi(self, z, std=0):
"""Samples an action from the learned policy (pi)."""
@@ -191,7 +196,7 @@ class TDMPC(nn.Module):
_z = z.repeat(num_pi_trajs, 1)
for t in range(horizon):
pi_actions[t] = self.model.pi(_z, self.cfg.min_std)
_z, _ = self.model.next(_z, pi_actions[t])
_z = self.model.next_dynamics(_z, pi_actions[t])
# Initialize state and parameters
z = z.repeat(self.cfg.num_samples + num_pi_trajs, 1)
@@ -241,6 +246,11 @@ class TDMPC(nn.Module):
mean, std = self.cfg.momentum * mean + (1 - self.cfg.momentum) * _mean, _std
# Outputs
# TODO(rcadene): remove numpy with
# # Convert score tensor to probabilities using softmax
# probabilities = torch.softmax(score, dim=0)
# # Generate a random sample index based on the probabilities
# sample_index = torch.multinomial(probabilities, 1).item()
score = score.squeeze(1).cpu().numpy()
actions = elite_actions[:, np.random.choice(np.arange(score.shape[0]), p=score)]
self._prev_mean = mean