finish benchmark debug

This commit is contained in:
hufei.hofee
2026-03-19 01:43:02 +08:00
parent cc9815f3b8
commit 1b30c3f96a
5 changed files with 257 additions and 189 deletions

View File

@@ -1,7 +1,9 @@
import yaml
import pickle
import argparse
import os
import traceback
from urllib.parse import urlparse
import numpy as np
import torch
@@ -77,49 +79,17 @@ class StarvlaInferenceServer:
return model
def parse_observation(self, obs):
def parse_observation(self, obs, target_size=(320, 180)):
rgb = obs["rgb"][-1]
state = obs["state"][-1]
joint = obs.get("joint", None)
prompt = obs["prompt"]
left_rgb, right_rgb, wrist_rgb = obs["rgb"]["Left_Camera"], obs["rgb"]["Right_Camera"], obs["rgb"]["Hand_Camera"]
left = rgb[:, :, :3]
right = rgb[:, :, 3:6]
wrist = rgb[:, :, 6:9]
img_left = Image.fromarray(cv2.resize(left_rgb, target_size))
img_right = Image.fromarray(cv2.resize(right_rgb, target_size))
img_wrist = Image.fromarray(cv2.resize(wrist_rgb, target_size))
target_size = (320, 180)
state_vec = obs["state"]
left = cv2.resize(left, target_size)
right = cv2.resize(right, target_size)
wrist = cv2.resize(wrist, target_size)
img_left = Image.fromarray(left)
img_right = Image.fromarray(right)
img_wrist = Image.fromarray(wrist)
if self.state_mode == "joint8":
joint_last = joint[-1]
gripper = state[9]
state_vec = np.concatenate(
[joint_last, np.array([gripper])],
axis=0
)
else:
xyz = state[0:3]
rot6d = state[3:9]
gripper = state[9]
state_vec = np.concatenate(
[xyz, rot6d[:3], np.array([gripper])],
axis=0
)
return img_left, img_right, img_wrist, state_vec, prompt
return img_left, img_right, img_wrist, state_vec, obs["prompt"]
def inference(self, observation: dict) -> dict:
@@ -141,18 +111,31 @@ class StarvlaInferenceServer:
actions = actions.cpu().numpy()
if actions.ndim == 3:
actions = actions[0]
return {"action": actions.astype(np.float32)}
actions = actions[0] # (8, 7)
return {"ee_delta_position_chunks": actions[:, :3].tolist(),
"ee_delta_euler_xyz_chunks": actions[:, 3:6].tolist(),
"gripper_chunks": actions[:, 6:7].tolist()}
def register_routes(self):
@self.app.route("/policy", methods=["POST"])
def policy():
data = pickle.loads(request.data)
result = self.inference(data)
body = pickle.dumps(result, protocol=4)
return Response(body, mimetype="application/octet-stream")
try:
data = pickle.loads(request.data)
result = self.inference(data)
body = pickle.dumps(result, protocol=4)
return Response(body, mimetype="application/octet-stream")
except Exception as e:
err_obj = {
"error": str(e),
"traceback": traceback.format_exc(),
}
body = pickle.dumps(err_obj, protocol=4)
return Response(
body,
mimetype="application/octet-stream",
status=500,
)
def run(self):
@@ -168,6 +151,15 @@ class StarvlaInferenceServer:
if __name__ == "__main__":
config_path = "./benchmark.yaml"
parser = argparse.ArgumentParser(description="StarVLA inference server")
parser.add_argument(
"--config",
type=str,
default="./benchmark.yaml",
help="Path to benchmark.yaml",
)
args = parser.parse_args()
config_path = args.config
server = StarvlaInferenceServer(config_path)
server.run()