import pickle from joysim.annotations.config_class import configclass, field from joysim.annotations.stereotype import stereotype from joysim.app import JoySim from joysim.controllers.motion_plan_controller import MotionPlanController from joysim.core.scene_manager import SceneManager from joysim.extensions.benchmark.action import RobotAction from joysim.extensions.benchmark.benchmark import ( BenchmarkAction, BenchmarkObservation, ControlMode, ) from joysim.extensions.benchmark.policy import Policy, PolicyConfig from joysim.utils.log import Log from joysim.utils.pose import Pose import numpy as np import requests @configclass @stereotype.register_config("starvla") class StarvlaPolicyConfig(PolicyConfig): robot_name: str = field(default="my_robot", required=True, comment="The name of the robot") sensor_names: list[str] = field( default=["Hand_Camera", "Left_Camera", "Right_Camera"], required=True, comment="The names of the sensors" ) server_url: str = field( default="http://127.0.0.1:5000/policy", required=True, comment="StarVLA policy server url" ) prompt: str = field( default="pick the object", required=True, comment="task instruction" ) @stereotype.register_model("starvla") class StarvlaPolicy(Policy): def __init__(self, config: StarvlaPolicyConfig): super().__init__(config) self.robot_name = config.robot_name self.sensor_names = config.sensor_names self.server_url = config.server_url self.prompt = config.prompt def reset(self) -> None: self.current_ee_position_state = None self.current_ee_euler_xyz_state = None self.current_gripper_state = None def warmup(self, benchmark_observation: BenchmarkObservation) -> None: pass def needs_observation(self) -> bool: return True def _handle_server_error(self, response: requests.Response) -> None: if response.status_code == 500: err_obj = pickle.loads(response.content) Log.error(f"StarVLA server error: {err_obj['error']}") Log.error(f"Traceback: {err_obj['traceback']}") exit(0) elif response.status_code != 200: Log.error(f"StarVLA server error with status code <{response.status_code}> : {response.text}") exit(0) def preprocess_observation(self, benchmark_observation: BenchmarkObservation) -> dict: robot_obs = benchmark_observation.get_robot_observations(self.robot_name)["robot_data"] ee_pose_base = robot_obs["ee_pose_base"] ee_position, ee_euler_xyz = ee_pose_base["position"],ee_pose_base["euler_xyz"] gripper = 1.0 if robot_obs["gripper_state"]["opened"] else 0.0 state = np.concatenate([ee_position,ee_euler_xyz,np.array([gripper])]) self.current_ee_position_state = np.array(ee_position).astype(np.float64) self.current_ee_euler_xyz_state = np.array(ee_euler_xyz).astype(np.float64) self.current_gripper_state = np.array([gripper]) rgb_data = {} for sensor_name in self.sensor_names: sensor_obs = benchmark_observation.get_sensor_observations(sensor_name) rgb_data[sensor_name] = sensor_obs["rgb"].data.cpu().numpy().astype(np.uint8) obs = {"state": state,"rgb": rgb_data,"prompt": self.prompt} return obs def compute_action(self, observation: dict) -> dict: payload = pickle.dumps(observation) response = requests.post( self.server_url, data=payload, headers={"Content-Type": "application/octet-stream"} ) self._handle_server_error(response) result = pickle.loads(response.content) return result def postprocess_action(self, action: dict) -> BenchmarkAction: benchmark_action = BenchmarkAction() # get base frame end-effector pose # TODO: Make sure add or multiply the current state ee_position = action["ee_delta_position_chunks"][0] + self.current_ee_position_state ee_euler_xyz = action["ee_delta_euler_xyz_chunks"][0] + self.current_ee_euler_xyz_state ee_pose = Pose(position=ee_position, euler_xyz=ee_euler_xyz) ik_result = MotionPlanController.solve_ik( robot_name=self.robot_name, base_frame_ee_pose=ee_pose, ).unwrap() if not ik_result["success"]: Log.error(f"IK failed: {ik_result['status']}. Ignore this action.") return benchmark_action joint_names = ik_result["result"]["plannable_joint_names"] joint_positions = ik_result["result"]["plannable_joint_positions"][0] benchmark_action.add_robot_action( RobotAction( control_mode=ControlMode.POSITION, robot_name=self.robot_name, joint_names=joint_names, joint_positions=joint_positions ) ) return benchmark_action if __name__ == "__main__": js = JoySim("/home/ubuntu/projects/benchmark/benchmark.yaml") js.start()