Fix done in pusht, Fix --time in sbatch
This commit is contained in:
@@ -88,10 +88,6 @@ def add_tee(
|
||||
|
||||
class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||
|
||||
# available_datasets = [
|
||||
# "xarm_lift_medium",
|
||||
# ]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_id,
|
||||
@@ -233,6 +229,7 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||
block_angle = state[:, 4]
|
||||
|
||||
reward = torch.zeros(num_frames, 1)
|
||||
success = torch.zeros(num_frames, 1, dtype=torch.bool)
|
||||
done = torch.zeros(num_frames, 1, dtype=torch.bool)
|
||||
for i in range(num_frames):
|
||||
space = pymunk.Space()
|
||||
@@ -257,7 +254,10 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||
goal_area = goal_geom.area
|
||||
coverage = intersection_area / goal_area
|
||||
reward[i] = np.clip(coverage / SUCCESS_THRESHOLD, 0, 1)
|
||||
done[i] = coverage > SUCCESS_THRESHOLD
|
||||
success[i] = coverage > SUCCESS_THRESHOLD
|
||||
|
||||
# last step of demonstration is considered done
|
||||
done[-1] = True
|
||||
|
||||
episode = TensorDict(
|
||||
{
|
||||
@@ -271,6 +271,7 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||
# TODO: verify that reward and done are aligned with image and agent_pos
|
||||
("next", "reward"): reward[1:],
|
||||
("next", "done"): done[1:],
|
||||
("next", "success"): success[1:],
|
||||
},
|
||||
batch_size=num_frames - 1,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user