import yaml import pickle import argparse import os import traceback from urllib.parse import urlparse import numpy as np import torch import cv2 from flask import Flask, request, Response from PIL import Image def pad_to_dim(x: np.ndarray, target_dim: int, axis: int = -1, value: float = 0.0) -> np.ndarray: """Pad an array to the target dimension with zeros along the specified axis.""" current_dim = x.shape[axis] if current_dim < target_dim: pad_width = [(0, 0)] * len(x.shape) pad_width[axis] = (0, target_dim - current_dim) return np.pad(x, pad_width, constant_values=value) return x def normalize_states(states, statistics): stats = statistics["new_embodiment"]["state"] q01 = np.array(stats["q01"]).astype(states.dtype) q99 = np.array(stats["q99"]).astype(states.dtype) # In the case of q01 == q99, the normalization will be undefined # So we set the normalized values to the original values mask = q01 != q99 normalized = np.zeros_like(states) # Normalize the values where q01 != q99 # Formula: 2 * (x - q01) / (q99 - q01) - 1 normalized[..., mask] = (states[..., mask] - q01[..., mask]) / ( q99[..., mask] - q01[..., mask] ) normalized[..., mask] = 2 * normalized[..., mask] - 1 # Set the normalized values to the original values where q01 == q99 normalized[..., ~mask] = states[..., ~mask] # Clip the normalized values to be between -1 and 1 normalized = np.clip(normalized, -1, 1) return normalized def unnormalize_actions(normalized_actions, statistics): stats = statistics["new_embodiment"]["action"] q01 = np.array(stats["q01"]).astype(normalized_actions.dtype) q99 = np.array(stats["q99"]).astype(normalized_actions.dtype) return (normalized_actions + 1) / 2 * (q99 - q01) + q01 class StarvlaInferenceServer: def __init__(self, config_path: str): with open(config_path, "r") as f: cfg = yaml.safe_load(f) policy_server_cfg = cfg["policy_server"] root_paths = cfg["general"]["root_paths"] self.ckpt_source = policy_server_cfg["ckpt_source"] self.ckpt_path = self._resolve_ckpt_path( ckpt_url=policy_server_cfg["ckpt_path"], root_paths=root_paths, ) self.host = policy_server_cfg.get("host", "0.0.0.0") self.port = policy_server_cfg.get("port", 5000) self.use_bf16 = policy_server_cfg.get("use_bf16", True) print("Loading StarVLA model...") self.model = self.load_model() print("Model loaded.") self.app = Flask(__name__) self.register_routes() @staticmethod def _resolve_ckpt_path(ckpt_url: str, root_paths: dict) -> str: parsed = urlparse(ckpt_url) if not parsed.scheme: return ckpt_url root = root_paths.get(parsed.scheme) if not root: raise KeyError( f"cannot find the checkpoint root path in root_paths: {root_paths}" ) rel = (parsed.netloc + parsed.path).lstrip("/") return os.path.join(root, rel) def load_model(self): from starVLA.model.framework.share_tools import read_mode_config, dict_to_namespace from starVLA.model.framework.__init__ import build_framework model_config, norm_stats = read_mode_config(self.ckpt_path) cfg = dict_to_namespace(model_config) cfg.trainer.pretrained_checkpoint = None model = build_framework(cfg=cfg) model.norm_stats = norm_stats if self.use_bf16: model = model.to(torch.bfloat16) model = model.eval() state_dict = torch.load(self.ckpt_path, map_location="cpu") model.load_state_dict(state_dict, strict=True) model = model.to("cuda") self.norm_stats = norm_stats return model def parse_observation(self, obs, target_size=(320, 180)): head_rgb = obs["rgb"]["head_camera"] img_head = Image.fromarray(cv2.resize(head_rgb, target_size)) state_vec = normalize_states(obs["state"], self.norm_stats) # state_vec = pad_to_dim(np.array(state_vec), 100, axis=-1) return img_head, state_vec, obs["prompt"] def inference(self, observation: dict) -> dict: img_head, state_vec, prompt = \ self.parse_observation(observation) vla_input = { # "batch_images": [[img_left, img_right, img_wrist]], "image": [img_head], "lang": prompt, "state": state_vec[None, :], # (1, 62) } with torch.no_grad(): output = self.model.predict_action(examples=vla_input) actions = output.get("normalized_actions") if isinstance(actions, torch.Tensor): actions = actions.cpu().numpy() if actions.ndim == 3: actions = actions[0] # (16, 10) actions = unnormalize_actions(actions, self.norm_stats) return {"left_arm": { "ee_delta_position_chunks": actions[:, :3].tolist(), "ee_delta_rot6d_chunks": actions[:, 3:9].tolist(), "finger_chunks": actions[:, 9:31].tolist()}, "right_arm": { "ee_delta_position_chunks": actions[:, 31:34].tolist(), "ee_delta_rot6d_chunks": actions[:, 34:40].tolist(), "finger_chunks": actions[:, 40:62].tolist()} } def register_routes(self): @self.app.route("/policy/inference", methods=["POST"]) def policy(): try: data = pickle.loads(request.data) result = self.inference(data) body = pickle.dumps(result, protocol=4) return Response(body, mimetype="application/octet-stream") except Exception as e: err_obj = { "error": str(e), "traceback": traceback.format_exc(), } body = pickle.dumps(err_obj, protocol=4) return Response( body, mimetype="application/octet-stream", status=500, ) @self.app.route("/policy/health", methods=["GET"]) def health(): if self.model is None: return Response("Failed to load model", mimetype="application/json", status=500) return Response("OK", mimetype="application/json") def run(self): print("StarVLA policy server running") print(f"Host: {self.host}") print(f"Port: {self.port}") self.app.run( host=self.host, port=self.port, threaded=True ) if __name__ == "__main__": parser = argparse.ArgumentParser(description="StarVLA inference server") parser.add_argument( "--config", type=str, default="./benchmark.yaml", help="Path to benchmark.yaml", ) args = parser.parse_args() config_path = args.config server = StarvlaInferenceServer(config_path) server.run()