Files
starvla_benchmark/starvla_inference_server.py
2026-05-22 18:38:16 +08:00

213 lines
7.0 KiB
Python

import yaml
import pickle
import argparse
import os
import traceback
from urllib.parse import urlparse
import numpy as np
import torch
import cv2
from flask import Flask, request, Response
from PIL import Image
def pad_to_dim(x: np.ndarray, target_dim: int, axis: int = -1, value: float = 0.0) -> np.ndarray:
"""Pad an array to the target dimension with zeros along the specified axis."""
current_dim = x.shape[axis]
if current_dim < target_dim:
pad_width = [(0, 0)] * len(x.shape)
pad_width[axis] = (0, target_dim - current_dim)
return np.pad(x, pad_width, constant_values=value)
return x
def normalize_states(states, statistics):
stats = statistics["new_embodiment"]["state"]
q01 = np.array(stats["q01"]).astype(states.dtype)
q99 = np.array(stats["q99"]).astype(states.dtype)
# In the case of q01 == q99, the normalization will be undefined
# So we set the normalized values to the original values
mask = q01 != q99
normalized = np.zeros_like(states)
# Normalize the values where q01 != q99
# Formula: 2 * (x - q01) / (q99 - q01) - 1
normalized[..., mask] = (states[..., mask] - q01[..., mask]) / (
q99[..., mask] - q01[..., mask]
)
normalized[..., mask] = 2 * normalized[..., mask] - 1
# Set the normalized values to the original values where q01 == q99
normalized[..., ~mask] = states[..., ~mask]
# Clip the normalized values to be between -1 and 1
normalized = np.clip(normalized, -1, 1)
return normalized
def unnormalize_actions(normalized_actions, statistics):
stats = statistics["new_embodiment"]["action"]
q01 = np.array(stats["q01"]).astype(normalized_actions.dtype)
q99 = np.array(stats["q99"]).astype(normalized_actions.dtype)
return (normalized_actions + 1) / 2 * (q99 - q01) + q01
class StarvlaInferenceServer:
def __init__(self, config_path: str):
with open(config_path, "r") as f:
cfg = yaml.safe_load(f)
policy_server_cfg = cfg["policy_server"]
root_paths = cfg["general"]["root_paths"]
self.ckpt_source = policy_server_cfg["ckpt_source"]
self.ckpt_path = self._resolve_ckpt_path(
ckpt_url=policy_server_cfg["ckpt_path"],
root_paths=root_paths,
)
self.host = policy_server_cfg.get("host", "0.0.0.0")
self.port = policy_server_cfg.get("port", 5000)
self.use_bf16 = policy_server_cfg.get("use_bf16", True)
print("Loading StarVLA model...")
self.model = self.load_model()
print("Model loaded.")
self.app = Flask(__name__)
self.register_routes()
@staticmethod
def _resolve_ckpt_path(ckpt_url: str, root_paths: dict) -> str:
parsed = urlparse(ckpt_url)
if not parsed.scheme:
return ckpt_url
root = root_paths.get(parsed.scheme)
if not root:
raise KeyError(
f"cannot find the checkpoint root path in root_paths: {root_paths}"
)
rel = (parsed.netloc + parsed.path).lstrip("/")
return os.path.join(root, rel)
def load_model(self):
from starVLA.model.framework.share_tools import read_mode_config, dict_to_namespace
from starVLA.model.framework.__init__ import build_framework
model_config, norm_stats = read_mode_config(self.ckpt_path)
cfg = dict_to_namespace(model_config)
cfg.trainer.pretrained_checkpoint = None
model = build_framework(cfg=cfg)
model.norm_stats = norm_stats
if self.use_bf16:
model = model.to(torch.bfloat16)
model = model.eval()
state_dict = torch.load(self.ckpt_path, map_location="cpu")
model.load_state_dict(state_dict, strict=True)
model = model.to("cuda")
self.norm_stats = norm_stats
return model
def parse_observation(self, obs, target_size=(320, 180)):
head_rgb = obs["rgb"]["head_camera"]
img_head = Image.fromarray(cv2.resize(head_rgb, target_size))
state_vec = normalize_states(obs["state"], self.norm_stats)
# state_vec = pad_to_dim(np.array(state_vec), 100, axis=-1)
return img_head, state_vec, obs["prompt"]
def inference(self, observation: dict) -> dict:
img_head, state_vec, prompt = \
self.parse_observation(observation)
vla_input = {
# "batch_images": [[img_left, img_right, img_wrist]],
"image": [img_head],
"lang": prompt,
"state": state_vec[None, :], # (1, 62)
}
with torch.no_grad():
output = self.model.predict_action(examples=vla_input)
actions = output.get("normalized_actions")
if isinstance(actions, torch.Tensor):
actions = actions.cpu().numpy()
if actions.ndim == 3:
actions = actions[0] # (16, 10)
actions = unnormalize_actions(actions, self.norm_stats)
return {"left_arm": {
"ee_delta_position_chunks": actions[:, :3].tolist(),
"ee_delta_rot6d_chunks": actions[:, 3:9].tolist(),
"finger_chunks": actions[:, 9:31].tolist()},
"right_arm": {
"ee_delta_position_chunks": actions[:, 31:34].tolist(),
"ee_delta_rot6d_chunks": actions[:, 34:40].tolist(),
"finger_chunks": actions[:, 40:62].tolist()}
}
def register_routes(self):
@self.app.route("/policy/inference", methods=["POST"])
def policy():
try:
data = pickle.loads(request.data)
result = self.inference(data)
body = pickle.dumps(result, protocol=4)
return Response(body, mimetype="application/octet-stream")
except Exception as e:
err_obj = {
"error": str(e),
"traceback": traceback.format_exc(),
}
body = pickle.dumps(err_obj, protocol=4)
return Response(
body,
mimetype="application/octet-stream",
status=500,
)
@self.app.route("/policy/health", methods=["GET"])
def health():
if self.model is None:
return Response("Failed to load model", mimetype="application/json", status=500)
return Response("OK", mimetype="application/json")
def run(self):
print("StarVLA policy server running")
print(f"Host: {self.host}")
print(f"Port: {self.port}")
self.app.run(
host=self.host,
port=self.port,
threaded=True
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="StarVLA inference server")
parser.add_argument(
"--config",
type=str,
default="./benchmark.yaml",
help="Path to benchmark.yaml",
)
args = parser.parse_args()
config_path = args.config
server = StarvlaInferenceServer(config_path)
server.run()