Initial commit
This commit is contained in:
46
examples/droid/README.md
Normal file
46
examples/droid/README.md
Normal file
@@ -0,0 +1,46 @@
|
||||
# Run DROID
|
||||
|
||||
This example shows how to run the fine-tuned $\pi_0$-FAST-DROID model on the [DROID robot platform](https://github.com/droid-dataset/droid). We also offer a $\pi_0$-DROID model that is fine-tuned from $\pi_0$ and uses flow action decoding. You can use it by replacing `pi0_fast_droid` with `pi0_droid` in the commands below. In practice, we find that out-of-the-box, the $\pi_0$-FAST-DROID model is better at following language commands, so we recommend it as the default checkpoint for DROID evaluation. If you want to fine-tune on a DROID task that requires a fast-to-inference policy, you may still want to consider using the $\pi_0$-DROID model, since it decodes faster. For more details, please see the [FAST paper](https://pi.website/research/fast).
|
||||
|
||||
|
||||
## Step 1: Start a policy server
|
||||
|
||||
Since the DROID control laptop does not have a powerful GPU, we will start a remote policy server on a different machine with a more powerful GPU and then query it from the DROID control laptop during inference.
|
||||
|
||||
1. On a machine with a powerful GPU (~NVIDIA 4090), clone and install the `openpi` repository following the instructions in the [README](https://github.com/Physical-Intelligence/openpi).
|
||||
2. Start the OpenPI server via the following command:
|
||||
|
||||
```bash
|
||||
uv run scripts/serve_policy.py policy:checkpoint --policy.config=pi0_fast_droid --policy.dir=s3://openpi-assets/checkpoints/pi0_fast_droid
|
||||
```
|
||||
|
||||
You can also run the equivalent command below:
|
||||
|
||||
```bash
|
||||
uv run scripts/serve_policy.py --env=DROID
|
||||
```
|
||||
|
||||
## Step 2: Run the DROID robot
|
||||
|
||||
1. Make sure you have the most recent version of the DROID package installed on both the DROID control laptop and the NUC.
|
||||
2. On the control laptop, activate your DROID conda environment.
|
||||
3. Clone the openpi repo and install the openpi client, which we will use to connect to the policy server (this has very few dependencies and should be very fast to install): with the DROID conda environment activated, run `cd $OPENPI_ROOT/packages/openpi-client && pip install -e .`.
|
||||
4. Install `tyro`, which we will use for command line parsing: `pip install tyro`.
|
||||
5. Copy the `main.py` file from this directory to the `$DROID_ROOT/scripts` directory.
|
||||
6. Replace the camera IDs in the `main.py` file with the IDs of your cameras (you can find the camera IDs by running `ZED_Explore` in the command line, which will open a tool that shows you all connected cameras and their IDs -- you can also use it to make sure that the cameras are well-positioned to see the scene you want the robot to interact with).
|
||||
7. Run the `main.py` file. Make sure to point the IP and host address to the policy server. (To make sure the server machine is reachable from the DROID laptop, you can run `ping <server_ip>` from the DROID laptop.) Also make sure to specify the external camera to use for the policy (we only input one external camera), choose from ["left", "right"].
|
||||
|
||||
```bash
|
||||
python3 scripts/main.py --remote_host=<server_ip> --remote_port=<server_port> --external_camera="left"
|
||||
```
|
||||
|
||||
The script will ask you to enter a free-form language instruction for the robot to follow. Make sure to point the cameras at the scene you want the robot to interact with. You _do not_ need to carefully control camera angle, object positions, etc. The policy is fairly robust in our experience. Happy prompting!
|
||||
|
||||
# Troubleshooting
|
||||
|
||||
| Issue | Solution |
|
||||
|-------|----------|
|
||||
| Cannot reach policy server | Make sure the server is running and the IP and port are correct. You can check that the server machine is reachable by running `ping <server_ip>` from the DROID laptop. |
|
||||
| Cannot find cameras | Make sure the camera IDs are correct and that the cameras are connected to the DROID laptop. Sometimes replugging the cameras can help. You can check all connected cameras by running `ZED_Explore` in the command line. |
|
||||
| Policy inference is slow / inconsistent | Try using a wired internet connection for the DROID laptop to reduce latency (0.5 - 1 sec latency per chunk is normal). |
|
||||
| Policy does not perform the task well | In our experiments, the policy could perform simple table top manipulation tasks (pick-and-place) across a wide range of environments, camera positions, and lighting conditions. If the policy does not perform the task well, you can try modifying the scene or object placement to make the task easier. Also make sure that the camera view you are passing to the policy can see all relevant objects in the scene (the policy is only conditioned on a single external camera + wrist camera, make sure you are feeding the desired camera to the policy). Use `ZED_Explore` to check that the camera view you are passing to the policy can see all relevant objects in the scene. Finally, the policy is far from perfect and will fail on more complex manipulation tasks, but it usually makes a decent effort. :) |
|
||||
237
examples/droid/main.py
Normal file
237
examples/droid/main.py
Normal file
@@ -0,0 +1,237 @@
|
||||
# ruff: noqa
|
||||
|
||||
import contextlib
|
||||
import dataclasses
|
||||
import datetime
|
||||
import faulthandler
|
||||
import os
|
||||
import signal
|
||||
|
||||
from moviepy.editor import ImageSequenceClip
|
||||
import numpy as np
|
||||
from openpi_client import image_tools
|
||||
from openpi_client import websocket_client_policy
|
||||
import pandas as pd
|
||||
from PIL import Image
|
||||
from droid.robot_env import RobotEnv
|
||||
import tqdm
|
||||
import tyro
|
||||
|
||||
faulthandler.enable()
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Args:
|
||||
# Hardware parameters
|
||||
left_camera_id: str = "<your_camera_id>" # e.g., "24259877"
|
||||
right_camera_id: str = "<your_camera_id>" # e.g., "24514023"
|
||||
wrist_camera_id: str = "<your_camera_id>" # e.g., "13062452"
|
||||
|
||||
# Policy parameters
|
||||
external_camera: str | None = (
|
||||
None # which external camera should be fed to the policy, choose from ["left", "right"]
|
||||
)
|
||||
|
||||
# Rollout parameters
|
||||
max_timesteps: int = 600
|
||||
# How many actions to execute from a predicted action chunk before querying policy server again
|
||||
# 8 is usually a good default (equals 0.5 seconds of action execution).
|
||||
open_loop_horizon: int = 8
|
||||
|
||||
# Remote server parameters
|
||||
remote_host: str = "0.0.0.0" # point this to the IP address of the policy server, e.g., "192.168.1.100"
|
||||
remote_port: int = (
|
||||
8000 # point this to the port of the policy server, default server port for openpi servers is 8000
|
||||
)
|
||||
|
||||
|
||||
# We are using Ctrl+C to optionally terminate rollouts early -- however, if we press Ctrl+C while the policy server is
|
||||
# waiting for a new action chunk, it will raise an exception and the server connection dies.
|
||||
# This context manager temporarily prevents Ctrl+C and delays it after the server call is complete.
|
||||
@contextlib.contextmanager
|
||||
def prevent_keyboard_interrupt():
|
||||
"""Temporarily prevent keyboard interrupts by delaying them until after the protected code."""
|
||||
interrupted = False
|
||||
original_handler = signal.getsignal(signal.SIGINT)
|
||||
|
||||
def handler(signum, frame):
|
||||
nonlocal interrupted
|
||||
interrupted = True
|
||||
|
||||
signal.signal(signal.SIGINT, handler)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
signal.signal(signal.SIGINT, original_handler)
|
||||
if interrupted:
|
||||
raise KeyboardInterrupt
|
||||
|
||||
|
||||
def main(args: Args):
|
||||
# Make sure external camera is specified by user -- we only use one external camera for the policy
|
||||
assert (
|
||||
args.external_camera is not None and args.external_camera in ["left", "right"]
|
||||
), f"Please specify an external camera to use for the policy, choose from ['left', 'right'], but got {args.external_camera}"
|
||||
|
||||
# Initialize the Panda environment. Using joint velocity action space and gripper position action space is very important.
|
||||
env = RobotEnv(action_space="joint_velocity", gripper_action_space="position")
|
||||
print("Created the droid env!")
|
||||
|
||||
# Connect to the policy server
|
||||
policy_client = websocket_client_policy.WebsocketClientPolicy(args.remote_host, args.remote_port)
|
||||
|
||||
df = pd.DataFrame(columns=["success", "duration", "video_filename"])
|
||||
|
||||
while True:
|
||||
instruction = input("Enter instruction: ")
|
||||
|
||||
# Rollout parameters
|
||||
actions_from_chunk_completed = 0
|
||||
pred_action_chunk = None
|
||||
|
||||
# Prepare to save video of rollout
|
||||
timestamp = datetime.datetime.now().strftime("%Y_%m_%d_%H:%M:%S")
|
||||
video = []
|
||||
bar = tqdm.tqdm(range(args.max_timesteps))
|
||||
print("Running rollout... press Ctrl+C to stop early.")
|
||||
for t_step in bar:
|
||||
try:
|
||||
# Get the current observation
|
||||
curr_obs = _extract_observation(
|
||||
args,
|
||||
env.get_observation(),
|
||||
# Save the first observation to disk
|
||||
save_to_disk=t_step == 0,
|
||||
)
|
||||
|
||||
video.append(curr_obs[f"{args.external_camera}_image"])
|
||||
|
||||
# Send websocket request to policy server if it's time to predict a new chunk
|
||||
if actions_from_chunk_completed == 0 or actions_from_chunk_completed >= args.open_loop_horizon:
|
||||
actions_from_chunk_completed = 0
|
||||
|
||||
# We resize images on the robot laptop to minimize the amount of data sent to the policy server
|
||||
# and improve latency.
|
||||
request_data = {
|
||||
"observation/exterior_image_1_left": image_tools.resize_with_pad(
|
||||
curr_obs[f"{args.external_camera}_image"], 224, 224
|
||||
),
|
||||
"observation/wrist_image_left": image_tools.resize_with_pad(curr_obs["wrist_image"], 224, 224),
|
||||
"observation/joint_position": curr_obs["joint_position"],
|
||||
"observation/gripper_position": curr_obs["gripper_position"],
|
||||
"prompt": instruction,
|
||||
}
|
||||
|
||||
# Wrap the server call in a context manager to prevent Ctrl+C from interrupting it
|
||||
# Ctrl+C will be handled after the server call is complete
|
||||
with prevent_keyboard_interrupt():
|
||||
# this returns action chunk [10, 8] of 10 joint velocity actions (7) + gripper position (1)
|
||||
pred_action_chunk = policy_client.infer(request_data)["actions"]
|
||||
assert pred_action_chunk.shape == (10, 8)
|
||||
|
||||
# Select current action to execute from chunk
|
||||
action = pred_action_chunk[actions_from_chunk_completed]
|
||||
actions_from_chunk_completed += 1
|
||||
|
||||
# Binarize gripper action
|
||||
if action[-1].item() > 0.5:
|
||||
# action[-1] = 1.0
|
||||
action = np.concatenate([action[:-1], np.ones((1,))])
|
||||
else:
|
||||
# action[-1] = 0.0
|
||||
action = np.concatenate([action[:-1], np.zeros((1,))])
|
||||
|
||||
# clip all dimensions of action to [-1, 1]
|
||||
action = np.clip(action, -1, 1)
|
||||
|
||||
env.step(action)
|
||||
except KeyboardInterrupt:
|
||||
break
|
||||
|
||||
video = np.stack(video)
|
||||
save_filename = "video_" + timestamp
|
||||
ImageSequenceClip(list(video), fps=10).write_videofile(save_filename + ".mp4", codec="libx264")
|
||||
|
||||
success: str | float | None = None
|
||||
while not isinstance(success, float):
|
||||
success = input(
|
||||
"Did the rollout succeed? (enter y for 100%, n for 0%), or a numeric value 0-100 based on the evaluation spec"
|
||||
)
|
||||
if success == "y":
|
||||
success = 1.0
|
||||
elif success == "n":
|
||||
success = 0.0
|
||||
|
||||
success = float(success) / 100
|
||||
if not (0 <= success <= 1):
|
||||
print(f"Success must be a number in [0, 100] but got: {success * 100}")
|
||||
|
||||
df = df.append(
|
||||
{
|
||||
"success": success,
|
||||
"duration": t_step,
|
||||
"video_filename": save_filename,
|
||||
},
|
||||
ignore_index=True,
|
||||
)
|
||||
|
||||
if input("Do one more eval? (enter y or n) ").lower() != "y":
|
||||
break
|
||||
env.reset()
|
||||
|
||||
os.makedirs("results", exist_ok=True)
|
||||
timestamp = datetime.datetime.now().strftime("%I:%M%p_%B_%d_%Y")
|
||||
csv_filename = os.path.join("results", f"eval_{timestamp}.csv")
|
||||
df.to_csv(csv_filename)
|
||||
print(f"Results saved to {csv_filename}")
|
||||
|
||||
|
||||
def _extract_observation(args: Args, obs_dict, *, save_to_disk=False):
|
||||
image_observations = obs_dict["image"]
|
||||
left_image, right_image, wrist_image = None, None, None
|
||||
for key in image_observations:
|
||||
# Note the "left" below refers to the left camera in the stereo pair.
|
||||
# The model is only trained on left stereo cams, so we only feed those.
|
||||
if args.left_camera_id in key and "left" in key:
|
||||
left_image = image_observations[key]
|
||||
elif args.right_camera_id in key and "left" in key:
|
||||
right_image = image_observations[key]
|
||||
elif args.wrist_camera_id in key and "left" in key:
|
||||
wrist_image = image_observations[key]
|
||||
|
||||
# Drop the alpha dimension
|
||||
left_image = left_image[..., :3]
|
||||
right_image = right_image[..., :3]
|
||||
wrist_image = wrist_image[..., :3]
|
||||
|
||||
# Convert to RGB
|
||||
left_image = left_image[..., ::-1]
|
||||
right_image = right_image[..., ::-1]
|
||||
wrist_image = wrist_image[..., ::-1]
|
||||
|
||||
# In addition to image observations, also capture the proprioceptive state
|
||||
robot_state = obs_dict["robot_state"]
|
||||
cartesian_position = np.array(robot_state["cartesian_position"])
|
||||
joint_position = np.array(robot_state["joint_positions"])
|
||||
gripper_position = np.array([robot_state["gripper_position"]])
|
||||
|
||||
# Save the images to disk so that they can be viewed live while the robot is running
|
||||
# Create one combined image to make live viewing easy
|
||||
if save_to_disk:
|
||||
combined_image = np.concatenate([left_image, wrist_image, right_image], axis=1)
|
||||
combined_image = Image.fromarray(combined_image)
|
||||
combined_image.save("robot_camera_views.png")
|
||||
|
||||
return {
|
||||
"left_image": left_image,
|
||||
"right_image": right_image,
|
||||
"wrist_image": wrist_image,
|
||||
"cartesian_position": cartesian_position,
|
||||
"joint_position": joint_position,
|
||||
"gripper_position": gripper_position,
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args: Args = tyro.cli(Args)
|
||||
main(args)
|
||||
Reference in New Issue
Block a user