import pickle import time import json import numpy as np from scipy.spatial.transform import Rotation as R import requests from fastsim.annotations.config_class import configclass, field from fastsim.annotations.stereotype import stereotype from fastsim.controllers.spawnable_controller import SpawnableController from fastsim.controllers.visualize_controller import VisualizeController from fastsim.unisim.robots.models.modular_robot import ModularRobot from fastsim.utils.namespace import PoseVisualType, SimulatorType from fastsim.unisim.robots.actuator_configs.grippers import GripperDriveJointConfig from fastsim.extensions.benchmark.action import RobotAction from fastsim.extensions.benchmark.benchmark import ( BenchmarkAction, BenchmarkObservation, ControlMode, ) from fastsim.extensions.benchmark.policy import Policy, PolicyConfig from fastsim.utils.log import Log from fastsim.utils.pose import Pose @configclass @stereotype.register_config("starvla") class StarvlaPolicyConfig(PolicyConfig): robot_name: str = field(default="None", required=True, comment="The name of the robot") visualize_action_ee_pose: bool = field(default=False, required=True, comment="Whether to visualize the action end effector pose") visualize_state_ee_pose: bool = field(default=False, required=True, comment="Whether to visualize the state end effector pose") visualize_bounding_box_targets: list[str] = field( default_factory=list, required=False, comment="Spawnable object names to draw bounding boxes for (empty = off)", ) sensor_names: list[str] = field( default=["Hand_Camera", "Left_Camera", "Right_Camera"], required=True, comment="The names of the sensors" ) server_url: str = field( default="http://127.0.0.1:5000/policy", required=True, comment="StarVLA policy server url" ) prompt: str = field( default="pick the object", required=True, comment="task instruction" ) run_trunk_size: int = field( default=16, required=True, comment="The number of chunks to run in one inference step" ) @stereotype.register_model("starvla") class StarvlaPolicy(Policy): def __init__(self, config: StarvlaPolicyConfig): super().__init__(config) self.robot_name = config.robot_name self.sensor_names = config.sensor_names self.server_url = config.server_url self.prompt = config.prompt self.visualize_action_ee_pose = config.visualize_action_ee_pose self.visualize_state_ee_pose = config.visualize_state_ee_pose self.visualize_bounding_box_targets = list(config.visualize_bounding_box_targets or []) # prevent circular import import pandas as pd df_data = pd.read_parquet("/home/zhiyuan/zhujuan/datasets/add_remove_lid_15fps_10epi/data/chunk-000/file-000.parquet") self.dummy_data = np.array(df_data.groupby('episode_index')['observation.state'].apply(list).to_dict()[0]) self.dummy_data_idx = 0 def reset(self) -> None: self.current_state = {} self.current_chunk_id = 0 self.current_chunk_result = None self.run_trunk_size = self.config.run_trunk_size self.robot: ModularRobot = SpawnableController.get_spawnable_data(self.robot_name).unwrap() self.left_hand_joints = SpawnableController.control_robot( self.robot_name, "get_actuator_joint_names", parameters={"actuator_name": "left_hand"}, ).unwrap() self.right_hand_joints = SpawnableController.control_robot( self.robot_name, "get_actuator_joint_names", parameters={"actuator_name": "right_hand"}, ).unwrap() def warmup(self, benchmark_observation: BenchmarkObservation) -> None: Log.info(f"Waiting for StarVLA inference server to be ready...") while True: try: if requests.get(f"{self.server_url}/health", timeout=1.0).status_code == 200: break except Exception: time.sleep(1) Log.success(f"StarVLA inference server is ready.") def needs_observation(self) -> bool: return self.current_chunk_id == 0 def _handle_server_error(self, response: requests.Response) -> None: if response.status_code == 500: err_obj = pickle.loads(response.content) Log.error(f"StarVLA server error: {err_obj['error']}") Log.error(f"Traceback: {err_obj['traceback']}", exit=True) elif response.status_code != 200: Log.error(f"StarVLA server error with status code <{response.status_code}> : {response.text}", exit=True) def split_joints(self, state_or_action, keys=None) -> list[dict]: if keys is None: keys = ["left_arm", "right_arm"] total_dim = 31 * len(keys) assert state_or_action.shape[-1] == total_dim, f"Expected last dimension to be {total_dim}, got {state_or_action.shape[-1]}" joints_all = np.split(state_or_action, [31], axis=-1) return_dict = {} for key, joints in zip(keys, joints_all): ee_pos, ee_rot6d, finger_qpos = np.split(joints, [3, 9], axis=-1) return_dict[key] = { "ee_pos": ee_pos, "ee_rot6d": ee_rot6d, "finger_qpos": finger_qpos } return return_dict def preprocess_observation(self, benchmark_observation: BenchmarkObservation) -> dict: robot_obs = benchmark_observation.get_robot_observations(self.robot_name)["robot_data"] left_ee_pose_base = robot_obs["ee_pose"]["left_arm"]["base_frame"] left_ee_position, left_ee_rot6d = left_ee_pose_base["position"], left_ee_pose_base["rot6d"] right_ee_pose_base = robot_obs["ee_pose"]["right_arm"]["base_frame"] right_ee_position, right_ee_rot6d = right_ee_pose_base["position"], right_ee_pose_base["rot6d"] finger_positions = robot_obs["joint_positions"] # use finger joints(44) only state = np.concatenate([left_ee_position, left_ee_rot6d, finger_positions[:22], right_ee_position, right_ee_rot6d, finger_positions[22:]], axis=-1) # (62,) rgb_data = {} for sensor_name in self.sensor_names: sensor_obs = benchmark_observation.get_sensor_observations(sensor_name) rgb_data[sensor_name] = sensor_obs["rgb"].data.cpu().numpy().astype(np.uint8) obs = {"state": state,"rgb": rgb_data,"prompt": self.prompt} return obs def compute_action(self, observation: dict) -> dict: if self.current_chunk_result is None: self.current_state.update(self.split_joints(observation["state"])) payload = pickle.dumps(observation) response = requests.post( f"{self.server_url}/inference", data=payload, headers={"Content-Type": "application/octet-stream"} ) self.test_obs = observation["state"] #TODO self._handle_server_error(response) result = pickle.loads(response.content) max_trunk_size = len(result["right_arm"]["ee_delta_position_chunks"]) if self.run_trunk_size > max_trunk_size: Log.warning(f"Run trunk size {self.run_trunk_size} is greater than the number of chunks {max_trunk_size}. Set run trunk size to {max_trunk_size}.") self.run_trunk_size = max_trunk_size self.run_trunk_size = max_trunk_size self.current_chunk_result = result else: result = self.current_chunk_result return result def postprocess_action(self, action: dict) -> BenchmarkAction: benchmark_action = BenchmarkAction() read_chunk_size = 1 dummy_action = self.dummy_data[self.dummy_data_idx:(self.dummy_data_idx + read_chunk_size)] if self.dummy_data_idx + read_chunk_size >= self.dummy_data.shape[0]: self.dummy_data_idx = 0 exit(0) else: self.dummy_data_idx += read_chunk_size read_chunk_id = 0 print(f'{self.current_chunk_id=}, {self.dummy_data_idx = }, {read_chunk_id=}') time.sleep(1.0) left_rpy_state = dummy_action[:, 3:6] # (3,) right_rpy_state = dummy_action[:, 31:34] # (3,) left_rot_state = R.from_euler('xyz', left_rpy_state).as_matrix() right_rot_state = R.from_euler('xyz', right_rpy_state).as_matrix() left_state_rot6d = np.concatenate([left_rot_state[:, 0], left_rot_state[:, 1]], axis=-1) # (6,) right_state_rot6d = np.concatenate([right_rot_state[:, 0], right_rot_state[:, 1]], axis=-1) # (6,) read_state = {"left_arm": { "ee_position_chunks": dummy_action[:, :3].tolist(), "ee_rot6d_chunks": left_state_rot6d.tolist(), "finger_chunks": dummy_action[:, 6:28].tolist()}, "right_arm": { "ee_position_chunks": dummy_action[:, 28:31].tolist(), "ee_rot6d_chunks": right_state_rot6d.tolist(), "finger_chunks": dummy_action[:, 34:56].tolist()} } for arm_key in self.robot['arms'].keys(): action_arm = action[arm_key] delta_ee_pose = Pose(position=action_arm["ee_delta_position_chunks"][self.current_chunk_id], rot6d=action_arm["ee_delta_rot6d_chunks"][self.current_chunk_id]) curr_state_ee_pose = Pose(position=self.current_state[arm_key]["ee_pos"], rot6d=self.current_state[arm_key]["ee_rot6d"]) curr_action_ee_pose = curr_state_ee_pose * delta_ee_pose # action2base = state2base * action2state finger_joint_qpos = action_arm["finger_chunks"][self.current_chunk_id] + self.current_state[arm_key]["finger_qpos"] joint_names = self.left_hand_joints if arm_key == "left_arm" else self.right_hand_joints state_arm = read_state[arm_key] benchmark_action.add_robot_action( RobotAction( control_mode=ControlMode.POSITION, robot_name=self.robot_name, joint_names=joint_names, # joint_positions=finger_joint_qpos joint_positions=state_arm["finger_chunks"][read_chunk_id] ) ) benchmark_action.add_robot_action( RobotAction( control_mode=ControlMode.EE_POSE, robot_name=self.robot_name, # ee_pose=curr_action_ee_pose, ee_pose=Pose(position=state_arm["ee_position_chunks"][read_chunk_id], rot6d=state_arm["ee_rot6d_chunks"][read_chunk_id]), arm_name=arm_key ) ) self._visualize_base_frame_ee_poses(curr_state_ee_pose, curr_action_ee_pose) self._visualize_bounding_boxes() self.current_chunk_id += 1 if self.current_chunk_id == self.run_trunk_size: self.current_chunk_id = 0 self.current_chunk_result = None return benchmark_action # ------------------- Visualization ------------------- def _visualize_base_frame_ee_poses( self, pose_state_base: Pose, pose_action_base: Pose ) -> None: if not self.visualize_action_ee_pose and not self.visualize_state_ee_pose: return robot_base_world = SpawnableController.control_robot( self.robot_name, "get_pose" ).unwrap() if self.visualize_state_ee_pose: VisualizeController.create_pose_visualization( robot_base_world * pose_state_base, name=f"{self.robot_name}/starvla_state_ee", simulator=SimulatorType.ISAACLAB, pose_type=PoseVisualType.COORDINATE, extra_params={"axis_length": 0.08, "thickness": 0.006}, ).unwrap() if self.visualize_action_ee_pose: VisualizeController.create_pose_visualization( robot_base_world * pose_action_base, name=f"{self.robot_name}/starvla_action_ee", simulator=SimulatorType.ISAACLAB, pose_type=PoseVisualType.COORDINATE, extra_params={"axis_length": 0.1, "thickness": 0.006}, ).unwrap() def _visualize_bounding_boxes(self) -> None: if not self.visualize_bounding_box_targets: return for target_name in self.visualize_bounding_box_targets: VisualizeController.visualize_target_bounding_box( target_name, simulator=SimulatorType.ISAACLAB ).unwrap()