finish benchmark debug
This commit is contained in:
@@ -1,7 +1,9 @@
|
||||
|
||||
import yaml
|
||||
import pickle
|
||||
import argparse
|
||||
import os
|
||||
import traceback
|
||||
from urllib.parse import urlparse
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -77,49 +79,17 @@ class StarvlaInferenceServer:
|
||||
|
||||
return model
|
||||
|
||||
def parse_observation(self, obs):
|
||||
def parse_observation(self, obs, target_size=(320, 180)):
|
||||
|
||||
rgb = obs["rgb"][-1]
|
||||
state = obs["state"][-1]
|
||||
joint = obs.get("joint", None)
|
||||
prompt = obs["prompt"]
|
||||
left_rgb, right_rgb, wrist_rgb = obs["rgb"]["Left_Camera"], obs["rgb"]["Right_Camera"], obs["rgb"]["Hand_Camera"]
|
||||
|
||||
left = rgb[:, :, :3]
|
||||
right = rgb[:, :, 3:6]
|
||||
wrist = rgb[:, :, 6:9]
|
||||
img_left = Image.fromarray(cv2.resize(left_rgb, target_size))
|
||||
img_right = Image.fromarray(cv2.resize(right_rgb, target_size))
|
||||
img_wrist = Image.fromarray(cv2.resize(wrist_rgb, target_size))
|
||||
|
||||
target_size = (320, 180)
|
||||
state_vec = obs["state"]
|
||||
|
||||
left = cv2.resize(left, target_size)
|
||||
right = cv2.resize(right, target_size)
|
||||
wrist = cv2.resize(wrist, target_size)
|
||||
|
||||
img_left = Image.fromarray(left)
|
||||
img_right = Image.fromarray(right)
|
||||
img_wrist = Image.fromarray(wrist)
|
||||
|
||||
if self.state_mode == "joint8":
|
||||
|
||||
joint_last = joint[-1]
|
||||
gripper = state[9]
|
||||
|
||||
state_vec = np.concatenate(
|
||||
[joint_last, np.array([gripper])],
|
||||
axis=0
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
xyz = state[0:3]
|
||||
rot6d = state[3:9]
|
||||
gripper = state[9]
|
||||
|
||||
state_vec = np.concatenate(
|
||||
[xyz, rot6d[:3], np.array([gripper])],
|
||||
axis=0
|
||||
)
|
||||
|
||||
return img_left, img_right, img_wrist, state_vec, prompt
|
||||
return img_left, img_right, img_wrist, state_vec, obs["prompt"]
|
||||
|
||||
def inference(self, observation: dict) -> dict:
|
||||
|
||||
@@ -141,18 +111,31 @@ class StarvlaInferenceServer:
|
||||
actions = actions.cpu().numpy()
|
||||
|
||||
if actions.ndim == 3:
|
||||
actions = actions[0]
|
||||
|
||||
return {"action": actions.astype(np.float32)}
|
||||
actions = actions[0] # (8, 7)
|
||||
return {"ee_delta_position_chunks": actions[:, :3].tolist(),
|
||||
"ee_delta_euler_xyz_chunks": actions[:, 3:6].tolist(),
|
||||
"gripper_chunks": actions[:, 6:7].tolist()}
|
||||
|
||||
def register_routes(self):
|
||||
|
||||
@self.app.route("/policy", methods=["POST"])
|
||||
def policy():
|
||||
data = pickle.loads(request.data)
|
||||
result = self.inference(data)
|
||||
body = pickle.dumps(result, protocol=4)
|
||||
return Response(body, mimetype="application/octet-stream")
|
||||
try:
|
||||
data = pickle.loads(request.data)
|
||||
result = self.inference(data)
|
||||
body = pickle.dumps(result, protocol=4)
|
||||
return Response(body, mimetype="application/octet-stream")
|
||||
except Exception as e:
|
||||
err_obj = {
|
||||
"error": str(e),
|
||||
"traceback": traceback.format_exc(),
|
||||
}
|
||||
body = pickle.dumps(err_obj, protocol=4)
|
||||
return Response(
|
||||
body,
|
||||
mimetype="application/octet-stream",
|
||||
status=500,
|
||||
)
|
||||
|
||||
def run(self):
|
||||
|
||||
@@ -168,6 +151,15 @@ class StarvlaInferenceServer:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
config_path = "./benchmark.yaml"
|
||||
parser = argparse.ArgumentParser(description="StarVLA inference server")
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
type=str,
|
||||
default="./benchmark.yaml",
|
||||
help="Path to benchmark.yaml",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
config_path = args.config
|
||||
server = StarvlaInferenceServer(config_path)
|
||||
server.run()
|
||||
Reference in New Issue
Block a user