Files
starvla_benchmark/starvla_policy.py
2026-06-15 14:58:28 +08:00

272 lines
12 KiB
Python

import pickle
import time
import json
import numpy as np
from scipy.spatial.transform import Rotation as R
import requests
from fastsim.annotations.config_class import configclass, field
from fastsim.annotations.stereotype import stereotype
from fastsim.controllers.spawnable_controller import SpawnableController
from fastsim.controllers.visualize_controller import VisualizeController
from fastsim.unisim.robots.models.modular_robot import ModularRobot
from fastsim.utils.namespace import PoseVisualType, SimulatorType
from fastsim.unisim.robots.actuator_configs.grippers import GripperDriveJointConfig
from fastsim.extensions.benchmark.action import RobotAction
from fastsim.extensions.benchmark.benchmark import (
BenchmarkAction,
BenchmarkObservation,
ControlMode,
)
from fastsim.extensions.benchmark.policy import Policy, PolicyConfig
from fastsim.utils.log import Log
from fastsim.utils.pose import Pose
@configclass
@stereotype.register_config("starvla")
class StarvlaPolicyConfig(PolicyConfig):
robot_name: str = field(default="None", required=True, comment="The name of the robot")
visualize_action_ee_pose: bool = field(default=False, required=True, comment="Whether to visualize the action end effector pose")
visualize_state_ee_pose: bool = field(default=False, required=True, comment="Whether to visualize the state end effector pose")
visualize_bounding_box_targets: list[str] = field(
default_factory=list,
required=False,
comment="Spawnable object names to draw bounding boxes for (empty = off)",
)
sensor_names: list[str] = field(
default=["Hand_Camera", "Left_Camera", "Right_Camera"],
required=True,
comment="The names of the sensors"
)
server_url: str = field(
default="http://127.0.0.1:5000/policy",
required=True,
comment="StarVLA policy server url"
)
prompt: str = field(
default="pick the object",
required=True,
comment="task instruction"
)
run_trunk_size: int = field(
default=16,
required=True,
comment="The number of chunks to run in one inference step"
)
@stereotype.register_model("starvla")
class StarvlaPolicy(Policy):
def __init__(self, config: StarvlaPolicyConfig):
super().__init__(config)
self.robot_name = config.robot_name
self.sensor_names = config.sensor_names
self.server_url = config.server_url
self.prompt = config.prompt
self.visualize_action_ee_pose = config.visualize_action_ee_pose
self.visualize_state_ee_pose = config.visualize_state_ee_pose
self.visualize_bounding_box_targets = list(config.visualize_bounding_box_targets or [])
# prevent circular import
import pandas as pd
df_data = pd.read_parquet("/home/zhiyuan/zhujuan/datasets/add_remove_lid_15fps_10epi/data/chunk-000/file-000.parquet")
self.dummy_data = np.array(df_data.groupby('episode_index')['observation.state'].apply(list).to_dict()[0])
self.dummy_data_idx = 0
def reset(self) -> None:
self.current_state = {}
self.current_chunk_id = 0
self.current_chunk_result = None
self.run_trunk_size = self.config.run_trunk_size
self.robot: ModularRobot = SpawnableController.get_spawnable_data(self.robot_name).unwrap()
self.left_hand_joints = SpawnableController.control_robot(
self.robot_name,
"get_actuator_joint_names",
parameters={"actuator_name": "left_hand"},
).unwrap()
self.right_hand_joints = SpawnableController.control_robot(
self.robot_name,
"get_actuator_joint_names",
parameters={"actuator_name": "right_hand"},
).unwrap()
def warmup(self, benchmark_observation: BenchmarkObservation) -> None:
Log.info(f"Waiting for StarVLA inference server to be ready...")
while True:
try:
if requests.get(f"{self.server_url}/health", timeout=1.0).status_code == 200:
break
except Exception:
time.sleep(1)
Log.success(f"StarVLA inference server is ready.")
def needs_observation(self) -> bool:
return self.current_chunk_id == 0
def _handle_server_error(self, response: requests.Response) -> None:
if response.status_code == 500:
err_obj = pickle.loads(response.content)
Log.error(f"StarVLA server error: {err_obj['error']}")
Log.error(f"Traceback: {err_obj['traceback']}", exit=True)
elif response.status_code != 200:
Log.error(f"StarVLA server error with status code <{response.status_code}> : {response.text}", exit=True)
def split_joints(self, state_or_action, keys=None) -> list[dict]:
if keys is None:
keys = ["left_arm", "right_arm"]
total_dim = 31 * len(keys)
assert state_or_action.shape[-1] == total_dim, f"Expected last dimension to be {total_dim}, got {state_or_action.shape[-1]}"
joints_all = np.split(state_or_action, [31], axis=-1)
return_dict = {}
for key, joints in zip(keys, joints_all):
ee_pos, ee_rot6d, finger_qpos = np.split(joints, [3, 9], axis=-1)
return_dict[key] = {
"ee_pos": ee_pos,
"ee_rot6d": ee_rot6d,
"finger_qpos": finger_qpos
}
return return_dict
def preprocess_observation(self, benchmark_observation: BenchmarkObservation) -> dict:
robot_obs = benchmark_observation.get_robot_observations(self.robot_name)["robot_data"]
left_ee_pose_base = robot_obs["ee_pose"]["left_arm"]["base_frame"]
left_ee_position, left_ee_rot6d = left_ee_pose_base["position"], left_ee_pose_base["rot6d"]
right_ee_pose_base = robot_obs["ee_pose"]["right_arm"]["base_frame"]
right_ee_position, right_ee_rot6d = right_ee_pose_base["position"], right_ee_pose_base["rot6d"]
finger_positions = robot_obs["joint_positions"] # use finger joints(44) only
state = np.concatenate([left_ee_position, left_ee_rot6d, finger_positions[:22],
right_ee_position, right_ee_rot6d, finger_positions[22:]], axis=-1) # (62,)
rgb_data = {}
for sensor_name in self.sensor_names:
sensor_obs = benchmark_observation.get_sensor_observations(sensor_name)
rgb_data[sensor_name] = sensor_obs["rgb"].data.cpu().numpy().astype(np.uint8)
obs = {"state": state,"rgb": rgb_data,"prompt": self.prompt}
return obs
def compute_action(self, observation: dict) -> dict:
if self.current_chunk_result is None:
self.current_state.update(self.split_joints(observation["state"]))
payload = pickle.dumps(observation)
response = requests.post(
f"{self.server_url}/inference",
data=payload,
headers={"Content-Type": "application/octet-stream"}
)
self.test_obs = observation["state"] #TODO
self._handle_server_error(response)
result = pickle.loads(response.content)
max_trunk_size = len(result["right_arm"]["ee_delta_position_chunks"])
if self.run_trunk_size > max_trunk_size:
Log.warning(f"Run trunk size {self.run_trunk_size} is greater than the number of chunks {max_trunk_size}. Set run trunk size to {max_trunk_size}.")
self.run_trunk_size = max_trunk_size
self.run_trunk_size = max_trunk_size
self.current_chunk_result = result
else:
result = self.current_chunk_result
return result
def postprocess_action(self, action: dict) -> BenchmarkAction:
benchmark_action = BenchmarkAction()
read_chunk_size = 1
dummy_action = self.dummy_data[self.dummy_data_idx:(self.dummy_data_idx + read_chunk_size)]
if self.dummy_data_idx + read_chunk_size >= self.dummy_data.shape[0]:
self.dummy_data_idx = 0
exit(0)
else:
self.dummy_data_idx += read_chunk_size
read_chunk_id = 0
print(f'{self.current_chunk_id=}, {self.dummy_data_idx = }, {read_chunk_id=}')
time.sleep(1.0)
left_rpy_state = dummy_action[:, 3:6] # (3,)
right_rpy_state = dummy_action[:, 31:34] # (3,)
left_rot_state = R.from_euler('xyz', left_rpy_state).as_matrix()
right_rot_state = R.from_euler('xyz', right_rpy_state).as_matrix()
left_state_rot6d = np.concatenate([left_rot_state[:, 0], left_rot_state[:, 1]], axis=-1) # (6,)
right_state_rot6d = np.concatenate([right_rot_state[:, 0], right_rot_state[:, 1]], axis=-1) # (6,)
read_state = {"left_arm": {
"ee_position_chunks": dummy_action[:, :3].tolist(),
"ee_rot6d_chunks": left_state_rot6d.tolist(),
"finger_chunks": dummy_action[:, 6:28].tolist()},
"right_arm": {
"ee_position_chunks": dummy_action[:, 28:31].tolist(),
"ee_rot6d_chunks": right_state_rot6d.tolist(),
"finger_chunks": dummy_action[:, 34:56].tolist()}
}
for arm_key in self.robot['arms'].keys():
action_arm = action[arm_key]
delta_ee_pose = Pose(position=action_arm["ee_delta_position_chunks"][self.current_chunk_id], rot6d=action_arm["ee_delta_rot6d_chunks"][self.current_chunk_id])
curr_state_ee_pose = Pose(position=self.current_state[arm_key]["ee_pos"], rot6d=self.current_state[arm_key]["ee_rot6d"])
curr_action_ee_pose = curr_state_ee_pose * delta_ee_pose # action2base = state2base * action2state
finger_joint_qpos = action_arm["finger_chunks"][self.current_chunk_id] + self.current_state[arm_key]["finger_qpos"]
joint_names = self.left_hand_joints if arm_key == "left_arm" else self.right_hand_joints
state_arm = read_state[arm_key]
benchmark_action.add_robot_action(
RobotAction(
control_mode=ControlMode.POSITION,
robot_name=self.robot_name,
joint_names=joint_names,
# joint_positions=finger_joint_qpos
joint_positions=state_arm["finger_chunks"][read_chunk_id]
)
)
benchmark_action.add_robot_action(
RobotAction(
control_mode=ControlMode.EE_POSE,
robot_name=self.robot_name,
# ee_pose=curr_action_ee_pose,
ee_pose=Pose(position=state_arm["ee_position_chunks"][read_chunk_id], rot6d=state_arm["ee_rot6d_chunks"][read_chunk_id]),
arm_name=arm_key
)
)
self._visualize_base_frame_ee_poses(curr_state_ee_pose, curr_action_ee_pose)
self._visualize_bounding_boxes()
self.current_chunk_id += 1
if self.current_chunk_id == self.run_trunk_size:
self.current_chunk_id = 0
self.current_chunk_result = None
return benchmark_action
# ------------------- Visualization -------------------
def _visualize_base_frame_ee_poses(
self, pose_state_base: Pose, pose_action_base: Pose
) -> None:
if not self.visualize_action_ee_pose and not self.visualize_state_ee_pose:
return
robot_base_world = SpawnableController.control_robot(
self.robot_name, "get_pose"
).unwrap()
if self.visualize_state_ee_pose:
VisualizeController.create_pose_visualization(
robot_base_world * pose_state_base,
name=f"{self.robot_name}/starvla_state_ee",
simulator=SimulatorType.ISAACLAB,
pose_type=PoseVisualType.COORDINATE,
extra_params={"axis_length": 0.08, "thickness": 0.006},
).unwrap()
if self.visualize_action_ee_pose:
VisualizeController.create_pose_visualization(
robot_base_world * pose_action_base,
name=f"{self.robot_name}/starvla_action_ee",
simulator=SimulatorType.ISAACLAB,
pose_type=PoseVisualType.COORDINATE,
extra_params={"axis_length": 0.1, "thickness": 0.006},
).unwrap()
def _visualize_bounding_boxes(self) -> None:
if not self.visualize_bounding_box_targets:
return
for target_name in self.visualize_bounding_box_targets:
VisualizeController.visualize_target_bounding_box(
target_name, simulator=SimulatorType.ISAACLAB
).unwrap()