import yaml import pickle import os from urllib.parse import urlparse import numpy as np import torch import cv2 from flask import Flask, request, Response from PIL import Image class StarvlaInferenceServer: def __init__(self, config_path: str): with open(config_path, "r") as f: cfg = yaml.safe_load(f) policy_server_cfg = cfg["policy_server"] root_paths = cfg["general"]["root_paths"] self.ckpt_source = policy_server_cfg["ckpt_source"] self.ckpt_path = self._resolve_ckpt_path( ckpt_url=policy_server_cfg["ckpt_path"], root_paths=root_paths, ) self.host = policy_server_cfg.get("host", "0.0.0.0") self.port = policy_server_cfg.get("port", 5000) self.use_bf16 = policy_server_cfg.get("use_bf16", True) self.unnorm_key = policy_server_cfg.get("unnorm_key", "oxe_bridge") self.state_mode = policy_server_cfg.get("state_mode", "ee_pose7") print("Loading StarVLA model...") self.model = self.load_model() print("Model loaded.") self.app = Flask(__name__) self.register_routes() @staticmethod def _resolve_ckpt_path(ckpt_url: str, root_paths: dict) -> str: parsed = urlparse(ckpt_url) if not parsed.scheme: return ckpt_url root = root_paths.get(parsed.scheme) if not root: raise KeyError( f"cannot find the checkpoint root path in root_paths: {root_paths}" ) rel = (parsed.netloc + parsed.path).lstrip("/") return os.path.join(root, rel) def load_model(self): from starVLA.model.framework.share_tools import read_mode_config, dict_to_namespace from starVLA.model.framework.__init__ import build_framework model_config, norm_stats = read_mode_config(self.ckpt_path) cfg = dict_to_namespace(model_config) cfg.trainer.pretrained_checkpoint = None model = build_framework(cfg=cfg) model.norm_stats = norm_stats state_dict = torch.load(self.ckpt_path, map_location="cpu") model.load_state_dict(state_dict, strict=True) if self.use_bf16: model = model.to(torch.bfloat16) model = model.to("cuda").eval() self.norm_stats = norm_stats self.action_norm_stats = norm_stats.get(self.unnorm_key, {}).get("action", None) return model def parse_observation(self, obs): rgb = obs["rgb"][-1] state = obs["state"][-1] joint = obs.get("joint", None) prompt = obs["prompt"] left = rgb[:, :, :3] right = rgb[:, :, 3:6] wrist = rgb[:, :, 6:9] target_size = (320, 180) left = cv2.resize(left, target_size) right = cv2.resize(right, target_size) wrist = cv2.resize(wrist, target_size) img_left = Image.fromarray(left) img_right = Image.fromarray(right) img_wrist = Image.fromarray(wrist) if self.state_mode == "joint8": joint_last = joint[-1] gripper = state[9] state_vec = np.concatenate( [joint_last, np.array([gripper])], axis=0 ) else: xyz = state[0:3] rot6d = state[3:9] gripper = state[9] state_vec = np.concatenate( [xyz, rot6d[:3], np.array([gripper])], axis=0 ) return img_left, img_right, img_wrist, state_vec, prompt def inference(self, observation: dict) -> dict: img_left, img_right, img_wrist, state_vec, prompt = \ self.parse_observation(observation) vla_input = { "batch_images": [[img_left, img_right, img_wrist]], "instructions": [prompt], "state": [state_vec] } with torch.no_grad(): output = self.model.predict_action(**vla_input) actions = output.get("normalized_actions") if isinstance(actions, torch.Tensor): actions = actions.cpu().numpy() if actions.ndim == 3: actions = actions[0] return {"action": actions.astype(np.float32)} def register_routes(self): @self.app.route("/policy", methods=["POST"]) def policy(): data = pickle.loads(request.data) result = self.inference(data) body = pickle.dumps(result, protocol=4) return Response(body, mimetype="application/octet-stream") def run(self): print("StarVLA policy server running") print(f"Host: {self.host}") print(f"Port: {self.port}") self.app.run( host=self.host, port=self.port, threaded=True ) if __name__ == "__main__": config_path = "./benchmark.yaml" server = StarvlaInferenceServer(config_path) server.run()