import pickle import time import json import numpy as np import requests from joysim.annotations.config_class import configclass, field from joysim.annotations.stereotype import stereotype from joysim.controllers.spawnable_controller import SpawnableController from joysim.controllers.visualize_controller import VisualizeController from joysim.unisim.robots.models.modular_robot import ModularRobot from joysim.utils.namespace import PoseVisualType, SimulatorType from joysim.unisim.robots.actuator_configs.grippers import GripperDriveJointConfig from joysim.extensions.benchmark.action import RobotAction from joysim.extensions.benchmark.benchmark import ( BenchmarkAction, BenchmarkObservation, ControlMode, ) from joysim.extensions.benchmark.policy import Policy, PolicyConfig from joysim.utils.log import Log from joysim.utils.pose import Pose @configclass @stereotype.register_config("starvla") class StarvlaPolicyConfig(PolicyConfig): robot_name: str = field(default="None", required=True, comment="The name of the robot") arm_name: str = field(default="main_arm", required=True, comment="The name of the arm module to control") drive_name: str = field(default="robotiq_85_left_knuckle_joint", required=True, comment="The name of the drive module to control") gripper_width_mapper_file: str = field(default="", required=True, comment="The file path to the gripper width mapper") visualize_action_ee_pose: bool = field(default=False, required=True, comment="Whether to visualize the action end effector pose") visualize_state_ee_pose: bool = field(default=False, required=True, comment="Whether to visualize the state end effector pose") visualize_bounding_box_targets: list[str] = field( default_factory=list, required=False, comment="Spawnable object names to draw bounding boxes for (empty = off)", ) sensor_names: list[str] = field( default=["Hand_Camera", "Left_Camera", "Right_Camera"], required=True, comment="The names of the sensors" ) server_url: str = field( default="http://127.0.0.1:5000/policy", required=True, comment="StarVLA policy server url" ) prompt: str = field( default="pick the object", required=True, comment="task instruction" ) run_trunk_size: int = field( default=16, required=True, comment="The number of chunks to run in one inference step" ) @stereotype.register_model("starvla") class StarvlaPolicy(Policy): def __init__(self, config: StarvlaPolicyConfig): super().__init__(config) self.robot_name = config.robot_name self.arm_name = config.arm_name self.drive_name = config.drive_name self.sensor_names = config.sensor_names self.server_url = config.server_url self.prompt = config.prompt self.gripper_width_mapper = json.load(open(config.gripper_width_mapper_file, "r")) self.visualize_action_ee_pose = config.visualize_action_ee_pose self.visualize_state_ee_pose = config.visualize_state_ee_pose self.visualize_bounding_box_targets = list(config.visualize_bounding_box_targets or []) def reset(self) -> None: self.current_ee_position_state = None self.current_ee_rot6d_state = None self.current_gripper_width = None self.current_chunk_id = 0 self.current_chunk_result = None self.run_trunk_size = self.config.run_trunk_size self.robot: ModularRobot = SpawnableController.get_spawnable_data(self.robot_name).unwrap() self.drive_joints: dict[str, GripperDriveJointConfig] = self.robot.get_arm(self.arm_name).get_ee().get_drive_joints() self.robot_drive_name = list(self.drive_joints.keys())[0] for joint_name, joint_config in self.drive_joints.items(): SpawnableController.control_robot(self.robot_name, "set_joint_stiffness", parameters={"joint_names": [joint_name], "stiffness": joint_config.position_control_stiffness}).unwrap() SpawnableController.control_robot(self.robot_name, "set_joint_damping", parameters={"joint_names": [joint_name], "damping": joint_config.position_control_damping}).unwrap() SpawnableController.control_robot(self.robot_name, "set_joint_effort_limit", parameters={"joint_names": [joint_name], "effort_limit": 5000}).unwrap() SpawnableController.control_robot(self.robot_name, "set_joint_effort_limit", parameters={"joint_names": [self.robot_drive_name], "effort_limit": 5000}).unwrap() self.max_width = float("-inf") self.min_width = float("inf") for entry in self.gripper_width_mapper: self.max_width = max(self.max_width, entry["width"]) self.min_width = min(self.min_width, entry["width"]) def warmup(self, benchmark_observation: BenchmarkObservation) -> None: Log.info(f"Waiting for StarVLA inference server to be ready...") while True: try: if requests.get(f"{self.server_url}/health", timeout=1.0).status_code == 200: break except Exception: time.sleep(1) Log.success(f"StarVLA inference server is ready.") def needs_observation(self) -> bool: return self.current_chunk_id == 0 def _handle_server_error(self, response: requests.Response) -> None: if response.status_code == 500: err_obj = pickle.loads(response.content) Log.error(f"StarVLA server error: {err_obj['error']}") Log.error(f"Traceback: {err_obj['traceback']}", exit=True) elif response.status_code != 200: Log.error(f"StarVLA server error with status code <{response.status_code}> : {response.text}", exit=True) def preprocess_observation(self, benchmark_observation: BenchmarkObservation) -> dict: robot_obs = benchmark_observation.get_robot_observations(self.robot_name)["robot_data"] ee_pose_base = robot_obs["ee_pose"][self.arm_name]["base_frame"] ee_position, ee_rot6d = ee_pose_base["position"],ee_pose_base["rot6d"] arm_joint_positions = robot_obs["joint_positions"][:7] # 临时多加了一个drive的位置,现在读的最后一个joint值是drive drive_joint_positions = robot_obs["joint_positions"][-1] normalized_gripper_width = self.__map_joint_position_to_normalized_width(drive_joint_positions) Log.debug(f"input normalized_gripper_width state: {round(normalized_gripper_width, 2)}") state = np.concatenate([ee_position,ee_rot6d,np.array([normalized_gripper_width]), [0]*10, np.array(arm_joint_positions)]) rgb_data = {} for sensor_name in self.sensor_names: sensor_obs = benchmark_observation.get_sensor_observations(sensor_name) rgb_data[sensor_name] = sensor_obs["rgb"].data.cpu().numpy().astype(np.uint8) obs = {"state": state,"rgb": rgb_data,"prompt": self.prompt} return obs def compute_action(self, observation: dict) -> dict: if self.current_chunk_result is None: self.current_ee_position_state = np.array(observation["state"][:3]).astype(np.float64) self.current_ee_rot6d_state = np.array(observation["state"][3:9]).astype(np.float64) self.current_gripper_width = np.array([observation["state"][9]]) payload = pickle.dumps(observation) response = requests.post( f"{self.server_url}/inference", data=payload, headers={"Content-Type": "application/octet-stream"} ) self.test_obs = observation["state"] #TODO self._handle_server_error(response) result = pickle.loads(response.content) max_trunk_size = len(result["ee_delta_position_chunks"]) if self.run_trunk_size > max_trunk_size: Log.warning(f"Run trunk size {self.run_trunk_size} is greater than the number of chunks {max_trunk_size}. Set run trunk size to {max_trunk_size}.") self.run_trunk_size = max_trunk_size self.run_trunk_size = max_trunk_size self.current_chunk_result = result else: result = self.current_chunk_result return result def __map_joint_position_to_normalized_width(self, joint_position: float) -> float: if joint_position < 0: joint_position = 0 if joint_position > 0.8: joint_position = 0.8 for entry in self.gripper_width_mapper: if round(entry["angel"], 2) == round(joint_position, 2): return 1-(entry["width"] - self.min_width) / (self.max_width - self.min_width) def __map_gripper_joint_position(self, normalized_gripper_width: float) -> float: joint_positions = [] joint_names = [] if normalized_gripper_width > 0.5: for joint_name, joint_config in self.drive_joints.items(): joint_positions.append(joint_config.close_position) joint_names.append(joint_name) else: for joint_name, joint_config in self.drive_joints.items(): joint_positions.append(joint_config.open_position) joint_names.append(joint_name) return joint_positions, joint_names def postprocess_action(self, action: dict) -> BenchmarkAction: benchmark_action = BenchmarkAction() Log.debug(f"observation: {self.test_obs}") # import ipdb;ipdb.set_trace() # get base frame end-effector pose delta_ee_pose = Pose(position=action["ee_delta_position_chunks"][self.current_chunk_id], rot6d=action["ee_delta_rot6d_chunks"][self.current_chunk_id]) curr_state_ee_pose = Pose(position=self.current_ee_position_state, rot6d=self.current_ee_rot6d_state) curr_action_ee_pose = curr_state_ee_pose * delta_ee_pose # action2base = state2base * action2state curr_action_gripper_width = action["gripper_width_chunks"][self.current_chunk_id] gripper_joint_positions, gripper_joint_names = self.__map_gripper_joint_position(curr_action_gripper_width[0]) Log.debug(f"action_gripper_joint_positions: {gripper_joint_positions}, action_normalized_gripper_width: {round(curr_action_gripper_width[0], 2)}") benchmark_action.add_robot_action( RobotAction( control_mode=ControlMode.POSITION, robot_name=self.robot_name, joint_names=gripper_joint_names, joint_positions=gripper_joint_positions ) ) benchmark_action.add_robot_action( RobotAction( control_mode=ControlMode.EE_POSE, robot_name=self.robot_name, ee_pose=curr_action_ee_pose ) ) self._visualize_base_frame_ee_poses(curr_state_ee_pose, curr_action_ee_pose) self._visualize_bounding_boxes() self.current_chunk_id += 1 if self.current_chunk_id == self.run_trunk_size: self.current_chunk_id = 0 self.current_chunk_result = None return benchmark_action # ------------------- Visualization ------------------- def _visualize_base_frame_ee_poses( self, pose_state_base: Pose, pose_action_base: Pose ) -> None: if not self.visualize_action_ee_pose and not self.visualize_state_ee_pose: return robot_base_world = SpawnableController.control_robot( self.robot_name, "get_pose" ).unwrap() if self.visualize_state_ee_pose: VisualizeController.create_pose_visualization( robot_base_world * pose_state_base, name=f"{self.robot_name}/starvla_state_ee", simulator=SimulatorType.ISAACLAB, pose_type=PoseVisualType.COORDINATE, extra_params={"axis_length": 0.08, "thickness": 0.006}, ).unwrap() if self.visualize_action_ee_pose: VisualizeController.create_pose_visualization( robot_base_world * pose_action_base, name=f"{self.robot_name}/starvla_action_ee", simulator=SimulatorType.ISAACLAB, pose_type=PoseVisualType.COORDINATE, extra_params={"axis_length": 0.1, "thickness": 0.006}, ).unwrap() def _visualize_bounding_boxes(self) -> None: if not self.visualize_bounding_box_targets: return for target_name in self.visualize_bounding_box_targets: VisualizeController.visualize_target_bounding_box( target_name, simulator=SimulatorType.ISAACLAB ).unwrap()