import yaml import pickle import argparse import os import traceback from urllib.parse import urlparse import numpy as np import torch import cv2 from flask import Flask, request, Response from PIL import Image class StarvlaInferenceServer: def __init__(self, config_path: str): with open(config_path, "r") as f: cfg = yaml.safe_load(f) policy_server_cfg = cfg["policy_server"] root_paths = cfg["general"]["root_paths"] self.ckpt_source = policy_server_cfg["ckpt_source"] self.ckpt_path = self._resolve_ckpt_path( ckpt_url=policy_server_cfg["ckpt_path"], root_paths=root_paths, ) self.host = policy_server_cfg.get("host", "0.0.0.0") self.port = policy_server_cfg.get("port", 5000) self.use_bf16 = policy_server_cfg.get("use_bf16", True) self.unnorm_key = policy_server_cfg.get("unnorm_key", "oxe_bridge") self.state_mode = policy_server_cfg.get("state_mode", "ee_pose7") print("Loading StarVLA model...") self.model = self.load_model() print("Model loaded.") self.app = Flask(__name__) self.register_routes() @staticmethod def _resolve_ckpt_path(ckpt_url: str, root_paths: dict) -> str: parsed = urlparse(ckpt_url) if not parsed.scheme: return ckpt_url root = root_paths.get(parsed.scheme) if not root: raise KeyError( f"cannot find the checkpoint root path in root_paths: {root_paths}" ) rel = (parsed.netloc + parsed.path).lstrip("/") return os.path.join(root, rel) def load_model(self): from starVLA.model.framework.share_tools import read_mode_config, dict_to_namespace from starVLA.model.framework.__init__ import build_framework model_config, norm_stats = read_mode_config(self.ckpt_path) cfg = dict_to_namespace(model_config) cfg.trainer.pretrained_checkpoint = None model = build_framework(cfg=cfg) model.norm_stats = norm_stats state_dict = torch.load(self.ckpt_path, map_location="cpu") model.load_state_dict(state_dict, strict=True) if self.use_bf16: model = model.to(torch.bfloat16) model = model.to("cuda").eval() self.norm_stats = norm_stats self.action_norm_stats = norm_stats.get(self.unnorm_key, {}).get("action", None) return model def parse_observation(self, obs, target_size=(320, 180)): left_rgb, right_rgb, wrist_rgb = obs["rgb"]["Left_Camera"], obs["rgb"]["Right_Camera"], obs["rgb"]["Hand_Camera"] img_left = Image.fromarray(cv2.resize(left_rgb, target_size)) img_right = Image.fromarray(cv2.resize(right_rgb, target_size)) img_wrist = Image.fromarray(cv2.resize(wrist_rgb, target_size)) state_vec = obs["state"] return img_left, img_right, img_wrist, state_vec, obs["prompt"] def inference(self, observation: dict) -> dict: img_left, img_right, img_wrist, state_vec, prompt = \ self.parse_observation(observation) vla_input = { "batch_images": [[img_left, img_right, img_wrist]], "instructions": [prompt], "state": [state_vec] } with torch.no_grad(): output = self.model.predict_action(**vla_input) actions = output.get("normalized_actions") if isinstance(actions, torch.Tensor): actions = actions.cpu().numpy() if actions.ndim == 3: actions = actions[0] # (8, 7) return {"ee_delta_position_chunks": actions[:, :3].tolist(), "ee_delta_euler_xyz_chunks": actions[:, 3:6].tolist(), "gripper_chunks": actions[:, 6:7].tolist()} def register_routes(self): @self.app.route("/policy", methods=["POST"]) def policy(): try: data = pickle.loads(request.data) result = self.inference(data) body = pickle.dumps(result, protocol=4) return Response(body, mimetype="application/octet-stream") except Exception as e: err_obj = { "error": str(e), "traceback": traceback.format_exc(), } body = pickle.dumps(err_obj, protocol=4) return Response( body, mimetype="application/octet-stream", status=500, ) def run(self): print("StarVLA policy server running") print(f"Host: {self.host}") print(f"Port: {self.port}") self.app.run( host=self.host, port=self.port, threaded=True ) if __name__ == "__main__": parser = argparse.ArgumentParser(description="StarVLA inference server") parser.add_argument( "--config", type=str, default="./benchmark.yaml", help="Path to benchmark.yaml", ) args = parser.parse_args() config_path = args.config server = StarvlaInferenceServer(config_path) server.run()