from joysim.annotations.config_class import configclass, field from joysim.annotations.stereotype import stereotype from joysim.app import JoySim from joysim.core.scene_manager import SceneManager from joysim.extensions.benchmark.action import RobotAction from joysim.extensions.benchmark.benchmark import ( BenchmarkAction, BenchmarkObservation, ControlMode, ) from joysim.extensions.benchmark.policy import Policy, PolicyConfig import numpy as np import pickle import requests @configclass @stereotype.register_config("starvla") class StarvlaPolicyConfig(PolicyConfig): robot_name: str = field(default="my_robot", required=True, comment="The name of the robot") object_name: str = field(default="target", required=True, comment="The name of the object") server_url: str = field( default="http://127.0.0.1:5000/policy", required=True, comment="StarVLA policy server url" ) prompt: str = field( default="pick the object", required=True, comment="task instruction" ) @stereotype.register_model("starvla") class StarvlaPolicy(Policy): def __init__(self, config: StarvlaPolicyConfig): super().__init__(config) self.robot_name = config.robot_name self.object_name = config.object_name self.server_url = config.server_url self.prompt = config.prompt def reset(self) -> None: pass def warmup(self, benchmark_observation: BenchmarkObservation) -> None: pass def needs_observation(self) -> bool: return True def preprocess_observation(self, benchmark_observation: BenchmarkObservation) -> dict: robot_obs = benchmark_observation.get_robot_observations(self.robot_name)["robot_data"] joint_positions = robot_obs["joint_positions"] robot_position = robot_obs["position"] robot_quaternion = robot_obs["rotation"] state = np.concatenate([ robot_position, robot_quaternion, np.array([0.0]) ]) camera_obs = benchmark_observation.get_sensor_observations() rgb = camera_obs["rgb"] obs = { "state": np.expand_dims(state, axis=0), "joint": np.expand_dims(joint_positions, axis=0), "rgb": np.expand_dims(rgb, axis=0), "prompt": self.prompt } return obs def compute_action(self, observation: dict) -> dict: payload = pickle.dumps(observation) response = requests.post( self.server_url, data=payload, headers={"Content-Type": "application/octet-stream"} ) if response.status_code != 200: raise RuntimeError(f"StarVLA server error: {response.text}") result = pickle.loads(response.content) return result def postprocess_action(self, action: dict) -> BenchmarkAction: benchmark_action = BenchmarkAction() robot = SceneManager.get_robot(self.robot_name) joint_names = robot.get_planner().get_plannable_joint_names() joint_positions = action["action"][0] benchmark_action.add_robot_action( RobotAction( control_mode=ControlMode.POSITION, robot_name=self.robot_name, joint_names=joint_names, joint_positions=joint_positions ) ) return benchmark_action if __name__ == "__main__": js = JoySim("./benchmark.yaml") js.start()