Update benchmark.yaml and inference server: Adjusted benchmark parameters, modified inference server routes, and enhanced policy handling for chunked actions.

This commit is contained in:
hufei.hofee
2026-03-19 20:05:36 +08:00
parent 852bdc0dd7
commit 457c26b868
3 changed files with 68 additions and 32 deletions

View File

@@ -30,8 +30,8 @@ scene:
- 0.001 - 0.001
- 0.001 - 0.001
position: position:
- 0.0 - 0.2
- -3.79243 - -4.15243
- 0.5 - 0.5
quaternion: quaternion:
- -0.304408012043137 - -0.304408012043137
@@ -204,7 +204,8 @@ extension:
stereotype: starvla stereotype: starvla
robot_name: Franka robot_name: Franka
sensor_names: [Hand_Camera, Left_Camera, Right_Camera] sensor_names: [Hand_Camera, Left_Camera, Right_Camera]
prompt: pick the cola bottle and place it on the book prompt: pick up the white plug
run_trunk_size: 16
recorder: recorder:
enable: false # set to true to record the data enable: false # set to true to record the data
@@ -214,7 +215,7 @@ extension:
backend_root_path: output://benchmark_record backend_root_path: output://benchmark_record
policy_server: policy_server:
ckpt_path: checkpoints://0309_qwenpi_droid_cola_post/final_model/pytorch_model.pt ckpt_path: checkpoints://0318_qwenpi_droid_pretrain_8node/checkpoints/steps_30000_pytorch_model.pt
ckpt_source: local ckpt_source: local
host: 0.0.0.0 host: 0.0.0.0
port: 5000 port: 5000

View File

@@ -111,14 +111,14 @@ class StarvlaInferenceServer:
actions = actions.cpu().numpy() actions = actions.cpu().numpy()
if actions.ndim == 3: if actions.ndim == 3:
actions = actions[0] # (8, 7) actions = actions[0] # (16, 10)
return {"ee_delta_position_chunks": actions[:, :3].tolist(), return {"ee_delta_position_chunks": actions[:, :3].tolist(),
"ee_delta_euler_xyz_chunks": actions[:, 3:6].tolist(), "ee_delta_rot6d_chunks": actions[:, 3:9].tolist(),
"gripper_chunks": actions[:, 6:7].tolist()} "gripper_chunks": actions[:, 9:10].tolist()}
def register_routes(self): def register_routes(self):
@self.app.route("/policy", methods=["POST"]) @self.app.route("/policy/inference", methods=["POST"])
def policy(): def policy():
try: try:
data = pickle.loads(request.data) data = pickle.loads(request.data)
@@ -136,6 +136,11 @@ class StarvlaInferenceServer:
mimetype="application/octet-stream", mimetype="application/octet-stream",
status=500, status=500,
) )
@self.app.route("/policy/health", methods=["GET"])
def health():
if self.model is None:
return Response("Failed to load model", mimetype="application/json", status=500)
return Response("OK", mimetype="application/json")
def run(self): def run(self):

View File

@@ -14,6 +14,7 @@ from joysim.utils.log import Log
from joysim.utils.pose import Pose from joysim.utils.pose import Pose
import numpy as np import numpy as np
import requests import requests
import time
@configclass @configclass
@stereotype.register_config("starvla") @stereotype.register_config("starvla")
@@ -37,6 +38,12 @@ class StarvlaPolicyConfig(PolicyConfig):
comment="task instruction" comment="task instruction"
) )
run_trunk_size: int = field(
default=16,
required=True,
comment="The number of chunks to run in one inference step"
)
@stereotype.register_model("starvla") @stereotype.register_model("starvla")
class StarvlaPolicy(Policy): class StarvlaPolicy(Policy):
@@ -51,33 +58,39 @@ class StarvlaPolicy(Policy):
def reset(self) -> None: def reset(self) -> None:
self.current_ee_position_state = None self.current_ee_position_state = None
self.current_ee_euler_xyz_state = None self.current_ee_rot6d_state = None
self.current_gripper_state = None self.current_gripper_state = None
self.current_chunk_id = 0
self.current_chunk_result = None
self.run_trunk_size = self.config.run_trunk_size
def warmup(self, benchmark_observation: BenchmarkObservation) -> None: def warmup(self, benchmark_observation: BenchmarkObservation) -> None:
pass Log.info(f"Waiting for StarVLA inference server to be ready...")
while True:
try:
if requests.get(f"{self.server_url}/health", timeout=1.0).status_code == 200:
break
except Exception:
time.sleep(1)
Log.success(f"StarVLA inference server is ready.")
def needs_observation(self) -> bool: def needs_observation(self) -> bool:
return True return self.current_chunk_id == 0
def _handle_server_error(self, response: requests.Response) -> None: def _handle_server_error(self, response: requests.Response) -> None:
if response.status_code == 500: if response.status_code == 500:
err_obj = pickle.loads(response.content) err_obj = pickle.loads(response.content)
Log.error(f"StarVLA server error: {err_obj['error']}") Log.error(f"StarVLA server error: {err_obj['error']}")
Log.error(f"Traceback: {err_obj['traceback']}") Log.error(f"Traceback: {err_obj['traceback']}", exit=True)
exit(0)
elif response.status_code != 200: elif response.status_code != 200:
Log.error(f"StarVLA server error with status code <{response.status_code}> : {response.text}") Log.error(f"StarVLA server error with status code <{response.status_code}> : {response.text}", exit=True)
exit(0)
def preprocess_observation(self, benchmark_observation: BenchmarkObservation) -> dict: def preprocess_observation(self, benchmark_observation: BenchmarkObservation) -> dict:
robot_obs = benchmark_observation.get_robot_observations(self.robot_name)["robot_data"] robot_obs = benchmark_observation.get_robot_observations(self.robot_name)["robot_data"]
ee_pose_base = robot_obs["ee_pose_base"] ee_pose_base = robot_obs["ee_pose_base"]
ee_position, ee_euler_xyz = ee_pose_base["position"],ee_pose_base["euler_xyz"] ee_position, ee_rot6d = ee_pose_base["position"],ee_pose_base["rot6d"]
gripper = 0.0 if robot_obs["gripper_state"]["opened"] else 1.0 gripper = 0.0 if robot_obs["gripper_state"]["opened"] else 1.0
state = np.concatenate([ee_position,ee_euler_xyz,np.array([gripper])]) state = np.concatenate([ee_position,ee_rot6d,np.array([gripper])])
self.current_ee_position_state = np.array(ee_position).astype(np.float64)
self.current_ee_euler_xyz_state = np.array(ee_euler_xyz).astype(np.float64)
self.current_gripper_state = np.array([gripper])
rgb_data = {} rgb_data = {}
for sensor_name in self.sensor_names: for sensor_name in self.sensor_names:
sensor_obs = benchmark_observation.get_sensor_observations(sensor_name) sensor_obs = benchmark_observation.get_sensor_observations(sensor_name)
@@ -87,22 +100,36 @@ class StarvlaPolicy(Policy):
return obs return obs
def compute_action(self, observation: dict) -> dict: def compute_action(self, observation: dict) -> dict:
if self.current_chunk_result is None:
self.current_ee_position_state = np.array(observation["state"][:3]).astype(np.float64)
self.current_ee_rot6d_state = np.array(observation["state"][3:9]).astype(np.float64)
self.current_gripper_state = np.array([observation["state"][9]])
payload = pickle.dumps(observation) payload = pickle.dumps(observation)
response = requests.post( response = requests.post(
self.server_url, f"{self.server_url}/inference",
data=payload, data=payload,
headers={"Content-Type": "application/octet-stream"} headers={"Content-Type": "application/octet-stream"}
) )
self._handle_server_error(response) self._handle_server_error(response)
result = pickle.loads(response.content) result = pickle.loads(response.content)
max_trunk_size = len(result["ee_delta_position_chunks"])
if self.run_trunk_size > max_trunk_size:
Log.warning(f"Run trunk size {self.run_trunk_size} is greater than the number of chunks {max_trunk_size}. Set run trunk size to {max_trunk_size}.")
self.run_trunk_size = max_trunk_size
self.run_trunk_size = max_trunk_size
self.current_chunk_result = result
else:
result = self.current_chunk_result
return result return result
def postprocess_action(self, action: dict) -> BenchmarkAction: def postprocess_action(self, action: dict) -> BenchmarkAction:
benchmark_action = BenchmarkAction() benchmark_action = BenchmarkAction()
# get base frame end-effector pose # get base frame end-effector pose
delta_ee_pose = Pose(position=action["ee_delta_position_chunks"][0], euler_xyz=action["ee_delta_euler_xyz_chunks"][0]) delta_ee_pose = Pose(position=action["ee_delta_position_chunks"][self.current_chunk_id], rot6d=action["ee_delta_rot6d_chunks"][self.current_chunk_id])
curr_state_ee_pose = Pose(position=self.current_ee_position_state, euler_xyz=self.current_ee_euler_xyz_state) curr_state_ee_pose = Pose(position=self.current_ee_position_state, rot6d=self.current_ee_rot6d_state)
Log.debug(f"trunck_id: {self.current_chunk_id}, curr_state_ee_pose: {curr_state_ee_pose}")
curr_action_ee_pose = curr_state_ee_pose * delta_ee_pose # action2base = state2base * action2state curr_action_ee_pose = curr_state_ee_pose * delta_ee_pose # action2base = state2base * action2state
ik_result = MotionPlanController.solve_ik( ik_result = MotionPlanController.solve_ik(
robot_name=self.robot_name, robot_name=self.robot_name,
@@ -122,5 +149,8 @@ class StarvlaPolicy(Policy):
joint_positions=joint_positions joint_positions=joint_positions
) )
) )
self.current_chunk_id += 1
if self.current_chunk_id == self.run_trunk_size:
self.current_chunk_id = 0
self.current_chunk_result = None
return benchmark_action return benchmark_action