forked from tangger/lerobot
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
committed by
Michel Aractingi
parent
2abbd60a0d
commit
0ea27704f6
@@ -51,18 +51,12 @@ def main():
|
||||
# - dataset stats: for normalization and denormalization of input/outputs
|
||||
dataset_metadata = LeRobotDatasetMetadata("lerobot/pusht")
|
||||
features = dataset_to_policy_features(dataset_metadata.features)
|
||||
output_features = {
|
||||
key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION
|
||||
}
|
||||
input_features = {
|
||||
key: ft for key, ft in features.items() if key not in output_features
|
||||
}
|
||||
output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
||||
input_features = {key: ft for key, ft in features.items() if key not in output_features}
|
||||
|
||||
# Policies are initialized with a configuration class, in this case `DiffusionConfig`. For this example,
|
||||
# we'll just use the defaults and so no arguments other than input/output features need to be passed.
|
||||
cfg = DiffusionConfig(
|
||||
input_features=input_features, output_features=output_features
|
||||
)
|
||||
cfg = DiffusionConfig(input_features=input_features, output_features=output_features)
|
||||
|
||||
# We can now instantiate our policy with this config and the dataset stats.
|
||||
policy = DiffusionPolicy(cfg, dataset_stats=dataset_metadata.stats)
|
||||
@@ -72,12 +66,8 @@ def main():
|
||||
# Another policy-dataset interaction is with the delta_timestamps. Each policy expects a given number frames
|
||||
# which can differ for inputs, outputs and rewards (if there are some).
|
||||
delta_timestamps = {
|
||||
"observation.image": [
|
||||
i / dataset_metadata.fps for i in cfg.observation_delta_indices
|
||||
],
|
||||
"observation.state": [
|
||||
i / dataset_metadata.fps for i in cfg.observation_delta_indices
|
||||
],
|
||||
"observation.image": [i / dataset_metadata.fps for i in cfg.observation_delta_indices],
|
||||
"observation.state": [i / dataset_metadata.fps for i in cfg.observation_delta_indices],
|
||||
"action": [i / dataset_metadata.fps for i in cfg.action_delta_indices],
|
||||
}
|
||||
|
||||
@@ -129,10 +119,7 @@ def main():
|
||||
done = False
|
||||
while not done:
|
||||
for batch in dataloader:
|
||||
batch = {
|
||||
k: (v.to(device) if isinstance(v, torch.Tensor) else v)
|
||||
for k, v in batch.items()
|
||||
}
|
||||
batch = {k: (v.to(device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}
|
||||
loss, _ = policy.forward(batch)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
Reference in New Issue
Block a user