multi-node openpi commit
This commit is contained in:
0
policy/openpi-InternData-A1/scripts/__init__.py
Normal file
0
policy/openpi-InternData-A1/scripts/__init__.py
Normal file
218
policy/openpi-InternData-A1/scripts/compute_norm_stats_real.py
Normal file
218
policy/openpi-InternData-A1/scripts/compute_norm_stats_real.py
Normal file
@@ -0,0 +1,218 @@
|
||||
"""Compute normalization statistics for real-world tasks.
|
||||
|
||||
This script is used to compute the normalization statistics for a given real-world task. It
|
||||
will compute the mean and standard deviation of the data in the dataset and save it
|
||||
to the config directory.
|
||||
"""
|
||||
import os
|
||||
import glob
|
||||
import numpy as np
|
||||
import tqdm
|
||||
import tyro
|
||||
import json
|
||||
|
||||
import openpi.models.model as _model
|
||||
import openpi.shared.normalize as normalize
|
||||
import openpi.training.config as _config
|
||||
import openpi.training.mixture_dataset as _mixture_dataset
|
||||
import openpi.training.data_loader as _data_loader
|
||||
import openpi.transforms as transforms
|
||||
|
||||
### training config ###
|
||||
import openpi.training.weight_loaders as weight_loaders
|
||||
import openpi.models.pi0_config as pi0_config
|
||||
from openpi.training.config import MultiLeRobotReala2dDataConfig, MultiLeRobotRealArxLift2DataConfig, MultiDataConfig, DataConfig, TrainConfig
|
||||
|
||||
from pdb import set_trace
|
||||
|
||||
class RemoveStrings(transforms.DataTransformFn):
|
||||
def __call__(self, x: dict) -> dict:
|
||||
return {k: v for k, v in x.items() if not np.issubdtype(np.asarray(v).dtype, np.str_)}
|
||||
|
||||
|
||||
def create_torch_dataloader(
|
||||
data_config: _config.DataConfig,
|
||||
action_horizon: int,
|
||||
batch_size: int,
|
||||
model_config: _model.BaseModelConfig,
|
||||
num_workers: int,
|
||||
max_frames: int | None = None,
|
||||
) -> tuple[_data_loader.Dataset, int]:
|
||||
dataset = _mixture_dataset.create_mixture_dataset_calculate_norm_stats(data_config, action_horizon, model_config)
|
||||
dataset = _mixture_dataset.TransformedDataset(
|
||||
dataset,
|
||||
[
|
||||
*data_config[0].repack_transforms.inputs,
|
||||
*data_config[0].data_transforms.inputs,
|
||||
RemoveStrings(),
|
||||
],
|
||||
)
|
||||
if max_frames is not None and max_frames < len(dataset):
|
||||
num_batches = max_frames // batch_size
|
||||
shuffle = True
|
||||
else:
|
||||
num_batches = len(dataset) // batch_size
|
||||
shuffle = False
|
||||
data_loader = _data_loader.TorchDataLoader(
|
||||
dataset,
|
||||
local_batch_size=batch_size,
|
||||
num_workers=num_workers,
|
||||
shuffle=shuffle,
|
||||
num_batches=num_batches,
|
||||
)
|
||||
return data_loader, num_batches
|
||||
|
||||
|
||||
def main(dataset_path, robot_name, task_name, save_path):
|
||||
if robot_name == "lift2" or robot_name == "split_aloha" or robot_name == "acone":
|
||||
config = TrainConfig(
|
||||
name="lift2",
|
||||
model=pi0_config.Pi0Config(),
|
||||
data=[
|
||||
MultiLeRobotRealArxLift2DataConfig(
|
||||
repo_dir=dataset_path,
|
||||
task_id=None,
|
||||
use_gripper_aug=False,
|
||||
stats_dir='',
|
||||
base_config=MultiDataConfig(
|
||||
prompt_from_task=True,
|
||||
),
|
||||
asset_id=robot_name,
|
||||
robot_name=robot_name,
|
||||
repack_transforms=transforms.Group(
|
||||
inputs=[
|
||||
transforms.RepackTransform(
|
||||
{
|
||||
"state_dict": {
|
||||
"left_joint": "states.left_joint.position",
|
||||
"right_joint": "states.right_joint.position",
|
||||
"left_gripper": "states.left_gripper.position",
|
||||
"right_gripper": "states.right_gripper.position"
|
||||
},
|
||||
"action_dict": {
|
||||
"left_joint": "actions.left_joint.position",
|
||||
"right_joint": "actions.right_joint.position",
|
||||
"left_gripper": "actions.left_gripper.position",
|
||||
"right_gripper": "actions.right_gripper.position"
|
||||
},
|
||||
"prompt": "task"
|
||||
}
|
||||
)
|
||||
]
|
||||
)
|
||||
),
|
||||
],
|
||||
# pretrain model path
|
||||
weight_loader=weight_loaders.CheckpointWeightLoader("checkpoints/jax/pi0_base/params"),
|
||||
pytorch_weight_path="checkpoints/pytorch/pi0_base",
|
||||
num_train_steps=30_000,
|
||||
num_workers=4,
|
||||
fsdp_devices=4,
|
||||
batch_size=8,
|
||||
)
|
||||
elif robot_name == "genie1":
|
||||
config = TrainConfig(
|
||||
name="genie1",
|
||||
model=pi0_config.Pi0Config(),
|
||||
data=[
|
||||
MultiLeRobotReala2dDataConfig(
|
||||
repo_dir=dataset_path,
|
||||
task_id=None,
|
||||
use_gripper_aug=False,
|
||||
stats_dir='',
|
||||
base_config=MultiDataConfig(
|
||||
prompt_from_task=True,
|
||||
),
|
||||
asset_id=robot_name,
|
||||
robot_name=robot_name,
|
||||
repack_transforms=transforms.Group(
|
||||
inputs=[
|
||||
transforms.RepackTransform(
|
||||
{
|
||||
"state_dict": {
|
||||
"joint": "observation.states.joint.position",
|
||||
"gripper": "observation.states.effector.position",
|
||||
},
|
||||
"action_dict": {
|
||||
"joint": "actions.joint.position",
|
||||
"gripper": "actions.effector.position",
|
||||
},
|
||||
"prompt": "task"
|
||||
}
|
||||
)
|
||||
]
|
||||
)
|
||||
),
|
||||
],
|
||||
# pretrain model path
|
||||
weight_loader=weight_loaders.CheckpointWeightLoader("checkpoints/jax/pi0_base/params"),
|
||||
pytorch_weight_path="checkpoints/pytorch/pi0_base",
|
||||
num_train_steps=30_000,
|
||||
num_workers=4,
|
||||
fsdp_devices=4,
|
||||
batch_size=8,
|
||||
)
|
||||
data_config = config.data[0].create(config.model)
|
||||
print("done")
|
||||
output_path = os.path.join(save_path, robot_name, task_name)
|
||||
stats_json_path = os.path.join(output_path, "norm_stats.json")
|
||||
if os.path.isfile(stats_json_path):
|
||||
with open(stats_json_path, 'r', encoding='utf-8') as f:
|
||||
json.load(f)
|
||||
return True
|
||||
|
||||
data_loader, num_batches = create_torch_dataloader(
|
||||
data_config, config.model.action_horizon, config.batch_size, config.model, config.num_workers, max_frames=None
|
||||
)
|
||||
|
||||
keys = ["state", "actions"]
|
||||
stats = {key: normalize.RunningStats() for key in keys}
|
||||
|
||||
step_id = 0
|
||||
for batch in tqdm.tqdm(data_loader, total=num_batches, desc="Computing stats"):
|
||||
step_id += 1
|
||||
for key in keys:
|
||||
stats[key].update(np.asarray(batch[key]))
|
||||
|
||||
norm_stats = {key: stats.get_statistics() for key, stats in stats.items()}
|
||||
|
||||
print(f"Writing stats to: {output_path}")
|
||||
normalize.save(output_path, norm_stats)
|
||||
|
||||
def check_lerobot_repo(repo_dir: str):
|
||||
if os.path.isdir(os.path.join(repo_dir, "data")) and os.path.isdir(os.path.join(repo_dir, "meta")) and os.path.isdir(os.path.join(repo_dir, "videos")):
|
||||
print(repo_dir, "true")
|
||||
return True
|
||||
else:
|
||||
print(repo_dir, "false")
|
||||
return False
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--task_path", type=str, default="data/InternData-A1/real/genie1/Put_the_pen_from_the_table_into_the_pen_holder/*")
|
||||
parser.add_argument("--robot_name", type=str, default="genie1")
|
||||
parser.add_argument("--save_path", type=str, default="stats/real")
|
||||
|
||||
args, unknown = parser.parse_known_args()
|
||||
dataset_path=args.task_path
|
||||
save_path = args.save_path
|
||||
parts = dataset_path.split("/")
|
||||
robot_idx = next((i for i, p in enumerate(parts) if p == args.robot_name), None)
|
||||
if robot_idx is None:
|
||||
raise ValueError(
|
||||
f"Cannot find robot name in path. Expected {args.robot_name}, "
|
||||
f"but got path: {dataset_path}"
|
||||
)
|
||||
|
||||
if robot_idx + 1 >= len(parts):
|
||||
raise ValueError(
|
||||
f"Path ends at robot name '{parts[robot_idx]}', cannot determine task_name: {local_path}"
|
||||
)
|
||||
robot_name = parts[robot_idx]
|
||||
task_name = parts[robot_idx + 1]
|
||||
try:
|
||||
main(dataset_path, robot_name, task_name, save_path)
|
||||
except:
|
||||
print(dataset_path)
|
||||
|
||||
314
policy/openpi-InternData-A1/scripts/compute_norm_stats_sim.py
Normal file
314
policy/openpi-InternData-A1/scripts/compute_norm_stats_sim.py
Normal file
@@ -0,0 +1,314 @@
|
||||
"""Compute normalization statistics for interndata-a1 sim tasks.
|
||||
|
||||
This script is used to compute the normalization statistics for interndata-a1 sim tasks. It
|
||||
will compute the mean and standard deviation of the data in the dataset and save it
|
||||
to the config assets directory.
|
||||
"""
|
||||
import os
|
||||
import glob
|
||||
import numpy as np
|
||||
import tqdm
|
||||
import tyro
|
||||
import json
|
||||
|
||||
import openpi.models.model as _model
|
||||
import openpi.shared.normalize as normalize
|
||||
import openpi.training.config as _config
|
||||
import openpi.training.mixture_dataset as _mixture_dataset
|
||||
import openpi.training.data_loader as _data_loader
|
||||
import openpi.transforms as transforms
|
||||
|
||||
### training config ###
|
||||
import openpi.training.weight_loaders as weight_loaders
|
||||
import openpi.models.pi0_config as pi0_config
|
||||
from openpi.training.config import MultiSimGenieDataConfig, MultiSimSplitAlohaDataConfig, MultiSimFrankaDataConfig, MultiDataConfig, DataConfig, TrainConfig
|
||||
|
||||
from pdb import set_trace
|
||||
|
||||
class RemoveStrings(transforms.DataTransformFn):
|
||||
def __call__(self, x: dict) -> dict:
|
||||
return {k: v for k, v in x.items() if not np.issubdtype(np.asarray(v).dtype, np.str_)}
|
||||
|
||||
|
||||
def create_torch_dataloader(
|
||||
data_config: _config.DataConfig,
|
||||
action_horizon: int,
|
||||
batch_size: int,
|
||||
model_config: _model.BaseModelConfig,
|
||||
num_workers: int,
|
||||
max_frames: int | None = None,
|
||||
) -> tuple[_data_loader.Dataset, int]:
|
||||
dataset = _mixture_dataset.create_mixture_dataset_calculate_norm_stats(data_config, action_horizon, model_config)
|
||||
dataset = _mixture_dataset.TransformedDataset(
|
||||
dataset,
|
||||
[
|
||||
*data_config[0].repack_transforms.inputs,
|
||||
*data_config[0].data_transforms.inputs,
|
||||
RemoveStrings(),
|
||||
],
|
||||
)
|
||||
if max_frames is not None and max_frames < len(dataset):
|
||||
num_batches = max_frames // batch_size
|
||||
shuffle = True
|
||||
else:
|
||||
num_batches = len(dataset) // batch_size
|
||||
shuffle = False
|
||||
data_loader = _data_loader.TorchDataLoader(
|
||||
dataset,
|
||||
local_batch_size=batch_size,
|
||||
num_workers=num_workers,
|
||||
shuffle=shuffle,
|
||||
num_batches=num_batches,
|
||||
)
|
||||
return data_loader, num_batches
|
||||
|
||||
|
||||
def main(dataset_path, task_category, robot_name, task_name, collect_name, save_path):
|
||||
if robot_name == "lift2" or robot_name == "split_aloha":
|
||||
config = TrainConfig(
|
||||
name="lift2",
|
||||
model=pi0_config.Pi0Config(),
|
||||
data=[
|
||||
MultiSimSplitAlohaDataConfig(
|
||||
repo_dir=dataset_path,
|
||||
task_id=None,
|
||||
use_gripper_aug=True,
|
||||
gripper_aug_config={
|
||||
"gripper_action_keys": ["master_actions.left_gripper.openness", "master_actions.right_gripper.openness"],
|
||||
"gripper_dim": -1,
|
||||
"gripper_threshold_method": "std_multiplier",
|
||||
"gripper_threshold_multiplier": 1.0,
|
||||
"gripper_min_threshold": 0.001,
|
||||
"gripper_max_threshold": 1.0,
|
||||
},
|
||||
stats_dir='',
|
||||
base_config=MultiDataConfig(
|
||||
prompt_from_task=True,
|
||||
),
|
||||
asset_id=robot_name,
|
||||
robot_name=robot_name,
|
||||
repack_transforms=transforms.Group(
|
||||
inputs=[
|
||||
transforms.RepackTransform(
|
||||
{
|
||||
"state_dict": {
|
||||
"left_joint": "states.left_joint.position",
|
||||
"right_joint": "states.right_joint.position",
|
||||
"left_gripper": "states.left_gripper.position",
|
||||
"right_gripper": "states.right_gripper.position"
|
||||
},
|
||||
"action_dict": {
|
||||
"left_joint": "actions.left_joint.position",
|
||||
"right_joint": "actions.right_joint.position",
|
||||
"left_gripper": "actions.left_gripper.position",
|
||||
"right_gripper": "actions.right_gripper.position",
|
||||
"left_gripper_openness": "master_actions.left_gripper.openness",
|
||||
"right_gripper_openness": "master_actions.right_gripper.openness"
|
||||
},
|
||||
"prompt": "task"
|
||||
}
|
||||
)
|
||||
]
|
||||
)
|
||||
),
|
||||
],
|
||||
weight_loader=weight_loaders.CheckpointWeightLoader("checkpoints/jax/pi0_base/params"),
|
||||
pytorch_weight_path="checkpoints/pytorch/pi0_base",
|
||||
num_train_steps=30_000,
|
||||
num_workers=4,
|
||||
fsdp_devices=4,
|
||||
batch_size=8,
|
||||
)
|
||||
elif robot_name == "genie1":
|
||||
config = TrainConfig(
|
||||
name="genie1",
|
||||
model=pi0_config.Pi0Config(),
|
||||
data=[
|
||||
MultiSimGenieDataConfig(
|
||||
repo_dir=dataset_path,
|
||||
task_id=None,
|
||||
use_gripper_aug=True,
|
||||
gripper_aug_config={
|
||||
"gripper_action_keys": ["master_actions.left_gripper.openness", "master_actions.right_gripper.openness"],
|
||||
"gripper_dim": -1,
|
||||
"gripper_threshold_method": "std_multiplier",
|
||||
"gripper_threshold_multiplier": 1.0,
|
||||
"gripper_min_threshold": 0.001,
|
||||
"gripper_max_threshold": 1.0,
|
||||
},
|
||||
stats_dir='',
|
||||
base_config=MultiDataConfig(
|
||||
prompt_from_task=True,
|
||||
),
|
||||
asset_id=robot_name,
|
||||
robot_name=robot_name,
|
||||
repack_transforms=transforms.Group(
|
||||
inputs=[
|
||||
transforms.RepackTransform(
|
||||
{
|
||||
"state_dict": {
|
||||
"left_joint": "states.left_joint.position",
|
||||
"right_joint": "states.right_joint.position",
|
||||
"left_gripper": "states.left_gripper.position",
|
||||
"right_gripper": "states.right_gripper.position"
|
||||
},
|
||||
"action_dict": {
|
||||
"left_joint": "actions.left_joint.position",
|
||||
"right_joint": "actions.right_joint.position",
|
||||
"left_gripper": "actions.left_gripper.position",
|
||||
"right_gripper": "actions.right_gripper.position",
|
||||
"left_gripper_openness": "master_actions.left_gripper.openness",
|
||||
"right_gripper_openness": "master_actions.right_gripper.openness"
|
||||
},
|
||||
"prompt": "task"
|
||||
}
|
||||
)
|
||||
]
|
||||
)
|
||||
),
|
||||
],
|
||||
weight_loader=weight_loaders.CheckpointWeightLoader("checkpoints/jax/pi0_base/params"),
|
||||
pytorch_weight_path="checkpoints/pytorch/pi0_base",
|
||||
num_train_steps=30_000,
|
||||
num_workers=4,
|
||||
fsdp_devices=4,
|
||||
batch_size=8,
|
||||
)
|
||||
elif "franka" in robot_name:
|
||||
config = TrainConfig(
|
||||
name="franka",
|
||||
model=pi0_config.Pi0Config(),
|
||||
data=[
|
||||
MultiSimFrankaDataConfig(
|
||||
repo_dir=dataset_path,
|
||||
task_id=None,
|
||||
use_gripper_aug=True,
|
||||
gripper_aug_config={
|
||||
"gripper_action_keys": ["actions.gripper.openness"],
|
||||
"gripper_dim": -1,
|
||||
"gripper_threshold_method": "std_multiplier",
|
||||
"gripper_threshold_multiplier": 1.0,
|
||||
"gripper_min_threshold": 0.001,
|
||||
"gripper_max_threshold": 1.0,
|
||||
},
|
||||
stats_dir='',
|
||||
base_config=MultiDataConfig(
|
||||
prompt_from_task=True,
|
||||
),
|
||||
asset_id=robot_name,
|
||||
robot_name=robot_name,
|
||||
repack_transforms=transforms.Group(
|
||||
inputs=[
|
||||
transforms.RepackTransform(
|
||||
{
|
||||
"state_dict": {
|
||||
"joint_position": "states.joint.position",
|
||||
"gripper_pose": "states.gripper.pose",
|
||||
"gripper_position": "states.gripper.position",
|
||||
},
|
||||
"action_dict": {
|
||||
"gripper_pose": "actions.gripper.pose",
|
||||
"gripper_position": "actions.gripper.position",
|
||||
"gripper_openness": "actions.gripper.openness",
|
||||
},
|
||||
"prompt": "task"
|
||||
}
|
||||
)
|
||||
]
|
||||
)
|
||||
),
|
||||
],
|
||||
weight_loader=weight_loaders.CheckpointWeightLoader("checkpoints/jax/pi0_base/params"),
|
||||
pytorch_weight_path="checkpoints/pytorch/pi0_base",
|
||||
num_train_steps=30_000,
|
||||
num_workers=4,
|
||||
fsdp_devices=4,
|
||||
batch_size=8,
|
||||
)
|
||||
|
||||
data_config = config.data[0].create(config.model)
|
||||
print("done")
|
||||
output_path = os.path.join(save_path, task_category, robot_name, task_name, collect_name)
|
||||
stats_json_path = os.path.join(output_path, "norm_stats.json")
|
||||
if os.path.isfile(stats_json_path):
|
||||
with open(stats_json_path, 'r', encoding='utf-8') as f:
|
||||
json.load(f)
|
||||
return True
|
||||
|
||||
data_loader, num_batches = create_torch_dataloader(
|
||||
data_config, config.model.action_horizon, config.batch_size, config.model, config.num_workers, max_frames=None
|
||||
)
|
||||
|
||||
keys = ["state", "actions"]
|
||||
stats = {key: normalize.RunningStats() for key in keys}
|
||||
|
||||
step_id = 0
|
||||
for batch in tqdm.tqdm(data_loader, total=num_batches, desc="Computing stats"):
|
||||
step_id += 1
|
||||
for key in keys:
|
||||
stats[key].update(np.asarray(batch[key]))
|
||||
if step_id > 10000:
|
||||
break
|
||||
|
||||
norm_stats = {key: stats.get_statistics() for key, stats in stats.items()}
|
||||
|
||||
print(f"Writing stats to: {output_path}")
|
||||
normalize.save(output_path, norm_stats)
|
||||
|
||||
def check_lerobot_repo(repo_dir: str):
|
||||
if os.path.isdir(os.path.join(repo_dir, "data")) and os.path.isdir(os.path.join(repo_dir, "meta")) and os.path.isdir(os.path.join(repo_dir, "videos")):
|
||||
print(repo_dir, "true")
|
||||
return True
|
||||
else:
|
||||
print(repo_dir, "false")
|
||||
return False
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--root_data_dir", type=str, default="data/InternData-A1/sim")
|
||||
parser.add_argument("--task_category", type=str, default="pick_and_place_tasks")
|
||||
parser.add_argument("--save_path", type=str, default="stats/sim")
|
||||
parser.add_argument("--start_ratio", type=float, default=0.0)
|
||||
parser.add_argument("--end_ratio", type=float, default=1)
|
||||
args, unknown = parser.parse_known_args()
|
||||
root_data_dir = os.path.join(args.root_data_dir, args.task_category)
|
||||
|
||||
dataset_paths = glob.glob(os.path.join(root_data_dir, "*", "*"))
|
||||
dataset_paths.sort()
|
||||
valid_paths = [
|
||||
p for p in dataset_paths
|
||||
if check_lerobot_repo(p)
|
||||
]
|
||||
|
||||
start_idx = int(len(valid_paths) * args.start_ratio)
|
||||
end_idx = int(len(valid_paths) * args.end_ratio) + 1
|
||||
valid_paths = valid_paths[start_idx:end_idx]
|
||||
for dataset_path in tqdm.tqdm(valid_paths):
|
||||
task_category = dataset_path.split('/')[-3]
|
||||
robot_name = dataset_path.split('/')[-2]
|
||||
task_name = dataset_path.split('/')[-1]
|
||||
collect_name = ""
|
||||
try:
|
||||
main(dataset_path, task_category, robot_name, task_name, collect_name, args.save_path)
|
||||
except:
|
||||
print(dataset_path)
|
||||
|
||||
dataset_paths_w_subtask = glob.glob(os.path.join(root_data_dir, "*", "*","*"))
|
||||
dataset_paths_w_subtask.sort()
|
||||
valid_paths_w_subtask = [
|
||||
p for p in dataset_paths_w_subtask
|
||||
if check_lerobot_repo(p)
|
||||
]
|
||||
start_idx = int(len(valid_paths_w_subtask) * args.start_ratio)
|
||||
end_idx = int(len(valid_paths_w_subtask) * args.end_ratio) + 1
|
||||
valid_paths_w_subtask = valid_paths_w_subtask[start_idx:end_idx]
|
||||
for dataset_path in tqdm.tqdm(valid_paths_w_subtask):
|
||||
task_category = dataset_path.split('/')[-4]
|
||||
robot_name = dataset_path.split('/')[-3]
|
||||
task_name = dataset_path.split('/')[-2]
|
||||
collect_name = dataset_path.split('/')[-1]
|
||||
try:
|
||||
main(dataset_path, task_category, robot_name, task_name, collect_name, args.save_path)
|
||||
except:
|
||||
print(dataset_path)
|
||||
@@ -0,0 +1,181 @@
|
||||
"""Compute normalization statistics for real-world tasks.
|
||||
|
||||
This script is used to compute the normalization statistics for a given real-world task. It
|
||||
will compute the mean and standard deviation of the data in the dataset and save it
|
||||
to the config directory.
|
||||
"""
|
||||
import os
|
||||
import glob
|
||||
import numpy as np
|
||||
import tqdm
|
||||
import tyro
|
||||
import json
|
||||
|
||||
import openpi.models.model as _model
|
||||
import openpi.shared.normalize as normalize
|
||||
import openpi.training.config as _config
|
||||
import openpi.training.mixture_dataset as _mixture_dataset
|
||||
import openpi.training.data_loader as _data_loader
|
||||
import openpi.transforms as transforms
|
||||
|
||||
### training config ###
|
||||
import openpi.training.weight_loaders as weight_loaders
|
||||
import openpi.models.pi0_config as pi0_config
|
||||
from openpi.training.config import MultiSim2RealSplitAlohaDataConfig, MultiDataConfig, DataConfig, TrainConfig
|
||||
|
||||
from pdb import set_trace
|
||||
|
||||
class RemoveStrings(transforms.DataTransformFn):
|
||||
def __call__(self, x: dict) -> dict:
|
||||
return {k: v for k, v in x.items() if not np.issubdtype(np.asarray(v).dtype, np.str_)}
|
||||
|
||||
|
||||
def create_torch_dataloader(
|
||||
data_config: _config.DataConfig,
|
||||
action_horizon: int,
|
||||
batch_size: int,
|
||||
model_config: _model.BaseModelConfig,
|
||||
num_workers: int,
|
||||
max_frames: int | None = None,
|
||||
) -> tuple[_data_loader.Dataset, int]:
|
||||
dataset = _mixture_dataset.create_mixture_dataset_calculate_norm_stats(data_config, action_horizon, model_config)
|
||||
dataset = _mixture_dataset.TransformedDataset(
|
||||
dataset,
|
||||
[
|
||||
*data_config[0].repack_transforms.inputs,
|
||||
*data_config[0].data_transforms.inputs,
|
||||
RemoveStrings(),
|
||||
],
|
||||
)
|
||||
if max_frames is not None and max_frames < len(dataset):
|
||||
num_batches = max_frames // batch_size
|
||||
shuffle = True
|
||||
else:
|
||||
num_batches = len(dataset) // batch_size
|
||||
shuffle = False
|
||||
data_loader = _data_loader.TorchDataLoader(
|
||||
dataset,
|
||||
local_batch_size=batch_size,
|
||||
num_workers=num_workers,
|
||||
shuffle=shuffle,
|
||||
num_batches=num_batches,
|
||||
)
|
||||
return data_loader, num_batches
|
||||
|
||||
|
||||
def main(dataset_path, robot_name, task_name, save_path):
|
||||
if robot_name == "lift2":
|
||||
config = TrainConfig(
|
||||
name="lift2",
|
||||
model=pi0_config.Pi0Config(),
|
||||
data=[
|
||||
MultiSim2RealSplitAlohaDataConfig(
|
||||
repo_dir=dataset_path,
|
||||
task_id=None,
|
||||
use_gripper_aug=False,
|
||||
stats_dir='',
|
||||
base_config=MultiDataConfig(
|
||||
prompt_from_task=True,
|
||||
),
|
||||
asset_id=robot_name,
|
||||
robot_name=robot_name,
|
||||
repack_transforms=transforms.Group(
|
||||
inputs=[
|
||||
transforms.RepackTransform(
|
||||
{
|
||||
"state_dict": {
|
||||
"left_joint": "states.left_joint.position",
|
||||
"right_joint": "states.right_joint.position",
|
||||
"left_gripper": "states.left_gripper.position",
|
||||
"right_gripper": "states.right_gripper.position"
|
||||
},
|
||||
"action_dict": {
|
||||
"left_joint": "actions.left_joint.position",
|
||||
"right_joint": "actions.right_joint.position",
|
||||
"left_gripper": "actions.left_gripper.position",
|
||||
"right_gripper": "actions.right_gripper.position",
|
||||
"left_gripper_openness": "master_actions.left_gripper.openness",
|
||||
"right_gripper_openness": "master_actions.right_gripper.openness"
|
||||
},
|
||||
"prompt": "task"
|
||||
}
|
||||
)
|
||||
]
|
||||
)
|
||||
),
|
||||
],
|
||||
# pretrain model path
|
||||
weight_loader=weight_loaders.CheckpointWeightLoader("checkpoints/jax/pi0_base/params"),
|
||||
pytorch_weight_path="checkpoints/pytorch/pi0_base",
|
||||
num_train_steps=30_000,
|
||||
num_workers=4,
|
||||
fsdp_devices=4,
|
||||
batch_size=8,
|
||||
)
|
||||
|
||||
data_config = config.data[0].create(config.model)
|
||||
print("done")
|
||||
output_path = os.path.join(save_path, robot_name, task_name)
|
||||
stats_json_path = os.path.join(output_path, "norm_stats.json")
|
||||
if os.path.isfile(stats_json_path):
|
||||
with open(stats_json_path, 'r', encoding='utf-8') as f:
|
||||
json.load(f)
|
||||
return True
|
||||
|
||||
data_loader, num_batches = create_torch_dataloader(
|
||||
data_config, config.model.action_horizon, config.batch_size, config.model, config.num_workers, max_frames=None
|
||||
)
|
||||
|
||||
keys = ["state", "actions"]
|
||||
stats = {key: normalize.RunningStats() for key in keys}
|
||||
|
||||
step_id = 0
|
||||
for batch in tqdm.tqdm(data_loader, total=num_batches, desc="Computing stats"):
|
||||
step_id += 1
|
||||
for key in keys:
|
||||
stats[key].update(np.asarray(batch[key]))
|
||||
if step_id > 10000:
|
||||
break
|
||||
|
||||
norm_stats = {key: stats.get_statistics() for key, stats in stats.items()}
|
||||
|
||||
print(f"Writing stats to: {output_path}")
|
||||
normalize.save(output_path, norm_stats)
|
||||
|
||||
def check_lerobot_repo(repo_dir: str):
|
||||
if os.path.isdir(os.path.join(repo_dir, "data")) and os.path.isdir(os.path.join(repo_dir, "meta")) and os.path.isdir(os.path.join(repo_dir, "videos")):
|
||||
print(repo_dir, "true")
|
||||
return True
|
||||
else:
|
||||
print(repo_dir, "false")
|
||||
return False
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--task_path", type=str, default="data/InternData-A1/sim/long_horizon_tasks/lift2/sort_the_rubbish/*")
|
||||
parser.add_argument("--robot_name", type=str, default="lift2")
|
||||
parser.add_argument("--save_path", type=str, default="stats/sim2real")
|
||||
|
||||
args, unknown = parser.parse_known_args()
|
||||
dataset_path=args.task_path
|
||||
save_path = args.save_path
|
||||
parts = dataset_path.split("/")
|
||||
robot_idx = next((i for i, p in enumerate(parts) if p == args.robot_name), None)
|
||||
if robot_idx is None:
|
||||
raise ValueError(
|
||||
f"Cannot find robot name in path. Expected {args.robot_name}, "
|
||||
f"but got path: {dataset_path}"
|
||||
)
|
||||
|
||||
if robot_idx + 1 >= len(parts):
|
||||
raise ValueError(
|
||||
f"Path ends at robot name '{parts[robot_idx]}', cannot determine task_name: {local_path}"
|
||||
)
|
||||
robot_name = parts[robot_idx]
|
||||
task_name = parts[robot_idx + 1]
|
||||
try:
|
||||
main(dataset_path, robot_name, task_name, save_path)
|
||||
except:
|
||||
print(dataset_path)
|
||||
|
||||
29
policy/openpi-InternData-A1/scripts/docker/compose.yml
Normal file
29
policy/openpi-InternData-A1/scripts/docker/compose.yml
Normal file
@@ -0,0 +1,29 @@
|
||||
# Run with:
|
||||
# docker compose -f scripts/docker/compose.yml up --build
|
||||
services:
|
||||
openpi_server:
|
||||
image: openpi_server
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: scripts/docker/serve_policy.Dockerfile
|
||||
init: true
|
||||
tty: true
|
||||
network_mode: host
|
||||
# Populate configured openpi data home to /openpi_assets inside the container.
|
||||
# Populate aws credential inside the container.
|
||||
volumes:
|
||||
- $PWD:/app
|
||||
- ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets
|
||||
environment:
|
||||
- SERVER_ARGS
|
||||
- OPENPI_DATA_HOME=/openpi_assets
|
||||
- IS_DOCKER=true
|
||||
|
||||
# Comment out this block if not running on a machine with GPUs.
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: 1
|
||||
capabilities: [gpu]
|
||||
37
policy/openpi-InternData-A1/scripts/docker/install_docker_ubuntu22.sh
Executable file
37
policy/openpi-InternData-A1/scripts/docker/install_docker_ubuntu22.sh
Executable file
@@ -0,0 +1,37 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Add Docker's official GPG key:
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y ca-certificates curl
|
||||
sudo install -m 0755 -d /etc/apt/keyrings
|
||||
sudo curl -fsSL https://download.docker.com/linux/ubuntu/gpg -o /etc/apt/keyrings/docker.asc
|
||||
sudo chmod a+r /etc/apt/keyrings/docker.asc
|
||||
|
||||
# Add the repository to Apt sources:
|
||||
echo \
|
||||
"deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.asc] https://download.docker.com/linux/ubuntu \
|
||||
$(. /etc/os-release && echo "$VERSION_CODENAME") stable" |
|
||||
sudo tee /etc/apt/sources.list.d/docker.list >/dev/null
|
||||
sudo apt-get update
|
||||
|
||||
sudo apt-get install -y docker-ce docker-ce-cli containerd.io docker-buildx-plugin docker-compose-plugin
|
||||
|
||||
# Add current user to the 'docker' group, which allows them to use docker commands (docker build, docker run, etc).
|
||||
# See https://docs.docker.com/engine/install/linux-postinstall/
|
||||
username=$(whoami)
|
||||
sudo usermod -aG docker $username
|
||||
|
||||
# Configure docker to start automatically on system boot.
|
||||
sudo systemctl enable docker.service
|
||||
sudo systemctl enable containerd.service
|
||||
|
||||
# https://forums.docker.com/t/docker-credential-desktop-exe-executable-file-not-found-in-path-using-wsl2/100225/5
|
||||
if [ ~/.docker/config.json ]; then
|
||||
sed -i 's/credsStore/credStore/g' ~/.docker/config.json
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "********************************************************************"
|
||||
echo "**** Restart to allow Docker permission changes to take effect. ****"
|
||||
echo "********************************************************************"
|
||||
echo ""
|
||||
@@ -0,0 +1,17 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Installs the NVIDIA Container Toolkit, which allows Docker containers to access NVIDIA GPUs.
|
||||
# NVIDIA's official documentation: https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html
|
||||
|
||||
curl -fsSL https://nvidia.github.io/libnvidia-container/gpgkey | sudo gpg --dearmor -o /usr/share/keyrings/nvidia-container-toolkit-keyring.gpg &&
|
||||
curl -s -L https://nvidia.github.io/libnvidia-container/stable/deb/nvidia-container-toolkit.list |
|
||||
sed 's#deb https://#deb [signed-by=/usr/share/keyrings/nvidia-container-toolkit-keyring.gpg] https://#g' |
|
||||
sudo tee /etc/apt/sources.list.d/nvidia-container-toolkit.list
|
||||
|
||||
# NVIDIA's documenation omits 'sudo' in the following command, but it is required.
|
||||
sudo sed -i -e '/experimental/ s/^#//g' /etc/apt/sources.list.d/nvidia-container-toolkit.list
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y nvidia-container-toolkit
|
||||
|
||||
sudo nvidia-ctk runtime configure --runtime=docker
|
||||
sudo systemctl restart docker
|
||||
@@ -0,0 +1,38 @@
|
||||
# Dockerfile for serving a PI policy.
|
||||
# Based on UV's instructions: https://docs.astral.sh/uv/guides/integration/docker/#developing-in-a-container
|
||||
|
||||
# Build the container:
|
||||
# docker build . -t openpi_server -f scripts/docker/serve_policy.Dockerfile
|
||||
|
||||
# Run the container:
|
||||
# docker run --rm -it --network=host -v .:/app --gpus=all openpi_server /bin/bash
|
||||
|
||||
FROM nvidia/cuda:12.2.2-cudnn8-runtime-ubuntu22.04@sha256:2d913b09e6be8387e1a10976933642c73c840c0b735f0bf3c28d97fc9bc422e0
|
||||
COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Needed because LeRobot uses git-lfs.
|
||||
RUN apt-get update && apt-get install -y git git-lfs linux-headers-generic build-essential clang
|
||||
|
||||
# Copy from the cache instead of linking since it's a mounted volume
|
||||
ENV UV_LINK_MODE=copy
|
||||
|
||||
# Write the virtual environment outside of the project directory so it doesn't
|
||||
# leak out of the container when we mount the application code.
|
||||
ENV UV_PROJECT_ENVIRONMENT=/.venv
|
||||
|
||||
# Install the project's dependencies using the lockfile and settings
|
||||
RUN uv venv --python 3.11.9 $UV_PROJECT_ENVIRONMENT
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
--mount=type=bind,source=uv.lock,target=uv.lock \
|
||||
--mount=type=bind,source=pyproject.toml,target=pyproject.toml \
|
||||
--mount=type=bind,source=packages/openpi-client/pyproject.toml,target=packages/openpi-client/pyproject.toml \
|
||||
--mount=type=bind,source=packages/openpi-client/src,target=packages/openpi-client/src \
|
||||
GIT_LFS_SKIP_SMUDGE=1 uv sync --frozen --no-install-project --no-dev
|
||||
|
||||
# Copy transformers_replace files while preserving directory structure
|
||||
COPY src/openpi/models_pytorch/transformers_replace/ /tmp/transformers_replace/
|
||||
RUN /.venv/bin/python -c "import transformers; print(transformers.__file__)" | xargs dirname | xargs -I{} cp -r /tmp/transformers_replace/* {} && rm -rf /tmp/transformers_replace
|
||||
|
||||
CMD /bin/bash -c "uv run scripts/serve_policy.py $SERVER_ARGS"
|
||||
27
policy/openpi-InternData-A1/scripts/download_paligemma.py
Normal file
27
policy/openpi-InternData-A1/scripts/download_paligemma.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
def download_from_gcs(gcs_uri: str, local_path: str):
|
||||
local_path = Path(local_path)
|
||||
local_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if os.system("which gsutil > /dev/null 2>&1") == 0:
|
||||
cmd = f"gsutil cp {gcs_uri} {local_path}"
|
||||
else:
|
||||
gcs_http = gcs_uri.replace("gs://", "https://storage.googleapis.com/")
|
||||
cmd = f"wget -O {local_path} {gcs_http}"
|
||||
|
||||
print(f"⬇️ Executing: {cmd}")
|
||||
ret = os.system(cmd)
|
||||
if ret == 0:
|
||||
print("✅ Download complete:", local_path)
|
||||
else:
|
||||
raise RuntimeError(f"Download failed: {gcs_uri}")
|
||||
|
||||
return local_path
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
gcs_uri = "gs://vertex-model-garden-paligemma-us/paligemma/pt_224.npz"
|
||||
save_path = "checkpoints/jax/paligemma/pt_224.npz"
|
||||
download_from_gcs(gcs_uri, save_path)
|
||||
122
policy/openpi-InternData-A1/scripts/serve_policy.py
Normal file
122
policy/openpi-InternData-A1/scripts/serve_policy.py
Normal file
@@ -0,0 +1,122 @@
|
||||
import dataclasses
|
||||
import enum
|
||||
import logging
|
||||
import socket
|
||||
|
||||
import tyro
|
||||
|
||||
from openpi.policies import policy as _policy
|
||||
from openpi.policies import policy_config as _policy_config
|
||||
from openpi.serving import websocket_policy_server
|
||||
from openpi.training import config as _config
|
||||
|
||||
|
||||
class EnvMode(enum.Enum):
|
||||
"""Supported environments."""
|
||||
|
||||
ALOHA = "aloha"
|
||||
ALOHA_SIM = "aloha_sim"
|
||||
DROID = "droid"
|
||||
LIBERO = "libero"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Checkpoint:
|
||||
"""Load a policy from a trained checkpoint."""
|
||||
|
||||
# Training config name (e.g., "pi0_aloha_sim").
|
||||
config: str
|
||||
# Checkpoint directory (e.g., "checkpoints/pi0_aloha_sim/exp/10000").
|
||||
dir: str
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Default:
|
||||
"""Use the default policy for the given environment."""
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Args:
|
||||
"""Arguments for the serve_policy script."""
|
||||
|
||||
# Environment to serve the policy for. This is only used when serving default policies.
|
||||
env: EnvMode = EnvMode.ALOHA_SIM
|
||||
|
||||
# If provided, will be used in case the "prompt" key is not present in the data, or if the model doesn't have a default
|
||||
# prompt.
|
||||
default_prompt: str | None = None
|
||||
|
||||
# Port to serve the policy on.
|
||||
port: int = 8000
|
||||
# Record the policy's behavior for debugging.
|
||||
record: bool = False
|
||||
|
||||
# Specifies how to load the policy. If not provided, the default policy for the environment will be used.
|
||||
policy: Checkpoint | Default = dataclasses.field(default_factory=Default)
|
||||
|
||||
|
||||
# Default checkpoints that should be used for each environment.
|
||||
DEFAULT_CHECKPOINT: dict[EnvMode, Checkpoint] = {
|
||||
EnvMode.ALOHA: Checkpoint(
|
||||
config="pi05_aloha",
|
||||
dir="gs://openpi-assets/checkpoints/pi05_base",
|
||||
),
|
||||
EnvMode.ALOHA_SIM: Checkpoint(
|
||||
config="pi0_aloha_sim",
|
||||
dir="gs://openpi-assets/checkpoints/pi0_aloha_sim",
|
||||
),
|
||||
EnvMode.DROID: Checkpoint(
|
||||
config="pi05_droid",
|
||||
dir="gs://openpi-assets/checkpoints/pi05_droid",
|
||||
),
|
||||
EnvMode.LIBERO: Checkpoint(
|
||||
config="pi05_libero",
|
||||
dir="gs://openpi-assets/checkpoints/pi05_libero",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def create_default_policy(env: EnvMode, *, default_prompt: str | None = None) -> _policy.Policy:
|
||||
"""Create a default policy for the given environment."""
|
||||
if checkpoint := DEFAULT_CHECKPOINT.get(env):
|
||||
return _policy_config.create_trained_policy(
|
||||
_config.get_config(checkpoint.config), checkpoint.dir, default_prompt=default_prompt
|
||||
)
|
||||
raise ValueError(f"Unsupported environment mode: {env}")
|
||||
|
||||
|
||||
def create_policy(args: Args) -> _policy.Policy:
|
||||
"""Create a policy from the given arguments."""
|
||||
match args.policy:
|
||||
case Checkpoint():
|
||||
return _policy_config.create_trained_policy(
|
||||
_config.get_config(args.policy.config), args.policy.dir, default_prompt=args.default_prompt
|
||||
)
|
||||
case Default():
|
||||
return create_default_policy(args.env, default_prompt=args.default_prompt)
|
||||
|
||||
|
||||
def main(args: Args) -> None:
|
||||
policy = create_policy(args)
|
||||
policy_metadata = policy.metadata
|
||||
|
||||
# Record the policy's behavior.
|
||||
if args.record:
|
||||
policy = _policy.PolicyRecorder(policy, "policy_records")
|
||||
|
||||
hostname = socket.gethostname()
|
||||
local_ip = socket.gethostbyname(hostname)
|
||||
logging.info("Creating server (host: %s, ip: %s)", hostname, local_ip)
|
||||
|
||||
server = websocket_policy_server.WebsocketPolicyServer(
|
||||
policy=policy,
|
||||
host="0.0.0.0",
|
||||
port=args.port,
|
||||
metadata=policy_metadata,
|
||||
)
|
||||
server.serve_forever()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO, force=True)
|
||||
main(tyro.cli(Args))
|
||||
290
policy/openpi-InternData-A1/scripts/train.py
Normal file
290
policy/openpi-InternData-A1/scripts/train.py
Normal file
@@ -0,0 +1,290 @@
|
||||
import dataclasses
|
||||
import functools
|
||||
import logging
|
||||
import platform
|
||||
from typing import Any
|
||||
|
||||
import etils.epath as epath
|
||||
import flax.nnx as nnx
|
||||
from flax.training import common_utils
|
||||
import flax.traverse_util as traverse_util
|
||||
import jax
|
||||
import jax.experimental
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
import optax
|
||||
import tqdm_loggable.auto as tqdm
|
||||
import wandb
|
||||
|
||||
import openpi.models.model as _model
|
||||
import openpi.shared.array_typing as at
|
||||
import openpi.shared.nnx_utils as nnx_utils
|
||||
import openpi.training.checkpoints as _checkpoints
|
||||
import openpi.training.config as _config
|
||||
import openpi.training.data_loader as _data_loader
|
||||
import openpi.training.optimizer as _optimizer
|
||||
import openpi.training.sharding as sharding
|
||||
import openpi.training.utils as training_utils
|
||||
import openpi.training.weight_loaders as _weight_loaders
|
||||
from memory_profiler import profile
|
||||
import psutil
|
||||
from openpi.shared.online_compute_norm_stats import compute_norm_stats
|
||||
|
||||
def init_logging():
|
||||
"""Custom logging format for better readability."""
|
||||
level_mapping = {"DEBUG": "D", "INFO": "I", "WARNING": "W", "ERROR": "E", "CRITICAL": "C"}
|
||||
|
||||
class CustomFormatter(logging.Formatter):
|
||||
def format(self, record):
|
||||
record.levelname = level_mapping.get(record.levelname, record.levelname)
|
||||
return super().format(record)
|
||||
|
||||
formatter = CustomFormatter(
|
||||
fmt="%(asctime)s.%(msecs)03d [%(levelname)s] %(message)-80s (%(process)d:%(filename)s:%(lineno)s)",
|
||||
datefmt="%H:%M:%S",
|
||||
)
|
||||
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.INFO)
|
||||
logger.handlers[0].setFormatter(formatter)
|
||||
|
||||
|
||||
def init_wandb(config: _config.TrainConfig, *, resuming: bool, log_code: bool = False, enabled: bool = True):
|
||||
if not enabled:
|
||||
wandb.init(mode="disabled")
|
||||
return
|
||||
|
||||
ckpt_dir = config.checkpoint_dir
|
||||
if not ckpt_dir.exists():
|
||||
raise FileNotFoundError(f"Checkpoint directory {ckpt_dir} does not exist.")
|
||||
if resuming:
|
||||
run_id = (ckpt_dir / "wandb_id.txt").read_text().strip()
|
||||
wandb.init(id=run_id, resume="must", project=config.project_name)
|
||||
else:
|
||||
wandb.init(
|
||||
name=config.exp_name,
|
||||
config=dataclasses.asdict(config),
|
||||
project=config.project_name,
|
||||
)
|
||||
(ckpt_dir / "wandb_id.txt").write_text(wandb.run.id)
|
||||
|
||||
if log_code:
|
||||
wandb.run.log_code(epath.Path(__file__).parent.parent)
|
||||
|
||||
|
||||
def _load_weights_and_validate(loader: _weight_loaders.WeightLoader, params_shape: at.Params) -> at.Params:
|
||||
"""Loads and validates the weights. Returns a loaded subset of the weights."""
|
||||
loaded_params = loader.load(params_shape)
|
||||
at.check_pytree_equality(expected=params_shape, got=loaded_params, check_shapes=True, check_dtypes=True)
|
||||
|
||||
# Remove jax.ShapeDtypeStruct from the loaded params. This makes sure that only the loaded params are returned.
|
||||
return traverse_util.unflatten_dict(
|
||||
{k: v for k, v in traverse_util.flatten_dict(loaded_params).items() if not isinstance(v, jax.ShapeDtypeStruct)}
|
||||
)
|
||||
|
||||
|
||||
@at.typecheck
|
||||
def init_train_state(
|
||||
config: _config.TrainConfig, init_rng: at.KeyArrayLike, mesh: jax.sharding.Mesh, *, resume: bool
|
||||
) -> tuple[training_utils.TrainState, Any]:
|
||||
tx = _optimizer.create_optimizer(config.optimizer, config.lr_schedule, weight_decay_mask=None)
|
||||
|
||||
def init(rng: at.KeyArrayLike, partial_params: at.Params | None = None) -> training_utils.TrainState:
|
||||
rng, model_rng = jax.random.split(rng)
|
||||
# initialize the model (and its parameters).
|
||||
model = config.model.create(model_rng)
|
||||
|
||||
# Merge the partial params into the model.
|
||||
if partial_params is not None:
|
||||
graphdef, state = nnx.split(model)
|
||||
# This will produce an error if the partial params are not a subset of the state.
|
||||
state.replace_by_pure_dict(partial_params)
|
||||
model = nnx.merge(graphdef, state)
|
||||
|
||||
params = nnx.state(model)
|
||||
# Convert frozen params to bfloat16.
|
||||
params = nnx_utils.state_map(params, config.freeze_filter, lambda p: p.replace(p.value.astype(jnp.bfloat16)))
|
||||
|
||||
return training_utils.TrainState(
|
||||
step=0,
|
||||
params=params,
|
||||
model_def=nnx.graphdef(model),
|
||||
tx=tx,
|
||||
opt_state=tx.init(params.filter(config.trainable_filter)),
|
||||
ema_decay=config.ema_decay,
|
||||
ema_params=None if config.ema_decay is None else params,
|
||||
)
|
||||
|
||||
train_state_shape = jax.eval_shape(init, init_rng)
|
||||
state_sharding = sharding.fsdp_sharding(train_state_shape, mesh, log=True)
|
||||
|
||||
if resume:
|
||||
return train_state_shape, state_sharding
|
||||
|
||||
partial_params = _load_weights_and_validate(config.weight_loader, train_state_shape.params.to_pure_dict())
|
||||
replicated_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
|
||||
|
||||
# Initialize the train state and mix in the partial params.
|
||||
train_state = jax.jit(
|
||||
init,
|
||||
donate_argnums=(1,), # donate the partial params buffer.
|
||||
in_shardings=replicated_sharding,
|
||||
out_shardings=state_sharding,
|
||||
)(init_rng, partial_params)
|
||||
|
||||
return train_state, state_sharding
|
||||
|
||||
|
||||
@at.typecheck
|
||||
def train_step(
|
||||
config: _config.TrainConfig,
|
||||
rng: at.KeyArrayLike,
|
||||
state: training_utils.TrainState,
|
||||
batch: tuple[_model.Observation, _model.Actions],
|
||||
) -> tuple[training_utils.TrainState, dict[str, at.Array]]:
|
||||
model = nnx.merge(state.model_def, state.params)
|
||||
model.train()
|
||||
|
||||
@at.typecheck
|
||||
def loss_fn(
|
||||
model: _model.BaseModel, rng: at.KeyArrayLike, observation: _model.Observation, actions: _model.Actions
|
||||
):
|
||||
chunked_loss = model.compute_loss(rng, observation, actions, train=True)
|
||||
return jnp.mean(chunked_loss)
|
||||
|
||||
train_rng = jax.random.fold_in(rng, state.step)
|
||||
observation, actions = batch
|
||||
|
||||
# Filter out frozen params.
|
||||
diff_state = nnx.DiffState(0, config.trainable_filter)
|
||||
loss, grads = nnx.value_and_grad(loss_fn, argnums=diff_state)(model, train_rng, observation, actions)
|
||||
|
||||
params = state.params.filter(config.trainable_filter)
|
||||
updates, new_opt_state = state.tx.update(grads, state.opt_state, params)
|
||||
new_params = optax.apply_updates(params, updates)
|
||||
|
||||
# Update the model in place and return the new full state.
|
||||
nnx.update(model, new_params)
|
||||
new_params = nnx.state(model)
|
||||
|
||||
new_state = dataclasses.replace(state, step=state.step + 1, params=new_params, opt_state=new_opt_state)
|
||||
if state.ema_decay is not None:
|
||||
new_state = dataclasses.replace(
|
||||
new_state,
|
||||
ema_params=jax.tree.map(
|
||||
lambda old, new: state.ema_decay * old + (1 - state.ema_decay) * new, state.ema_params, new_params
|
||||
),
|
||||
)
|
||||
|
||||
# Filter out params that aren't kernels.
|
||||
kernel_params = nnx.state(
|
||||
model,
|
||||
nnx.All(
|
||||
nnx.Param,
|
||||
nnx.Not(nnx_utils.PathRegex(".*/(bias|scale|pos_embedding|input_embedding)")),
|
||||
lambda _, x: x.value.ndim > 1,
|
||||
),
|
||||
)
|
||||
info = {
|
||||
"loss": loss,
|
||||
"grad_norm": optax.global_norm(grads),
|
||||
"param_norm": optax.global_norm(kernel_params),
|
||||
}
|
||||
return new_state, info
|
||||
|
||||
|
||||
def main(config: _config.TrainConfig):
|
||||
init_logging()
|
||||
logging.info(f"Running on: {platform.node()}")
|
||||
|
||||
if config.batch_size % jax.device_count() != 0:
|
||||
raise ValueError(
|
||||
f"Batch size {config.batch_size} must be divisible by the number of devices {jax.device_count()}."
|
||||
)
|
||||
|
||||
jax.config.update("jax_compilation_cache_dir", str(epath.Path("~/.cache/jax").expanduser()))
|
||||
|
||||
rng = jax.random.key(config.seed)
|
||||
train_rng, init_rng = jax.random.split(rng)
|
||||
|
||||
mesh = sharding.make_mesh(config.fsdp_devices)
|
||||
data_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(sharding.DATA_AXIS))
|
||||
replicated_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
|
||||
|
||||
checkpoint_manager, resuming = _checkpoints.initialize_checkpoint_dir(
|
||||
config.checkpoint_dir,
|
||||
keep_period=config.keep_period,
|
||||
overwrite=config.overwrite,
|
||||
resume=config.resume,
|
||||
)
|
||||
init_wandb(config, resuming=resuming, enabled=config.wandb_enabled)
|
||||
|
||||
if config.online_compute_norm_stats:
|
||||
global_norm_stats = compute_norm_stats(config.name)
|
||||
else:
|
||||
global_norm_stats = None
|
||||
|
||||
data_loader = _data_loader.create_data_loader_multi(
|
||||
config,
|
||||
sharding=data_sharding,
|
||||
shuffle=True,
|
||||
global_norm_stats=global_norm_stats,
|
||||
)
|
||||
# @profile
|
||||
data_iter = iter(data_loader)
|
||||
batch = next(data_iter)
|
||||
logging.info(f"Initialized data loader:\n{training_utils.array_tree_to_info(batch)}")
|
||||
print(psutil.Process().memory_info().rss/1024**2)
|
||||
# set_trace()
|
||||
# Log images from first batch to sanity check.
|
||||
images_to_log = [
|
||||
wandb.Image(np.concatenate([np.array(img[i]) for img in batch[0].images.values()], axis=1))
|
||||
for i in range(min(5, len(next(iter(batch[0].images.values())))))
|
||||
]
|
||||
wandb.log({"camera_views": images_to_log}, step=0)
|
||||
|
||||
train_state, train_state_sharding = init_train_state(config, init_rng, mesh, resume=resuming)
|
||||
jax.block_until_ready(train_state)
|
||||
logging.info(f"Initialized train state:\n{training_utils.array_tree_to_info(train_state.params)}")
|
||||
|
||||
if resuming:
|
||||
train_state = _checkpoints.restore_state(checkpoint_manager, train_state, data_loader)
|
||||
|
||||
ptrain_step = jax.jit(
|
||||
functools.partial(train_step, config),
|
||||
in_shardings=(replicated_sharding, train_state_sharding, data_sharding),
|
||||
out_shardings=(train_state_sharding, replicated_sharding),
|
||||
donate_argnums=(1,),
|
||||
)
|
||||
|
||||
start_step = int(train_state.step)
|
||||
pbar = tqdm.tqdm(
|
||||
range(start_step, config.num_train_steps),
|
||||
initial=start_step,
|
||||
total=config.num_train_steps,
|
||||
dynamic_ncols=True,
|
||||
)
|
||||
|
||||
infos = []
|
||||
for step in pbar:
|
||||
with sharding.set_mesh(mesh):
|
||||
train_state, info = ptrain_step(train_rng, train_state, batch)
|
||||
infos.append(info)
|
||||
if step % config.log_interval == 0:
|
||||
stacked_infos = common_utils.stack_forest(infos)
|
||||
reduced_info = jax.device_get(jax.tree.map(jnp.mean, stacked_infos))
|
||||
info_str = ", ".join(f"{k}={v:.4f}" for k, v in reduced_info.items())
|
||||
pbar.write(f"Step {step}: {info_str}")
|
||||
wandb.log(reduced_info, step=step)
|
||||
infos = []
|
||||
batch = next(data_iter)
|
||||
|
||||
if (step % config.save_interval == 0 and step > start_step) or step == config.num_train_steps - 1:
|
||||
_checkpoints.save_state(checkpoint_manager, train_state, data_loader, step)
|
||||
|
||||
logging.info("Waiting for checkpoint manager to finish")
|
||||
checkpoint_manager.wait_until_finished()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(_config.cli())
|
||||
341
policy/openpi-InternData-A1/scripts/train_jax_multinode.py
Executable file
341
policy/openpi-InternData-A1/scripts/train_jax_multinode.py
Executable file
@@ -0,0 +1,341 @@
|
||||
"""
|
||||
Multi-host training entrypoint (JAX).
|
||||
|
||||
How to run multi-host (example: 2 nodes):
|
||||
# node0
|
||||
export JAX_COORDINATOR_ADDRESS=node0:12345
|
||||
export JAX_PROCESS_COUNT=2
|
||||
export JAX_PROCESS_INDEX=0
|
||||
uv run python scripts/train.py <config_name> --exp_name <exp>
|
||||
|
||||
# node1
|
||||
export JAX_COORDINATOR_ADDRESS=node0:12345
|
||||
export JAX_PROCESS_COUNT=2
|
||||
export JAX_PROCESS_INDEX=1
|
||||
uv run python scripts/train.py <config_name> --exp_name <exp>
|
||||
|
||||
Notes:
|
||||
- Initialize distributed BEFORE any device query.
|
||||
- Only process_index==0 performs side-effects (wandb, checkpoints, progress bar).
|
||||
- Total devices across hosts must be divisible by config.fsdp_devices.
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
import functools
|
||||
import logging
|
||||
import platform
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import etils.epath as epath
|
||||
import flax.nnx as nnx
|
||||
from flax.training import common_utils
|
||||
import flax.traverse_util as traverse_util
|
||||
import jax
|
||||
import jax.experimental
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
import optax
|
||||
import tqdm_loggable.auto as tqdm
|
||||
import wandb
|
||||
|
||||
import openpi.models.model as _model
|
||||
import openpi.shared.array_typing as at
|
||||
import openpi.shared.nnx_utils as nnx_utils
|
||||
import openpi.training.checkpoints as _checkpoints
|
||||
import openpi.training.config as _config
|
||||
import openpi.training.data_loader as _data_loader
|
||||
import openpi.training.optimizer as _optimizer
|
||||
import openpi.training.sharding as sharding
|
||||
import openpi.training.utils as training_utils
|
||||
import openpi.training.weight_loaders as _weight_loaders
|
||||
from pdb import set_trace
|
||||
|
||||
|
||||
def maybe_initialize_distributed() -> bool:
|
||||
coordinator = os.environ.get("JAX_COORDINATOR_ADDRESS")
|
||||
process_count = int(os.environ.get("JAX_PROCESS_COUNT", "1"))
|
||||
process_index = int(os.environ.get("JAX_PROCESS_INDEX", "0"))
|
||||
if process_count > 1 and coordinator:
|
||||
jax.distributed.initialize(
|
||||
coordinator_address=coordinator,
|
||||
num_processes=process_count,
|
||||
process_id=process_index,
|
||||
)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def init_logging():
|
||||
"""Custom logging format for better readability."""
|
||||
level_mapping = {"DEBUG": "D", "INFO": "I", "WARNING": "W", "ERROR": "E", "CRITICAL": "C"}
|
||||
|
||||
class CustomFormatter(logging.Formatter):
|
||||
def format(self, record):
|
||||
record.levelname = level_mapping.get(record.levelname, record.levelname)
|
||||
return super().format(record)
|
||||
|
||||
formatter = CustomFormatter(
|
||||
fmt="%(asctime)s.%(msecs)03d [%(levelname)s] %(message)-80s (%(process)d:%(filename)s:%(lineno)s)",
|
||||
datefmt="%H:%M:%S",
|
||||
)
|
||||
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.INFO)
|
||||
if not logger.handlers:
|
||||
ch = logging.StreamHandler()
|
||||
ch.setFormatter(formatter)
|
||||
logger.addHandler(ch)
|
||||
else:
|
||||
logger.handlers[0].setFormatter(formatter)
|
||||
|
||||
|
||||
def init_wandb(config: _config.TrainConfig, *, resuming: bool, log_code: bool = False, enabled: bool = True):
|
||||
if not enabled:
|
||||
wandb.init(mode="disabled")
|
||||
return
|
||||
|
||||
ckpt_dir = config.checkpoint_dir
|
||||
if not ckpt_dir.exists():
|
||||
raise FileNotFoundError(f"Checkpoint directory {ckpt_dir} does not exist.")
|
||||
if resuming:
|
||||
run_id = (ckpt_dir / "wandb_id.txt").read_text().strip()
|
||||
wandb.init(id=run_id, resume="must", project=config.project_name)
|
||||
else:
|
||||
wandb.init(
|
||||
name=config.exp_name,
|
||||
config=dataclasses.asdict(config),
|
||||
project=config.project_name,
|
||||
)
|
||||
(ckpt_dir / "wandb_id.txt").write_text(wandb.run.id)
|
||||
|
||||
if log_code:
|
||||
wandb.run.log_code(epath.Path(__file__).parent.parent)
|
||||
|
||||
|
||||
def _load_weights_and_validate(loader: _weight_loaders.WeightLoader, params_shape: at.Params) -> at.Params:
|
||||
"""Loads and validates the weights. Returns a loaded subset of the weights."""
|
||||
loaded_params = loader.load(params_shape)
|
||||
at.check_pytree_equality(expected=params_shape, got=loaded_params, check_shapes=True, check_dtypes=True)
|
||||
|
||||
# Remove jax.ShapeDtypeStruct from the loaded params. This makes sure that only the loaded params are returned.
|
||||
return traverse_util.unflatten_dict(
|
||||
{k: v for k, v in traverse_util.flatten_dict(loaded_params).items() if not isinstance(v, jax.ShapeDtypeStruct)}
|
||||
)
|
||||
|
||||
|
||||
@at.typecheck
|
||||
def init_train_state(
|
||||
config: _config.TrainConfig, init_rng: at.KeyArrayLike, mesh: jax.sharding.Mesh, *, resume: bool
|
||||
) -> tuple[training_utils.TrainState, Any]:
|
||||
tx = _optimizer.create_optimizer(config.optimizer, config.lr_schedule, weight_decay_mask=None)
|
||||
|
||||
def init(rng: at.KeyArrayLike, partial_params: at.Params | None = None) -> training_utils.TrainState:
|
||||
rng, model_rng = jax.random.split(rng)
|
||||
# initialize the model (and its parameters).
|
||||
model = config.model.create(model_rng)
|
||||
|
||||
# Merge the partial params into the model.
|
||||
if partial_params is not None:
|
||||
graphdef, state = nnx.split(model)
|
||||
# This will produce an error if the partial params are not a subset of the state.
|
||||
state.replace_by_pure_dict(partial_params)
|
||||
model = nnx.merge(graphdef, state)
|
||||
|
||||
params = nnx.state(model)
|
||||
# Convert frozen params to bfloat16.
|
||||
params = nnx_utils.state_map(params, config.freeze_filter, lambda p: p.replace(p.value.astype(jnp.bfloat16)))
|
||||
|
||||
return training_utils.TrainState(
|
||||
step=0,
|
||||
params=params,
|
||||
model_def=nnx.graphdef(model),
|
||||
tx=tx,
|
||||
opt_state=tx.init(params.filter(config.trainable_filter)),
|
||||
ema_decay=config.ema_decay,
|
||||
ema_params=None if config.ema_decay is None else params,
|
||||
)
|
||||
|
||||
train_state_shape = jax.eval_shape(init, init_rng)
|
||||
state_sharding = sharding.fsdp_sharding(train_state_shape, mesh, log=True)
|
||||
|
||||
if resume:
|
||||
return train_state_shape, state_sharding
|
||||
|
||||
partial_params = _load_weights_and_validate(config.weight_loader, train_state_shape.params.to_pure_dict())
|
||||
replicated_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
|
||||
|
||||
# Initialize the train state and mix in the partial params.
|
||||
train_state = jax.jit(
|
||||
init,
|
||||
donate_argnums=(1,), # donate the partial params buffer.
|
||||
in_shardings=replicated_sharding,
|
||||
out_shardings=state_sharding,
|
||||
)(init_rng, partial_params)
|
||||
|
||||
return train_state, state_sharding
|
||||
|
||||
|
||||
@at.typecheck
|
||||
def train_step(
|
||||
config: _config.TrainConfig,
|
||||
rng: at.KeyArrayLike,
|
||||
state: training_utils.TrainState,
|
||||
batch: tuple[_model.Observation, _model.Actions],
|
||||
) -> tuple[training_utils.TrainState, dict[str, at.Array]]:
|
||||
model = nnx.merge(state.model_def, state.params)
|
||||
model.train()
|
||||
|
||||
@at.typecheck
|
||||
def loss_fn(
|
||||
model: _model.BaseModel, rng: at.KeyArrayLike, observation: _model.Observation, actions: _model.Actions
|
||||
):
|
||||
chunked_loss = model.compute_loss(rng, observation, actions, train=True)
|
||||
return jnp.mean(chunked_loss)
|
||||
# set_trace()
|
||||
train_rng = jax.random.fold_in(rng, state.step)
|
||||
observation, actions = batch
|
||||
|
||||
# Filter out frozen params.
|
||||
diff_state = nnx.DiffState(0, config.trainable_filter)
|
||||
loss, grads = nnx.value_and_grad(loss_fn, argnums=diff_state)(model, train_rng, observation, actions)
|
||||
|
||||
params = state.params.filter(config.trainable_filter)
|
||||
updates, new_opt_state = state.tx.update(grads, state.opt_state, params)
|
||||
new_params = optax.apply_updates(params, updates)
|
||||
|
||||
# Update the model in place and return the new full state.
|
||||
nnx.update(model, new_params)
|
||||
new_params = nnx.state(model)
|
||||
|
||||
new_state = dataclasses.replace(state, step=state.step + 1, params=new_params, opt_state=new_opt_state)
|
||||
if state.ema_decay is not None:
|
||||
new_state = dataclasses.replace(
|
||||
new_state,
|
||||
ema_params=jax.tree.map(
|
||||
lambda old, new: state.ema_decay * old + (1 - state.ema_decay) * new, state.ema_params, new_params
|
||||
),
|
||||
)
|
||||
|
||||
# Filter out params that aren't kernels.
|
||||
kernel_params = nnx.state(
|
||||
model,
|
||||
nnx.All(
|
||||
nnx.Param,
|
||||
nnx.Not(nnx_utils.PathRegex(".*/(bias|scale|pos_embedding|input_embedding)")),
|
||||
lambda _, x: x.value.ndim > 1,
|
||||
),
|
||||
)
|
||||
info = {
|
||||
"loss": loss,
|
||||
"grad_norm": optax.global_norm(grads),
|
||||
"param_norm": optax.global_norm(kernel_params),
|
||||
}
|
||||
return new_state, info
|
||||
|
||||
|
||||
def main(config: _config.TrainConfig):
|
||||
init_logging()
|
||||
logging.info(f"Running on: {platform.node()}")
|
||||
|
||||
# Initialize multi-host distributed if environment variables are set
|
||||
distributed_initialized = maybe_initialize_distributed()
|
||||
is_main = jax.process_index() == 0
|
||||
|
||||
if config.batch_size % jax.device_count() != 0:
|
||||
raise ValueError(
|
||||
f"Batch size {config.batch_size} must be divisible by the number of devices {jax.device_count()}."
|
||||
)
|
||||
|
||||
jax.config.update("jax_compilation_cache_dir", str(epath.Path("~/.cache/jax").expanduser()))
|
||||
|
||||
rng = jax.random.key(config.seed)
|
||||
train_rng, init_rng = jax.random.split(rng)
|
||||
|
||||
mesh = sharding.make_mesh(config.fsdp_devices)
|
||||
data_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(sharding.DATA_AXIS))
|
||||
replicated_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
|
||||
|
||||
checkpoint_manager, resuming = _checkpoints.initialize_checkpoint_dir(
|
||||
config.checkpoint_dir,
|
||||
keep_period=config.keep_period,
|
||||
overwrite=config.overwrite,
|
||||
resume=config.resume,
|
||||
)
|
||||
init_wandb(config, resuming=resuming, enabled=(config.wandb_enabled and is_main))
|
||||
|
||||
data_loader = _data_loader.create_data_loader_multi(
|
||||
config,
|
||||
sharding=data_sharding,
|
||||
shuffle=True,
|
||||
)
|
||||
data_iter = iter(data_loader)
|
||||
batch = next(data_iter)
|
||||
logging.info(f"Initialized data loader:\n{training_utils.array_tree_to_info(batch)}")
|
||||
|
||||
# Note: Wandb image logging is disabled in multi-node setup to avoid potential hanging issues
|
||||
# caused by concurrent access to sharded arrays across processes.
|
||||
|
||||
train_state, train_state_sharding = init_train_state(config, init_rng, mesh, resume=resuming)
|
||||
jax.block_until_ready(train_state)
|
||||
logging.info(f"Initialized train state:\n{training_utils.array_tree_to_info(train_state.params)}")
|
||||
|
||||
if resuming:
|
||||
train_state = _checkpoints.restore_state(checkpoint_manager, train_state, data_loader)
|
||||
|
||||
ptrain_step = jax.jit(
|
||||
functools.partial(train_step, config),
|
||||
in_shardings=(replicated_sharding, train_state_sharding, data_sharding),
|
||||
out_shardings=(train_state_sharding, replicated_sharding),
|
||||
donate_argnums=(1,),
|
||||
)
|
||||
|
||||
start_step = int(train_state.step)
|
||||
step_iter = range(start_step, config.num_train_steps)
|
||||
pbar = (
|
||||
tqdm.tqdm(
|
||||
step_iter,
|
||||
initial=start_step,
|
||||
total=config.num_train_steps,
|
||||
dynamic_ncols=True,
|
||||
)
|
||||
if is_main
|
||||
else None
|
||||
)
|
||||
|
||||
infos = []
|
||||
for step in step_iter:
|
||||
with sharding.set_mesh(mesh):
|
||||
train_state, info = ptrain_step(train_rng, train_state, batch)
|
||||
if is_main and pbar is not None:
|
||||
pbar.update(1)
|
||||
infos.append(info)
|
||||
if step % config.log_interval == 0:
|
||||
# print("log!")
|
||||
stacked_infos = common_utils.stack_forest(infos)
|
||||
reduced_info = jax.device_get(jax.tree.map(jnp.mean, stacked_infos))
|
||||
if is_main:
|
||||
info_str = ", ".join(f"{k}={v:.4f}" for k, v in reduced_info.items())
|
||||
if pbar is not None:
|
||||
pbar.write(f"Step {step}: {info_str}")
|
||||
else:
|
||||
logging.info(f"Step {step}: {info_str}")
|
||||
if config.wandb_enabled:
|
||||
wandb.log(reduced_info, step=step)
|
||||
infos = []
|
||||
batch = next(data_iter)
|
||||
if ((step % config.save_interval == 0 and step > start_step) or step == config.num_train_steps - 1):
|
||||
_checkpoints.save_state(checkpoint_manager, train_state, data_loader, step)
|
||||
|
||||
if is_main:
|
||||
if pbar is not None:
|
||||
pbar.close()
|
||||
logging.info("Waiting for checkpoint manager to finish")
|
||||
checkpoint_manager.wait_until_finished()
|
||||
|
||||
if distributed_initialized:
|
||||
jax.distributed.shutdown()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(_config.cli())
|
||||
632
policy/openpi-InternData-A1/scripts/train_pytorch.py
Normal file
632
policy/openpi-InternData-A1/scripts/train_pytorch.py
Normal file
@@ -0,0 +1,632 @@
|
||||
"""
|
||||
PyTorch training entrypoint for PI0/PI05 with multi-GPU and multi-node (DDP) support.
|
||||
This script mirrors the behavior of the JAX trainer (`scripts/train.py`) but runs
|
||||
entirely in PyTorch using the `PI0Pytorch` model and your existing config/data
|
||||
pipeline from `src/openpi/training/config.py` and `src/openpi/training/data_loader.py`.
|
||||
|
||||
Usage
|
||||
Single GPU:
|
||||
python scripts/train_pytorch.py <config_name> --exp_name <run_name> --save_interval <interval>
|
||||
Example:
|
||||
python scripts/train_pytorch.py debug --exp_name pytorch_ddp_test
|
||||
python scripts/train_pytorch.py debug --exp_name pytorch_ddp_test --resume # Resume from latest checkpoint
|
||||
Multi-GPU (single node):
|
||||
torchrun --standalone --nnodes=1 --nproc_per_node=<num_gpus> scripts/train_pytorch.py <config_name> --exp_name <run_name>
|
||||
Example:
|
||||
torchrun --standalone --nnodes=1 --nproc_per_node=2 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_ddp_test
|
||||
torchrun --standalone --nnodes=1 --nproc_per_node=2 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_ddp_test --resume
|
||||
Multi-Node Training:
|
||||
torchrun \
|
||||
--nnodes=<num_nodes> --nproc_per_node=<gpus_per_node> --node_rank=<rank_of_node> \
|
||||
--master_addr=<master_ip> --master_port=<port> \
|
||||
scripts/train_pytorch.py <config_name> --exp_name=<run_name> --save_interval <interval>
|
||||
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
import gc
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import shutil
|
||||
import time
|
||||
|
||||
import jax
|
||||
import numpy as np
|
||||
import safetensors.torch
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.parallel
|
||||
import tqdm
|
||||
import wandb
|
||||
|
||||
import openpi.models.pi0_config
|
||||
import openpi.models_pytorch.pi0_pytorch
|
||||
import openpi.shared.normalize as _normalize
|
||||
import openpi.training.config as _config
|
||||
import openpi.training.data_loader as _data
|
||||
|
||||
|
||||
def init_logging():
|
||||
level_mapping = {"DEBUG": "D", "INFO": "I", "WARNING": "W", "ERROR": "E", "CRITICAL": "C"}
|
||||
|
||||
class CustomFormatter(logging.Formatter):
|
||||
def format(self, record):
|
||||
record.levelname = level_mapping.get(record.levelname, record.levelname)
|
||||
return super().format(record)
|
||||
|
||||
formatter = CustomFormatter(
|
||||
fmt="%(asctime)s.%(msecs)03d [%(levelname)s] %(message)-80s (%(process)d:%(filename)s:%(lineno)s)",
|
||||
datefmt="%H:%M:%S",
|
||||
)
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.INFO)
|
||||
if not logger.handlers:
|
||||
ch = logging.StreamHandler()
|
||||
ch.setFormatter(formatter)
|
||||
logger.addHandler(ch)
|
||||
else:
|
||||
logger.handlers[0].setFormatter(formatter)
|
||||
|
||||
|
||||
def init_wandb(config: _config.TrainConfig, *, resuming: bool, enabled: bool = True):
|
||||
"""Initialize wandb logging."""
|
||||
if not enabled:
|
||||
wandb.init(mode="disabled")
|
||||
return
|
||||
|
||||
ckpt_dir = config.checkpoint_dir
|
||||
if not ckpt_dir.exists():
|
||||
raise FileNotFoundError(f"Checkpoint directory {ckpt_dir} does not exist.")
|
||||
|
||||
if resuming:
|
||||
run_id = (ckpt_dir / "wandb_id.txt").read_text().strip()
|
||||
wandb.init(id=run_id, resume="must", project=config.project_name)
|
||||
else:
|
||||
wandb.init(
|
||||
name=config.exp_name,
|
||||
config=dataclasses.asdict(config),
|
||||
project=config.project_name,
|
||||
)
|
||||
(ckpt_dir / "wandb_id.txt").write_text(wandb.run.id)
|
||||
|
||||
|
||||
def setup_ddp():
|
||||
world_size = int(os.environ.get("WORLD_SIZE", "1"))
|
||||
use_ddp = world_size > 1
|
||||
if use_ddp and not torch.distributed.is_initialized():
|
||||
backend = "nccl" if torch.cuda.is_available() else "gloo"
|
||||
torch.distributed.init_process_group(backend=backend, init_method="env://")
|
||||
|
||||
# Set up debugging environment variables for DDP issues
|
||||
if os.environ.get("TORCH_DISTRIBUTED_DEBUG") is None:
|
||||
os.environ["TORCH_DISTRIBUTED_DEBUG"] = "INFO"
|
||||
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", os.environ.get("RANK", "0")))
|
||||
device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu")
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.set_device(device)
|
||||
return use_ddp, local_rank, device
|
||||
|
||||
|
||||
def cleanup_ddp():
|
||||
if torch.distributed.is_initialized():
|
||||
torch.distributed.barrier()
|
||||
torch.distributed.destroy_process_group()
|
||||
|
||||
|
||||
def set_seed(seed: int, local_rank: int):
|
||||
torch.manual_seed(seed + local_rank)
|
||||
np.random.seed(seed + local_rank)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed + local_rank)
|
||||
|
||||
|
||||
def build_datasets(config: _config.TrainConfig):
|
||||
# Use the unified data loader with PyTorch framework
|
||||
data_loader = _data.create_data_loader(config, framework="pytorch", shuffle=True)
|
||||
return data_loader, data_loader.data_config()
|
||||
|
||||
|
||||
def get_model_state_dict(model):
|
||||
"""Get state dict from model, handling DDP wrapper."""
|
||||
return (
|
||||
model.module.state_dict()
|
||||
if isinstance(model, torch.nn.parallel.DistributedDataParallel)
|
||||
else model.state_dict()
|
||||
)
|
||||
|
||||
|
||||
def get_model_parameters(model):
|
||||
"""Get parameters from model, handling DDP wrapper."""
|
||||
return (
|
||||
model.module.parameters()
|
||||
if isinstance(model, torch.nn.parallel.DistributedDataParallel)
|
||||
else model.parameters()
|
||||
)
|
||||
|
||||
|
||||
def save_checkpoint(model, optimizer, global_step, config, is_main, data_config):
|
||||
"""Save a checkpoint with model state, optimizer state, and metadata."""
|
||||
if not is_main:
|
||||
return
|
||||
|
||||
# Only save if it's time to save or if it's the final step
|
||||
if (global_step % config.save_interval == 0 and global_step > 0) or global_step == config.num_train_steps - 1:
|
||||
# Create temporary directory for atomic checkpoint saving
|
||||
final_ckpt_dir = config.checkpoint_dir / f"{global_step}"
|
||||
tmp_ckpt_dir = config.checkpoint_dir / f"tmp_{global_step}"
|
||||
|
||||
# Remove any existing temp directory and create new one
|
||||
if tmp_ckpt_dir.exists():
|
||||
shutil.rmtree(tmp_ckpt_dir)
|
||||
tmp_ckpt_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save model state using safetensors (handle shared tensors)
|
||||
model_to_save = model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model
|
||||
safetensors.torch.save_model(model_to_save, tmp_ckpt_dir / "model.safetensors")
|
||||
|
||||
# Save optimizer state using PyTorch format
|
||||
torch.save(optimizer.state_dict(), tmp_ckpt_dir / "optimizer.pt")
|
||||
|
||||
# Save training metadata (avoid saving full config to prevent JAX/Flax compatibility issues)
|
||||
metadata = {
|
||||
"global_step": global_step,
|
||||
"config": dataclasses.asdict(config),
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
torch.save(metadata, tmp_ckpt_dir / "metadata.pt")
|
||||
|
||||
# save norm stats
|
||||
norm_stats = data_config.norm_stats
|
||||
if norm_stats is not None and data_config.asset_id is not None:
|
||||
_normalize.save(tmp_ckpt_dir / "assets" / data_config.asset_id, norm_stats)
|
||||
|
||||
# Atomically move temp directory to final location
|
||||
if final_ckpt_dir.exists():
|
||||
shutil.rmtree(final_ckpt_dir)
|
||||
tmp_ckpt_dir.rename(final_ckpt_dir)
|
||||
|
||||
logging.info(f"Saved checkpoint at step {global_step} -> {final_ckpt_dir}")
|
||||
|
||||
# Log checkpoint to wandb
|
||||
if config.wandb_enabled:
|
||||
wandb.log({"checkpoint_step": global_step}, step=global_step)
|
||||
|
||||
|
||||
def load_checkpoint(model, optimizer, checkpoint_dir, device):
|
||||
"""Load the latest checkpoint and return the global step."""
|
||||
checkpoint_steps = [
|
||||
int(d.name)
|
||||
for d in checkpoint_dir.iterdir()
|
||||
if d.is_dir() and d.name.isdigit() and not d.name.startswith("tmp_")
|
||||
]
|
||||
|
||||
if not checkpoint_steps:
|
||||
raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}")
|
||||
|
||||
latest_step = max(checkpoint_steps)
|
||||
ckpt_dir = checkpoint_dir / f"{latest_step}"
|
||||
|
||||
# Clear memory before loading checkpoints
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
log_memory_usage(device, latest_step, "before_loading_checkpoint")
|
||||
|
||||
try:
|
||||
# Load model state with error handling
|
||||
logging.info("Loading model state...")
|
||||
safetensors_path = ckpt_dir / "model.safetensors"
|
||||
|
||||
if safetensors_path.exists():
|
||||
model_to_load = model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model
|
||||
safetensors.torch.load_model(model_to_load, safetensors_path, device=str(device))
|
||||
logging.info("Loaded model state from safetensors format")
|
||||
else:
|
||||
raise FileNotFoundError(f"No model checkpoint found at {ckpt_dir}")
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
log_memory_usage(device, latest_step, "after_loading_model")
|
||||
|
||||
# Load optimizer state with error handling
|
||||
logging.info("Loading optimizer state...")
|
||||
optimizer_path = ckpt_dir / "optimizer.pt"
|
||||
|
||||
if optimizer_path.exists():
|
||||
optimizer_state_dict = torch.load(optimizer_path, map_location=device, weights_only=False)
|
||||
logging.info("Loaded optimizer state from pt format")
|
||||
else:
|
||||
raise FileNotFoundError(f"No optimizer checkpoint found at {ckpt_dir}")
|
||||
|
||||
optimizer.load_state_dict(optimizer_state_dict)
|
||||
del optimizer_state_dict
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
log_memory_usage(device, latest_step, "after_loading_optimizer")
|
||||
|
||||
# Load metadata
|
||||
logging.info("Loading metadata...")
|
||||
metadata = torch.load(ckpt_dir / "metadata.pt", map_location=device, weights_only=False)
|
||||
global_step = metadata.get("global_step", latest_step)
|
||||
del metadata
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
log_memory_usage(device, latest_step, "after_loading_metadata")
|
||||
|
||||
logging.info(f"Successfully loaded all checkpoint components from step {latest_step}")
|
||||
return global_step
|
||||
|
||||
except RuntimeError as e:
|
||||
if "out of memory" in str(e):
|
||||
# Clear memory and provide detailed error message
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
logging.error(f"Out of memory error while loading checkpoint: {e!s}")
|
||||
log_memory_usage(device, latest_step, "after_oom_error")
|
||||
raise RuntimeError(
|
||||
"Out of memory while loading checkpoint. Try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True"
|
||||
) from e
|
||||
raise
|
||||
|
||||
|
||||
def get_latest_checkpoint_step(checkpoint_dir):
|
||||
"""Get the latest checkpoint step number from a checkpoint directory."""
|
||||
checkpoint_steps = [
|
||||
int(d.name)
|
||||
for d in checkpoint_dir.iterdir()
|
||||
if d.is_dir() and d.name.isdigit() and not d.name.startswith("tmp_")
|
||||
]
|
||||
return max(checkpoint_steps) if checkpoint_steps else None
|
||||
|
||||
|
||||
def log_memory_usage(device, step, phase="unknown"):
|
||||
"""Log detailed memory usage information."""
|
||||
if not torch.cuda.is_available():
|
||||
return
|
||||
|
||||
memory_allocated = torch.cuda.memory_allocated(device) / 1e9
|
||||
memory_reserved = torch.cuda.memory_reserved(device) / 1e9
|
||||
memory_free = torch.cuda.memory_reserved(device) - torch.cuda.memory_allocated(device)
|
||||
memory_free = memory_free / 1e9
|
||||
|
||||
# Get more detailed memory info
|
||||
memory_stats = torch.cuda.memory_stats(device)
|
||||
max_memory_allocated = memory_stats.get("allocated_bytes.all.peak", 0) / 1e9
|
||||
max_memory_reserved = memory_stats.get("reserved_bytes.all.peak", 0) / 1e9
|
||||
|
||||
# Get DDP info if available
|
||||
ddp_info = ""
|
||||
if dist.is_initialized():
|
||||
ddp_info = f" | DDP: rank={dist.get_rank()}, world_size={dist.get_world_size()}"
|
||||
|
||||
logging.info(
|
||||
f"Step {step} ({phase}): GPU memory - allocated: {memory_allocated:.2f}GB, reserved: {memory_reserved:.2f}GB, free: {memory_free:.2f}GB, peak_allocated: {max_memory_allocated:.2f}GB, peak_reserved: {max_memory_reserved:.2f}GB{ddp_info}"
|
||||
)
|
||||
|
||||
|
||||
def train_loop(config: _config.TrainConfig):
|
||||
use_ddp, local_rank, device = setup_ddp()
|
||||
is_main = (not use_ddp) or (dist.get_rank() == 0)
|
||||
set_seed(config.seed, local_rank)
|
||||
|
||||
# Initialize checkpoint directory and wandb
|
||||
resuming = False
|
||||
if config.resume:
|
||||
# Find checkpoint directory based on experiment name
|
||||
exp_checkpoint_dir = config.checkpoint_dir
|
||||
if exp_checkpoint_dir.exists():
|
||||
# Use validation to find the latest working checkpoint
|
||||
latest_step = get_latest_checkpoint_step(exp_checkpoint_dir)
|
||||
if latest_step is not None:
|
||||
resuming = True
|
||||
logging.info(
|
||||
f"Resuming from experiment checkpoint directory: {exp_checkpoint_dir} at step {latest_step}"
|
||||
)
|
||||
else:
|
||||
raise FileNotFoundError(f"No valid checkpoints found in {exp_checkpoint_dir} for resume")
|
||||
else:
|
||||
raise FileNotFoundError(f"Experiment checkpoint directory {exp_checkpoint_dir} does not exist for resume")
|
||||
elif config.overwrite and config.checkpoint_dir.exists():
|
||||
shutil.rmtree(config.checkpoint_dir)
|
||||
logging.info(f"Overwriting checkpoint directory: {config.checkpoint_dir}")
|
||||
|
||||
# Create checkpoint directory with experiment name
|
||||
if not resuming:
|
||||
# For new runs, create experiment-specific checkpoint directory
|
||||
exp_checkpoint_dir = config.checkpoint_dir
|
||||
exp_checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
||||
logging.info(f"Created experiment checkpoint directory: {exp_checkpoint_dir}")
|
||||
else:
|
||||
# For resume, checkpoint_dir is already set to the experiment directory
|
||||
logging.info(f"Using existing experiment checkpoint directory: {config.checkpoint_dir}")
|
||||
|
||||
# Initialize wandb (only on main process)
|
||||
if is_main:
|
||||
init_wandb(config, resuming=resuming, enabled=config.wandb_enabled)
|
||||
|
||||
# Build data loader using the unified data loader
|
||||
# Calculate effective batch size per GPU for DDP
|
||||
# For N GPUs, each GPU should get batch_size/N samples, so total across all GPUs is batch_size
|
||||
world_size = torch.distributed.get_world_size() if use_ddp else 1
|
||||
effective_batch_size = config.batch_size // world_size
|
||||
logging.info(
|
||||
f"Using batch size per GPU: {effective_batch_size} (total batch size across {world_size} GPUs: {config.batch_size})"
|
||||
)
|
||||
|
||||
# Pass the original batch size to data loader - it will handle DDP splitting internally
|
||||
loader, data_config = build_datasets(config)
|
||||
|
||||
# Log sample images to wandb on first batch
|
||||
if is_main and config.wandb_enabled and not resuming:
|
||||
# Create a separate data loader for sample batch to avoid consuming the main loader
|
||||
sample_data_loader = _data.create_data_loader(config, framework="pytorch", shuffle=False)
|
||||
sample_batch = next(iter(sample_data_loader))
|
||||
# Convert observation and actions to torch tensors
|
||||
observation, actions = sample_batch
|
||||
sample_batch = observation.to_dict()
|
||||
sample_batch["actions"] = actions
|
||||
|
||||
# Create sample images for wandb
|
||||
images_to_log = []
|
||||
# Get batch size from the first image tensor
|
||||
batch_size = next(iter(sample_batch["image"].values())).shape[0]
|
||||
for i in range(min(5, batch_size)):
|
||||
# Concatenate all camera views horizontally for this batch item
|
||||
# Convert from NCHW to NHWC format for wandb
|
||||
img_concatenated = torch.cat([img[i].permute(1, 2, 0) for img in sample_batch["image"].values()], axis=1)
|
||||
img_concatenated = img_concatenated.cpu().numpy()
|
||||
images_to_log.append(wandb.Image(img_concatenated))
|
||||
|
||||
wandb.log({"camera_views": images_to_log}, step=0)
|
||||
|
||||
# Clear sample batch from memory aggressively
|
||||
del sample_batch, observation, actions, images_to_log, img_concatenated
|
||||
del sample_data_loader # Also delete the sample data loader
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
logging.info("Cleared sample batch and data loader from memory")
|
||||
|
||||
# Build model
|
||||
if not isinstance(config.model, openpi.models.pi0_config.Pi0Config):
|
||||
# Convert dataclass to Pi0Config if needed
|
||||
model_cfg = openpi.models.pi0_config.Pi0Config(
|
||||
dtype=config.pytorch_training_precision,
|
||||
action_dim=config.model.action_dim,
|
||||
action_horizon=config.model.action_horizon,
|
||||
max_token_len=config.model.max_token_len,
|
||||
paligemma_variant=getattr(config.model, "paligemma_variant", "gemma_2b"),
|
||||
action_expert_variant=getattr(config.model, "action_expert_variant", "gemma_300m"),
|
||||
pi05=getattr(config.model, "pi05", False),
|
||||
)
|
||||
else:
|
||||
model_cfg = config.model
|
||||
# Update dtype to match pytorch_training_precision
|
||||
object.__setattr__(model_cfg, "dtype", config.pytorch_training_precision)
|
||||
|
||||
model = openpi.models_pytorch.pi0_pytorch.PI0Pytorch(model_cfg).to(device)
|
||||
|
||||
if hasattr(model, "gradient_checkpointing_enable"):
|
||||
enable_gradient_checkpointing = True
|
||||
model.gradient_checkpointing_enable()
|
||||
logging.info("Enabled gradient checkpointing for memory optimization")
|
||||
else:
|
||||
enable_gradient_checkpointing = False
|
||||
logging.info("Gradient checkpointing is not supported for this model")
|
||||
|
||||
# Log initial memory usage after model creation
|
||||
if is_main and torch.cuda.is_available():
|
||||
log_memory_usage(device, 0, "after_model_creation")
|
||||
|
||||
# Enable memory optimizations for large-scale training
|
||||
if world_size >= 8:
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
# Set memory allocation configuration
|
||||
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128,expandable_segments:True"
|
||||
logging.info("Enabled memory optimizations for 8+ GPU training")
|
||||
|
||||
if use_ddp:
|
||||
model = torch.nn.parallel.DistributedDataParallel(
|
||||
model,
|
||||
device_ids=[device.index] if device.type == "cuda" else None,
|
||||
find_unused_parameters=False, # Disable for memory efficiency
|
||||
gradient_as_bucket_view=True, # Enable for memory efficiency
|
||||
static_graph=world_size >= 8, # Enable for 8+ GPUs
|
||||
)
|
||||
|
||||
# Load weights from weight_loader if specified (for fine-tuning)
|
||||
# if config.pytorch_weight_path is not None:
|
||||
# logging.info(f"Loading weights from: {config.pytorch_weight_path}")
|
||||
|
||||
# model_path = os.path.join(config.pytorch_weight_path, "model.safetensors")
|
||||
# safetensors.torch.load_model(
|
||||
# (model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model), model_path
|
||||
# )
|
||||
# logging.info(f"Loaded PyTorch weights from {config.pytorch_weight_path}")
|
||||
|
||||
# Optimizer + learning rate schedule from config
|
||||
warmup_steps = config.lr_schedule.warmup_steps
|
||||
peak_lr = config.lr_schedule.peak_lr
|
||||
decay_steps = config.lr_schedule.decay_steps
|
||||
end_lr = config.lr_schedule.decay_lr
|
||||
|
||||
# Create optimizer with config parameters
|
||||
optim = torch.optim.AdamW(
|
||||
model.parameters(),
|
||||
lr=peak_lr,
|
||||
betas=(config.optimizer.b1, config.optimizer.b2),
|
||||
eps=config.optimizer.eps,
|
||||
weight_decay=config.optimizer.weight_decay,
|
||||
)
|
||||
|
||||
# Load checkpoint if resuming
|
||||
global_step = 0
|
||||
if resuming:
|
||||
global_step = load_checkpoint(model, optim, config.checkpoint_dir, device)
|
||||
logging.info(f"Resumed training from step {global_step}")
|
||||
|
||||
def lr_schedule(step: int):
|
||||
if step < warmup_steps:
|
||||
# Match JAX behavior: start from peak_lr / (warmup_steps + 1)
|
||||
init_lr = peak_lr / (warmup_steps + 1)
|
||||
return init_lr + (peak_lr - init_lr) * step / warmup_steps
|
||||
# cosine decay
|
||||
progress = min(1.0, (step - warmup_steps) / max(1, decay_steps - warmup_steps))
|
||||
cos = 0.5 * (1 + np.cos(np.pi * progress))
|
||||
return end_lr + (peak_lr - end_lr) * cos
|
||||
|
||||
model.train()
|
||||
start_time = time.time()
|
||||
infos = [] # Collect stats over log interval
|
||||
if is_main:
|
||||
logging.info(
|
||||
f"Running on: {platform.node()} | world_size={torch.distributed.get_world_size() if use_ddp else 1}"
|
||||
)
|
||||
logging.info(
|
||||
f"Training config: batch_size={config.batch_size}, effective_batch_size={effective_batch_size}, num_train_steps={config.num_train_steps}"
|
||||
)
|
||||
logging.info(f"Memory optimizations: gradient_checkpointing={enable_gradient_checkpointing}")
|
||||
logging.info(
|
||||
f"LR schedule: warmup={warmup_steps}, peak_lr={peak_lr:.2e}, decay_steps={decay_steps}, end_lr={end_lr:.2e}"
|
||||
)
|
||||
logging.info(
|
||||
f"Optimizer: {type(config.optimizer).__name__}, weight_decay={config.optimizer.weight_decay}, clip_norm={config.optimizer.clip_gradient_norm}"
|
||||
)
|
||||
logging.info("EMA is not supported for PyTorch training")
|
||||
logging.info(f"Training precision: {model_cfg.dtype}")
|
||||
|
||||
# Training loop - iterate until we reach num_train_steps
|
||||
pbar = (
|
||||
tqdm.tqdm(total=config.num_train_steps, initial=global_step, desc="Training", disable=not is_main)
|
||||
if is_main
|
||||
else None
|
||||
)
|
||||
|
||||
while global_step < config.num_train_steps:
|
||||
# Set epoch for distributed training
|
||||
if use_ddp and hasattr(loader, "set_epoch"):
|
||||
loader.set_epoch(global_step // len(loader))
|
||||
|
||||
for observation, actions in loader:
|
||||
# Check if we've reached the target number of steps
|
||||
if global_step >= config.num_train_steps:
|
||||
break
|
||||
|
||||
# The unified data loader returns (observation, actions) tuple
|
||||
observation = jax.tree.map(lambda x: x.to(device), observation) # noqa: PLW2901
|
||||
actions = actions.to(torch.float32) # noqa: PLW2901
|
||||
actions = actions.to(device) # noqa: PLW2901
|
||||
|
||||
# Update LR
|
||||
for pg in optim.param_groups:
|
||||
pg["lr"] = lr_schedule(global_step)
|
||||
|
||||
# Forward pass
|
||||
losses = model(observation, actions)
|
||||
# Ensure losses is a tensor and handle different return types
|
||||
if isinstance(losses, list | tuple):
|
||||
losses = torch.stack(losses)
|
||||
elif not isinstance(losses, torch.Tensor):
|
||||
losses = torch.tensor(losses, device=device, dtype=torch.float32)
|
||||
|
||||
loss = losses.mean()
|
||||
|
||||
# Backward pass
|
||||
loss.backward()
|
||||
|
||||
# Log memory usage after backward pass
|
||||
if global_step < 5 and is_main and torch.cuda.is_available():
|
||||
log_memory_usage(device, global_step, "after_backward")
|
||||
|
||||
# Gradient clipping
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.optimizer.clip_gradient_norm)
|
||||
|
||||
# Optimizer step
|
||||
optim.step()
|
||||
optim.zero_grad(set_to_none=True)
|
||||
|
||||
# Clear gradients more aggressively
|
||||
for param in model.parameters():
|
||||
if param.grad is not None:
|
||||
param.grad.detach_()
|
||||
param.grad = None
|
||||
|
||||
# Collect stats
|
||||
if is_main:
|
||||
infos.append(
|
||||
{
|
||||
"loss": loss.item(),
|
||||
"learning_rate": optim.param_groups[0]["lr"],
|
||||
"grad_norm": float(grad_norm) if isinstance(grad_norm, torch.Tensor) else grad_norm,
|
||||
}
|
||||
)
|
||||
|
||||
if is_main and (global_step % config.log_interval == 0):
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
# Average stats over log interval
|
||||
avg_loss = sum(info["loss"] for info in infos) / len(infos)
|
||||
avg_lr = sum(info["learning_rate"] for info in infos) / len(infos)
|
||||
|
||||
avg_grad_norm = None
|
||||
if any("grad_norm" in info for info in infos):
|
||||
vals = [
|
||||
info["grad_norm"] for info in infos if "grad_norm" in info and info["grad_norm"] is not None
|
||||
]
|
||||
if len(vals) > 0:
|
||||
avg_grad_norm = sum(vals) / len(vals)
|
||||
logging.info(
|
||||
f"step={global_step} loss={avg_loss:.4f} lr={avg_lr:.2e} grad_norm={avg_grad_norm:.2f} time={elapsed:.1f}s"
|
||||
if avg_grad_norm is not None
|
||||
else f"step={global_step} loss={avg_loss:.4f} lr={avg_lr:.2e} time={elapsed:.1f}s"
|
||||
)
|
||||
|
||||
# Log to wandb
|
||||
if config.wandb_enabled and len(infos) > 0:
|
||||
log_payload = {
|
||||
"loss": avg_loss,
|
||||
"learning_rate": avg_lr,
|
||||
"step": global_step,
|
||||
"time_per_step": elapsed / config.log_interval,
|
||||
}
|
||||
if avg_grad_norm is not None:
|
||||
log_payload["grad_norm"] = avg_grad_norm
|
||||
wandb.log(log_payload, step=global_step)
|
||||
|
||||
start_time = time.time()
|
||||
infos = [] # Reset stats collection
|
||||
|
||||
global_step += 1
|
||||
# Save checkpoint using the new mechanism
|
||||
save_checkpoint(model, optim, global_step, config, is_main, data_config)
|
||||
|
||||
# Update progress bar
|
||||
if pbar is not None:
|
||||
pbar.update(1)
|
||||
pbar.set_postfix(
|
||||
{"loss": f"{loss.item():.4f}", "lr": f"{optim.param_groups[0]['lr']:.2e}", "step": global_step}
|
||||
)
|
||||
|
||||
# Close progress bar
|
||||
if pbar is not None:
|
||||
pbar.close()
|
||||
|
||||
# Finish wandb run
|
||||
if is_main and config.wandb_enabled:
|
||||
wandb.finish()
|
||||
|
||||
cleanup_ddp()
|
||||
|
||||
|
||||
def main():
|
||||
init_logging()
|
||||
config = _config.cli()
|
||||
train_loop(config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
30
policy/openpi-InternData-A1/scripts/train_test.py
Normal file
30
policy/openpi-InternData-A1/scripts/train_test.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import dataclasses
|
||||
import os
|
||||
import pathlib
|
||||
|
||||
import pytest
|
||||
|
||||
os.environ["JAX_PLATFORMS"] = "cpu"
|
||||
|
||||
from openpi.training import config as _config
|
||||
|
||||
from . import train
|
||||
|
||||
|
||||
@pytest.mark.parametrize("config_name", ["debug"])
|
||||
def test_train(tmp_path: pathlib.Path, config_name: str):
|
||||
config = dataclasses.replace(
|
||||
_config._CONFIGS_DICT[config_name], # noqa: SLF001
|
||||
batch_size=2,
|
||||
checkpoint_base_dir=str(tmp_path / "checkpoint"),
|
||||
exp_name="test",
|
||||
overwrite=False,
|
||||
resume=False,
|
||||
num_train_steps=2,
|
||||
log_interval=1,
|
||||
)
|
||||
train.main(config)
|
||||
|
||||
# test resuming
|
||||
config = dataclasses.replace(config, resume=True, num_train_steps=4)
|
||||
train.main(config)
|
||||
209
policy/openpi-InternData-A1/scripts/training_scripts/multi_node.sh
Executable file
209
policy/openpi-InternData-A1/scripts/training_scripts/multi_node.sh
Executable file
@@ -0,0 +1,209 @@
|
||||
#!/usr/bin/env bash
|
||||
set -ex
|
||||
|
||||
cd YOUR_PATH/openpi
|
||||
|
||||
export USE_TF=0
|
||||
export USE_TORCH=0
|
||||
export USE_JAX=1
|
||||
export IMAGEIO_FFMPEG_EXE=ffmpeg
|
||||
# JAX GPU memory fraction
|
||||
export XLA_PYTHON_CLIENT_MEM_FRACTION="${XLA_PYTHON_CLIENT_MEM_FRACTION:-0.9}"
|
||||
|
||||
# ============================================================================
|
||||
# NCCL Configuration
|
||||
# ============================================================================
|
||||
export NCCL_ASYNC_ERROR_HANDLING=1
|
||||
export NCCL_TIMEOUT=3600
|
||||
export NCCL_DEBUG="${NCCL_DEBUG:-WARN}"
|
||||
|
||||
# ============================================================================
|
||||
# Platform-Injected Configuration
|
||||
# ============================================================================
|
||||
# The platform automatically injects these when DISTRIBUTED_JOB=true:
|
||||
# - NCCL_IB_HCA, NCCL_IB_GID_INDEX, NCCL_SOCKET_IFNAME
|
||||
# - NODE_RANK, NODE_COUNT, MASTER_ADDR, PROC_PER_NODE
|
||||
# - CUDA_VISIBLE_DEVICES
|
||||
# We trust and use these platform configurations directly.
|
||||
# ============================================================================
|
||||
|
||||
echo ""
|
||||
echo "=========================================="
|
||||
echo "Platform Configuration"
|
||||
echo "=========================================="
|
||||
echo "NODE_RANK: ${NODE_RANK:-<not set>}"
|
||||
echo "NODE_COUNT: ${NODE_COUNT:-<not set>}"
|
||||
echo "MASTER_ADDR: ${MASTER_ADDR:-<not set>}"
|
||||
echo "NCCL_IB_HCA: ${NCCL_IB_HCA:-<not set>}"
|
||||
echo "NCCL_IB_GID_INDEX: ${NCCL_IB_GID_INDEX:-<not set>}"
|
||||
echo "NCCL_SOCKET_IFNAME: ${NCCL_SOCKET_IFNAME:-<not set>}"
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
|
||||
# ============================================================================
|
||||
# NCCL Transport Configuration
|
||||
# ============================================================================
|
||||
# Use platform-injected configuration if available, otherwise fallback
|
||||
# ============================================================================
|
||||
|
||||
if [ -n "${NCCL_IB_HCA:-}" ]; then
|
||||
# Platform has configured InfiniBand
|
||||
echo "[NCCL] ✓ Using platform-injected InfiniBand configuration"
|
||||
|
||||
# Only set NCCL_NET if not already set
|
||||
if [ -z "${NCCL_NET:-}" ]; then
|
||||
export NCCL_NET="IB"
|
||||
fi
|
||||
|
||||
# Set IB timeout if not already set
|
||||
if [ -z "${NCCL_IB_TIMEOUT:-}" ]; then
|
||||
export NCCL_IB_TIMEOUT=23
|
||||
fi
|
||||
|
||||
echo "[NCCL] NCCL_NET: ${NCCL_NET}"
|
||||
echo "[NCCL] NCCL_IB_HCA: ${NCCL_IB_HCA}"
|
||||
echo "[NCCL] NCCL_IB_GID_INDEX: ${NCCL_IB_GID_INDEX}"
|
||||
echo "[NCCL] NCCL_IB_TIMEOUT: ${NCCL_IB_TIMEOUT}"
|
||||
|
||||
elif [ -n "${NCCL_SOCKET_IFNAME:-}" ]; then
|
||||
# Platform has configured Socket
|
||||
echo "[NCCL] ✓ Using platform-injected Socket configuration"
|
||||
|
||||
if [ -z "${NCCL_NET:-}" ]; then
|
||||
export NCCL_NET="Socket"
|
||||
fi
|
||||
|
||||
echo "[NCCL] NCCL_NET: ${NCCL_NET}"
|
||||
echo "[NCCL] NCCL_SOCKET_IFNAME: ${NCCL_SOCKET_IFNAME}"
|
||||
|
||||
else
|
||||
# No platform injection - use OPENPI_NCCL_NET preference
|
||||
echo "[NCCL] ⚠️ No platform-injected NCCL configuration"
|
||||
|
||||
if [ "${OPENPI_NCCL_NET:-IB}" = "IB" ]; then
|
||||
echo "[NCCL] ✗ InfiniBand requested but not configured by platform"
|
||||
echo "[NCCL] ✗ Falling back to Socket transport"
|
||||
export NCCL_NET="Socket"
|
||||
export NCCL_IB_DISABLE=1
|
||||
else
|
||||
export NCCL_NET="Socket"
|
||||
export NCCL_IB_DISABLE=1
|
||||
echo "[NCCL] Using Socket transport"
|
||||
fi
|
||||
fi
|
||||
|
||||
echo ""
|
||||
|
||||
# ============================================================================
|
||||
# JAX Distributed Configuration
|
||||
# ============================================================================
|
||||
# Map platform variables to JAX variables
|
||||
# ============================================================================
|
||||
|
||||
echo "=========================================="
|
||||
echo "JAX Distributed Configuration"
|
||||
echo "=========================================="
|
||||
|
||||
JAX_COORDINATOR_PORT="${JAX_COORDINATOR_PORT:-12345}"
|
||||
|
||||
# Set JAX coordinator address
|
||||
if [ -z "${JAX_COORDINATOR_ADDRESS:-}" ] && [ -n "${MASTER_ADDR:-}" ]; then
|
||||
export JAX_COORDINATOR_ADDRESS="${MASTER_ADDR}:${JAX_COORDINATOR_PORT}"
|
||||
echo "[JAX] ✓ Coordinator: ${JAX_COORDINATOR_ADDRESS} (from MASTER_ADDR)"
|
||||
elif [ -n "${JAX_COORDINATOR_ADDRESS:-}" ]; then
|
||||
echo "[JAX] ✓ Coordinator: ${JAX_COORDINATOR_ADDRESS}"
|
||||
else
|
||||
echo "[JAX] ✗ WARNING: No coordinator address set!"
|
||||
fi
|
||||
|
||||
# Set JAX process count
|
||||
if [ -z "${JAX_PROCESS_COUNT:-}" ] && [ -n "${NODE_COUNT:-}" ]; then
|
||||
export JAX_PROCESS_COUNT="${NODE_COUNT}"
|
||||
echo "[JAX] ✓ Process count: ${JAX_PROCESS_COUNT} (from NODE_COUNT)"
|
||||
elif [ -n "${JAX_PROCESS_COUNT:-}" ]; then
|
||||
echo "[JAX] ✓ Process count: ${JAX_PROCESS_COUNT}"
|
||||
fi
|
||||
|
||||
# Set JAX process index
|
||||
if [ -z "${JAX_PROCESS_INDEX:-}" ] && [ -n "${NODE_RANK:-}" ]; then
|
||||
export JAX_PROCESS_INDEX="${NODE_RANK}"
|
||||
echo "[JAX] ✓ Process index: ${JAX_PROCESS_INDEX} (from NODE_RANK)"
|
||||
elif [ -n "${JAX_PROCESS_INDEX:-}" ]; then
|
||||
echo "[JAX] ✓ Process index: ${JAX_PROCESS_INDEX}"
|
||||
fi
|
||||
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
|
||||
# ============================================================================
|
||||
# Python Environment
|
||||
# ============================================================================
|
||||
export PYTHONPATH=YOUR_PATH/openpi/src:YOUR_PATH/openpi/packages/openpi-client/src:YOUR_PATH/openpi/third_party/lerobot:${PYTHONPATH}
|
||||
conda activate pi0
|
||||
|
||||
# ============================================================================
|
||||
# Configuration Summary
|
||||
# ============================================================================
|
||||
|
||||
echo "=========================================="
|
||||
echo "Configuration Summary"
|
||||
echo "=========================================="
|
||||
echo "NCCL_NET: ${NCCL_NET:-<not set>}"
|
||||
echo "NCCL_IB_HCA: ${NCCL_IB_HCA:-<not set>}"
|
||||
echo "NCCL_IB_GID_INDEX: ${NCCL_IB_GID_INDEX:-<not set>}"
|
||||
echo "NCCL_SOCKET_IFNAME: ${NCCL_SOCKET_IFNAME:-<not set>}"
|
||||
echo "JAX_COORDINATOR: ${JAX_COORDINATOR_ADDRESS:-<not set>}"
|
||||
echo "JAX_PROCESS_COUNT: ${JAX_PROCESS_COUNT:-<not set>}"
|
||||
echo "JAX_PROCESS_INDEX: ${JAX_PROCESS_INDEX:-<not set>}"
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
|
||||
# ============================================================================
|
||||
# Display Host Information
|
||||
# ============================================================================
|
||||
|
||||
python - <<'EOF'
|
||||
import socket
|
||||
import os
|
||||
import jax
|
||||
hostname = socket.gethostname()
|
||||
devices = jax.local_devices()
|
||||
device_count = len(devices)
|
||||
device_ids = [d.id for d in devices]
|
||||
print(f"[JAX] host={hostname}, devices={device_count}xgpu, ids={device_ids}")
|
||||
print(f"[JAX] JAX_COORDINATOR_ADDRESS={os.environ.get('JAX_COORDINATOR_ADDRESS', '<not set>')}")
|
||||
print(f"[JAX] JAX_PROCESS_COUNT={os.environ.get('JAX_PROCESS_COUNT', '<not set>')}")
|
||||
print(f"[JAX] JAX_PROCESS_INDEX={os.environ.get('JAX_PROCESS_INDEX', '<not set>')}")
|
||||
EOF
|
||||
|
||||
# ============================================================================
|
||||
# Launch Training
|
||||
# ============================================================================
|
||||
|
||||
# Determine experiment name based on transport
|
||||
if [ "${OPENPI_DEBUG_SINGLE_GPU:-0}" = "1" ]; then
|
||||
EXP_NAME="${EXP_NAME:-dev_jax_single_gpu}"
|
||||
echo "[DEBUG] Running in single-GPU mode"
|
||||
else
|
||||
EXP_NAME="${EXP_NAME:-dev_jax_multinode_ib}"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "=========================================="
|
||||
echo "Starting Training"
|
||||
echo "=========================================="
|
||||
echo "Experiment: $EXP_NAME"
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
|
||||
ulimit -n 1000000
|
||||
|
||||
python scripts/train_jax_multinode.py \
|
||||
pretrain-interndata-a1 \
|
||||
--exp-name=pretrain-interndata-a1 \
|
||||
--num_workers=12 \
|
||||
--fsdp_devices=8 \
|
||||
--batch_size=512 \
|
||||
--num_train_steps=2000000 \
|
||||
--save_interval=5000
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
set -ex
|
||||
|
||||
export IMAGEIO_FFMPEG_EXE=ffmpeg
|
||||
export OMP_NUM_THREADS=128
|
||||
|
||||
export PYTHONPATH=YOUR_PATH/openpi/src:YOUR_PATH/openpi/packages/openpi-client/src:YOUR_PATH/openpi/third_party/lerobot:${PYTHONPATH}
|
||||
conda activate pi0
|
||||
|
||||
cd YOUR_PATH/openpi
|
||||
ulimit -n 1000000
|
||||
config_name=$1
|
||||
XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 python scripts/train.py ${config_name} \
|
||||
--exp-name=${config_name}
|
||||
Reference in New Issue
Block a user