From 457c26b8682261e646f55b1574e14199a35973db Mon Sep 17 00:00:00 2001 From: "hufei.hofee" Date: Thu, 19 Mar 2026 20:05:36 +0800 Subject: [PATCH] Update benchmark.yaml and inference server: Adjusted benchmark parameters, modified inference server routes, and enhanced policy handling for chunked actions. --- benchmark.yaml | 9 +++-- starvla_inference_server.py | 13 +++++-- starvla_policy.py | 78 +++++++++++++++++++++++++------------ 3 files changed, 68 insertions(+), 32 deletions(-) diff --git a/benchmark.yaml b/benchmark.yaml index 91d0f5b..d66088f 100644 --- a/benchmark.yaml +++ b/benchmark.yaml @@ -30,8 +30,8 @@ scene: - 0.001 - 0.001 position: - - 0.0 - - -3.79243 + - 0.2 + - -4.15243 - 0.5 quaternion: - -0.304408012043137 @@ -204,7 +204,8 @@ extension: stereotype: starvla robot_name: Franka sensor_names: [Hand_Camera, Left_Camera, Right_Camera] - prompt: pick the cola bottle and place it on the book + prompt: pick up the white plug + run_trunk_size: 16 recorder: enable: false # set to true to record the data @@ -214,7 +215,7 @@ extension: backend_root_path: output://benchmark_record policy_server: - ckpt_path: checkpoints://0309_qwenpi_droid_cola_post/final_model/pytorch_model.pt + ckpt_path: checkpoints://0318_qwenpi_droid_pretrain_8node/checkpoints/steps_30000_pytorch_model.pt ckpt_source: local host: 0.0.0.0 port: 5000 diff --git a/starvla_inference_server.py b/starvla_inference_server.py index f4f674b..0f432de 100644 --- a/starvla_inference_server.py +++ b/starvla_inference_server.py @@ -111,14 +111,14 @@ class StarvlaInferenceServer: actions = actions.cpu().numpy() if actions.ndim == 3: - actions = actions[0] # (8, 7) + actions = actions[0] # (16, 10) return {"ee_delta_position_chunks": actions[:, :3].tolist(), - "ee_delta_euler_xyz_chunks": actions[:, 3:6].tolist(), - "gripper_chunks": actions[:, 6:7].tolist()} + "ee_delta_rot6d_chunks": actions[:, 3:9].tolist(), + "gripper_chunks": actions[:, 9:10].tolist()} def register_routes(self): - @self.app.route("/policy", methods=["POST"]) + @self.app.route("/policy/inference", methods=["POST"]) def policy(): try: data = pickle.loads(request.data) @@ -136,6 +136,11 @@ class StarvlaInferenceServer: mimetype="application/octet-stream", status=500, ) + @self.app.route("/policy/health", methods=["GET"]) + def health(): + if self.model is None: + return Response("Failed to load model", mimetype="application/json", status=500) + return Response("OK", mimetype="application/json") def run(self): diff --git a/starvla_policy.py b/starvla_policy.py index 660e700..3f03a98 100644 --- a/starvla_policy.py +++ b/starvla_policy.py @@ -14,6 +14,7 @@ from joysim.utils.log import Log from joysim.utils.pose import Pose import numpy as np import requests +import time @configclass @stereotype.register_config("starvla") @@ -37,6 +38,12 @@ class StarvlaPolicyConfig(PolicyConfig): comment="task instruction" ) + run_trunk_size: int = field( + default=16, + required=True, + comment="The number of chunks to run in one inference step" + ) + @stereotype.register_model("starvla") class StarvlaPolicy(Policy): @@ -51,33 +58,39 @@ class StarvlaPolicy(Policy): def reset(self) -> None: self.current_ee_position_state = None - self.current_ee_euler_xyz_state = None + self.current_ee_rot6d_state = None self.current_gripper_state = None + self.current_chunk_id = 0 + self.current_chunk_result = None + self.run_trunk_size = self.config.run_trunk_size def warmup(self, benchmark_observation: BenchmarkObservation) -> None: - pass - def needs_observation(self) -> bool: - return True + Log.info(f"Waiting for StarVLA inference server to be ready...") + while True: + try: + if requests.get(f"{self.server_url}/health", timeout=1.0).status_code == 200: + break + except Exception: + time.sleep(1) + Log.success(f"StarVLA inference server is ready.") + def needs_observation(self) -> bool: + return self.current_chunk_id == 0 + def _handle_server_error(self, response: requests.Response) -> None: if response.status_code == 500: err_obj = pickle.loads(response.content) Log.error(f"StarVLA server error: {err_obj['error']}") - Log.error(f"Traceback: {err_obj['traceback']}") - exit(0) + Log.error(f"Traceback: {err_obj['traceback']}", exit=True) elif response.status_code != 200: - Log.error(f"StarVLA server error with status code <{response.status_code}> : {response.text}") - exit(0) + Log.error(f"StarVLA server error with status code <{response.status_code}> : {response.text}", exit=True) def preprocess_observation(self, benchmark_observation: BenchmarkObservation) -> dict: robot_obs = benchmark_observation.get_robot_observations(self.robot_name)["robot_data"] ee_pose_base = robot_obs["ee_pose_base"] - ee_position, ee_euler_xyz = ee_pose_base["position"],ee_pose_base["euler_xyz"] + ee_position, ee_rot6d = ee_pose_base["position"],ee_pose_base["rot6d"] gripper = 0.0 if robot_obs["gripper_state"]["opened"] else 1.0 - state = np.concatenate([ee_position,ee_euler_xyz,np.array([gripper])]) - self.current_ee_position_state = np.array(ee_position).astype(np.float64) - self.current_ee_euler_xyz_state = np.array(ee_euler_xyz).astype(np.float64) - self.current_gripper_state = np.array([gripper]) + state = np.concatenate([ee_position,ee_rot6d,np.array([gripper])]) rgb_data = {} for sensor_name in self.sensor_names: sensor_obs = benchmark_observation.get_sensor_observations(sensor_name) @@ -87,22 +100,36 @@ class StarvlaPolicy(Policy): return obs def compute_action(self, observation: dict) -> dict: - payload = pickle.dumps(observation) - response = requests.post( - self.server_url, - data=payload, - headers={"Content-Type": "application/octet-stream"} - ) - self._handle_server_error(response) - result = pickle.loads(response.content) + if self.current_chunk_result is None: + self.current_ee_position_state = np.array(observation["state"][:3]).astype(np.float64) + self.current_ee_rot6d_state = np.array(observation["state"][3:9]).astype(np.float64) + self.current_gripper_state = np.array([observation["state"][9]]) + payload = pickle.dumps(observation) + response = requests.post( + f"{self.server_url}/inference", + data=payload, + headers={"Content-Type": "application/octet-stream"} + ) + self._handle_server_error(response) + result = pickle.loads(response.content) + max_trunk_size = len(result["ee_delta_position_chunks"]) + if self.run_trunk_size > max_trunk_size: + Log.warning(f"Run trunk size {self.run_trunk_size} is greater than the number of chunks {max_trunk_size}. Set run trunk size to {max_trunk_size}.") + self.run_trunk_size = max_trunk_size + self.run_trunk_size = max_trunk_size + self.current_chunk_result = result + else: + result = self.current_chunk_result + return result def postprocess_action(self, action: dict) -> BenchmarkAction: benchmark_action = BenchmarkAction() # get base frame end-effector pose - delta_ee_pose = Pose(position=action["ee_delta_position_chunks"][0], euler_xyz=action["ee_delta_euler_xyz_chunks"][0]) - curr_state_ee_pose = Pose(position=self.current_ee_position_state, euler_xyz=self.current_ee_euler_xyz_state) + delta_ee_pose = Pose(position=action["ee_delta_position_chunks"][self.current_chunk_id], rot6d=action["ee_delta_rot6d_chunks"][self.current_chunk_id]) + curr_state_ee_pose = Pose(position=self.current_ee_position_state, rot6d=self.current_ee_rot6d_state) + Log.debug(f"trunck_id: {self.current_chunk_id}, curr_state_ee_pose: {curr_state_ee_pose}") curr_action_ee_pose = curr_state_ee_pose * delta_ee_pose # action2base = state2base * action2state ik_result = MotionPlanController.solve_ik( robot_name=self.robot_name, @@ -122,5 +149,8 @@ class StarvlaPolicy(Policy): joint_positions=joint_positions ) ) - + self.current_chunk_id += 1 + if self.current_chunk_id == self.run_trunk_size: + self.current_chunk_id = 0 + self.current_chunk_result = None return benchmark_action \ No newline at end of file