Files
starvla_benchmark/starvla_inference_server.py

170 lines
5.4 KiB
Python

import yaml
import pickle
import argparse
import os
import traceback
from urllib.parse import urlparse
import numpy as np
import torch
import cv2
from flask import Flask, request, Response
from PIL import Image
class StarvlaInferenceServer:
def __init__(self, config_path: str):
with open(config_path, "r") as f:
cfg = yaml.safe_load(f)
policy_server_cfg = cfg["policy_server"]
root_paths = cfg["general"]["root_paths"]
self.ckpt_source = policy_server_cfg["ckpt_source"]
self.ckpt_path = self._resolve_ckpt_path(
ckpt_url=policy_server_cfg["ckpt_path"],
root_paths=root_paths,
)
self.host = policy_server_cfg.get("host", "0.0.0.0")
self.port = policy_server_cfg.get("port", 5000)
self.use_bf16 = policy_server_cfg.get("use_bf16", True)
self.unnorm_key = policy_server_cfg.get("unnorm_key", "oxe_bridge")
self.state_mode = policy_server_cfg.get("state_mode", "ee_pose7")
print("Loading StarVLA model...")
self.model = self.load_model()
print("Model loaded.")
self.app = Flask(__name__)
self.register_routes()
@staticmethod
def _resolve_ckpt_path(ckpt_url: str, root_paths: dict) -> str:
parsed = urlparse(ckpt_url)
if not parsed.scheme:
return ckpt_url
root = root_paths.get(parsed.scheme)
if not root:
raise KeyError(
f"cannot find the checkpoint root path in root_paths: {root_paths}"
)
rel = (parsed.netloc + parsed.path).lstrip("/")
return os.path.join(root, rel)
def load_model(self):
from starVLA.model.framework.share_tools import read_mode_config, dict_to_namespace
from starVLA.model.framework.__init__ import build_framework
model_config, norm_stats = read_mode_config(self.ckpt_path)
cfg = dict_to_namespace(model_config)
cfg.trainer.pretrained_checkpoint = None
model = build_framework(cfg=cfg)
model.norm_stats = norm_stats
state_dict = torch.load(self.ckpt_path, map_location="cpu")
model.load_state_dict(state_dict, strict=True)
if self.use_bf16:
model = model.to(torch.bfloat16)
model = model.to("cuda").eval()
self.norm_stats = norm_stats
self.action_norm_stats = norm_stats.get(self.unnorm_key, {}).get("action", None)
return model
def parse_observation(self, obs, target_size=(320, 180)):
left_rgb, right_rgb, wrist_rgb = obs["rgb"]["Left_Camera"], obs["rgb"]["Right_Camera"], obs["rgb"]["Hand_Camera"]
img_left = Image.fromarray(cv2.resize(left_rgb, target_size))
img_right = Image.fromarray(cv2.resize(right_rgb, target_size))
img_wrist = Image.fromarray(cv2.resize(wrist_rgb, target_size))
state_vec = obs["state"]
return img_left, img_right, img_wrist, state_vec, obs["prompt"]
def inference(self, observation: dict) -> dict:
img_left, img_right, img_wrist, state_vec, prompt = \
self.parse_observation(observation)
vla_input = {
"batch_images": [[img_left, img_right, img_wrist]],
"instructions": [prompt],
"state": [state_vec]
}
with torch.no_grad():
output = self.model.predict_action(**vla_input)
actions = output.get("normalized_actions")
if isinstance(actions, torch.Tensor):
actions = actions.cpu().numpy()
if actions.ndim == 3:
actions = actions[0] # (16, 10)
return {"ee_delta_position_chunks": actions[:, :3].tolist(),
"ee_delta_rot6d_chunks": actions[:, 3:9].tolist(),
"gripper_chunks": actions[:, 9:10].tolist()}
def register_routes(self):
@self.app.route("/policy/inference", methods=["POST"])
def policy():
try:
data = pickle.loads(request.data)
result = self.inference(data)
body = pickle.dumps(result, protocol=4)
return Response(body, mimetype="application/octet-stream")
except Exception as e:
err_obj = {
"error": str(e),
"traceback": traceback.format_exc(),
}
body = pickle.dumps(err_obj, protocol=4)
return Response(
body,
mimetype="application/octet-stream",
status=500,
)
@self.app.route("/policy/health", methods=["GET"])
def health():
if self.model is None:
return Response("Failed to load model", mimetype="application/json", status=500)
return Response("OK", mimetype="application/json")
def run(self):
print("StarVLA policy server running")
print(f"Host: {self.host}")
print(f"Port: {self.port}")
self.app.run(
host=self.host,
port=self.port,
threaded=True
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="StarVLA inference server")
parser.add_argument(
"--config",
type=str,
default="./benchmark.yaml",
help="Path to benchmark.yaml",
)
args = parser.parse_args()
config_path = args.config
server = StarvlaInferenceServer(config_path)
server.run()