Files
starvla_benchmark/starvla_inference_server.py
2026-03-18 18:12:34 +08:00

173 lines
4.8 KiB
Python

import yaml
import pickle
import os
from urllib.parse import urlparse
import numpy as np
import torch
import cv2
from flask import Flask, request, Response
from PIL import Image
class StarvlaInferenceServer:
def __init__(self, config_path: str):
with open(config_path, "r") as f:
cfg = yaml.safe_load(f)
policy_server_cfg = cfg["policy_server"]
root_paths = cfg["general"]["root_paths"]
self.ckpt_source = policy_server_cfg["ckpt_source"]
self.ckpt_path = self._resolve_ckpt_path(
ckpt_url=policy_server_cfg["ckpt_path"],
root_paths=root_paths,
)
self.host = policy_server_cfg.get("host", "0.0.0.0")
self.port = policy_server_cfg.get("port", 5000)
self.use_bf16 = policy_server_cfg.get("use_bf16", True)
self.unnorm_key = policy_server_cfg.get("unnorm_key", "oxe_bridge")
self.state_mode = policy_server_cfg.get("state_mode", "ee_pose7")
print("Loading StarVLA model...")
self.model = self.load_model()
print("Model loaded.")
self.app = Flask(__name__)
self.register_routes()
@staticmethod
def _resolve_ckpt_path(ckpt_url: str, root_paths: dict) -> str:
parsed = urlparse(ckpt_url)
if not parsed.scheme:
return ckpt_url
root = root_paths.get(parsed.scheme)
if not root:
raise KeyError(
f"cannot find the checkpoint root path in root_paths: {root_paths}"
)
rel = (parsed.netloc + parsed.path).lstrip("/")
return os.path.join(root, rel)
def load_model(self):
from starVLA.model.framework.share_tools import read_mode_config, dict_to_namespace
from starVLA.model.framework.__init__ import build_framework
model_config, norm_stats = read_mode_config(self.ckpt_path)
cfg = dict_to_namespace(model_config)
cfg.trainer.pretrained_checkpoint = None
model = build_framework(cfg=cfg)
model.norm_stats = norm_stats
state_dict = torch.load(self.ckpt_path, map_location="cpu")
model.load_state_dict(state_dict, strict=True)
if self.use_bf16:
model = model.to(torch.bfloat16)
model = model.to("cuda").eval()
self.norm_stats = norm_stats
self.action_norm_stats = norm_stats.get(self.unnorm_key, {}).get("action", None)
return model
def parse_observation(self, obs):
rgb = obs["rgb"][-1]
state = obs["state"][-1]
joint = obs.get("joint", None)
prompt = obs["prompt"]
left = rgb[:, :, :3]
right = rgb[:, :, 3:6]
wrist = rgb[:, :, 6:9]
target_size = (320, 180)
left = cv2.resize(left, target_size)
right = cv2.resize(right, target_size)
wrist = cv2.resize(wrist, target_size)
img_left = Image.fromarray(left)
img_right = Image.fromarray(right)
img_wrist = Image.fromarray(wrist)
if self.state_mode == "joint8":
joint_last = joint[-1]
gripper = state[9]
state_vec = np.concatenate(
[joint_last, np.array([gripper])],
axis=0
)
else:
xyz = state[0:3]
rot6d = state[3:9]
gripper = state[9]
state_vec = np.concatenate(
[xyz, rot6d[:3], np.array([gripper])],
axis=0
)
return img_left, img_right, img_wrist, state_vec, prompt
def inference(self, observation: dict) -> dict:
img_left, img_right, img_wrist, state_vec, prompt = \
self.parse_observation(observation)
vla_input = {
"batch_images": [[img_left, img_right, img_wrist]],
"instructions": [prompt],
"state": [state_vec]
}
with torch.no_grad():
output = self.model.predict_action(**vla_input)
actions = output.get("normalized_actions")
if isinstance(actions, torch.Tensor):
actions = actions.cpu().numpy()
if actions.ndim == 3:
actions = actions[0]
return {"action": actions.astype(np.float32)}
def register_routes(self):
@self.app.route("/policy", methods=["POST"])
def policy():
data = pickle.loads(request.data)
result = self.inference(data)
body = pickle.dumps(result, protocol=4)
return Response(body, mimetype="application/octet-stream")
def run(self):
print("StarVLA policy server running")
print(f"Host: {self.host}")
print(f"Port: {self.port}")
self.app.run(
host=self.host,
port=self.port,
threaded=True
)
if __name__ == "__main__":
config_path = "./benchmark.yaml"
server = StarvlaInferenceServer(config_path)
server.run()