213 lines
7.0 KiB
Python
213 lines
7.0 KiB
Python
|
|
import yaml
|
|
import pickle
|
|
import argparse
|
|
import os
|
|
import traceback
|
|
from urllib.parse import urlparse
|
|
import numpy as np
|
|
import torch
|
|
import cv2
|
|
|
|
from flask import Flask, request, Response
|
|
from PIL import Image
|
|
|
|
def pad_to_dim(x: np.ndarray, target_dim: int, axis: int = -1, value: float = 0.0) -> np.ndarray:
|
|
"""Pad an array to the target dimension with zeros along the specified axis."""
|
|
current_dim = x.shape[axis]
|
|
if current_dim < target_dim:
|
|
pad_width = [(0, 0)] * len(x.shape)
|
|
pad_width[axis] = (0, target_dim - current_dim)
|
|
return np.pad(x, pad_width, constant_values=value)
|
|
return x
|
|
|
|
|
|
def normalize_states(states, statistics):
|
|
stats = statistics["new_embodiment"]["state"]
|
|
q01 = np.array(stats["q01"]).astype(states.dtype)
|
|
q99 = np.array(stats["q99"]).astype(states.dtype)
|
|
|
|
# In the case of q01 == q99, the normalization will be undefined
|
|
# So we set the normalized values to the original values
|
|
mask = q01 != q99
|
|
normalized = np.zeros_like(states)
|
|
|
|
# Normalize the values where q01 != q99
|
|
# Formula: 2 * (x - q01) / (q99 - q01) - 1
|
|
normalized[..., mask] = (states[..., mask] - q01[..., mask]) / (
|
|
q99[..., mask] - q01[..., mask]
|
|
)
|
|
normalized[..., mask] = 2 * normalized[..., mask] - 1
|
|
|
|
# Set the normalized values to the original values where q01 == q99
|
|
normalized[..., ~mask] = states[..., ~mask]
|
|
|
|
# Clip the normalized values to be between -1 and 1
|
|
normalized = np.clip(normalized, -1, 1)
|
|
return normalized
|
|
|
|
|
|
def unnormalize_actions(normalized_actions, statistics):
|
|
stats = statistics["new_embodiment"]["action"]
|
|
q01 = np.array(stats["q01"]).astype(normalized_actions.dtype)
|
|
q99 = np.array(stats["q99"]).astype(normalized_actions.dtype)
|
|
|
|
return (normalized_actions + 1) / 2 * (q99 - q01) + q01
|
|
|
|
class StarvlaInferenceServer:
|
|
|
|
def __init__(self, config_path: str):
|
|
|
|
with open(config_path, "r") as f:
|
|
cfg = yaml.safe_load(f)
|
|
|
|
policy_server_cfg = cfg["policy_server"]
|
|
root_paths = cfg["general"]["root_paths"]
|
|
self.ckpt_source = policy_server_cfg["ckpt_source"]
|
|
self.ckpt_path = self._resolve_ckpt_path(
|
|
ckpt_url=policy_server_cfg["ckpt_path"],
|
|
root_paths=root_paths,
|
|
)
|
|
self.host = policy_server_cfg.get("host", "0.0.0.0")
|
|
self.port = policy_server_cfg.get("port", 5000)
|
|
self.use_bf16 = policy_server_cfg.get("use_bf16", True)
|
|
|
|
print("Loading StarVLA model...")
|
|
self.model = self.load_model()
|
|
print("Model loaded.")
|
|
|
|
self.app = Flask(__name__)
|
|
self.register_routes()
|
|
|
|
@staticmethod
|
|
def _resolve_ckpt_path(ckpt_url: str, root_paths: dict) -> str:
|
|
parsed = urlparse(ckpt_url)
|
|
if not parsed.scheme:
|
|
return ckpt_url
|
|
root = root_paths.get(parsed.scheme)
|
|
if not root:
|
|
raise KeyError(
|
|
f"cannot find the checkpoint root path in root_paths: {root_paths}"
|
|
)
|
|
rel = (parsed.netloc + parsed.path).lstrip("/")
|
|
return os.path.join(root, rel)
|
|
|
|
def load_model(self):
|
|
|
|
from starVLA.model.framework.share_tools import read_mode_config, dict_to_namespace
|
|
from starVLA.model.framework.__init__ import build_framework
|
|
|
|
model_config, norm_stats = read_mode_config(self.ckpt_path)
|
|
|
|
cfg = dict_to_namespace(model_config)
|
|
cfg.trainer.pretrained_checkpoint = None
|
|
|
|
model = build_framework(cfg=cfg)
|
|
model.norm_stats = norm_stats
|
|
|
|
if self.use_bf16:
|
|
model = model.to(torch.bfloat16)
|
|
model = model.eval()
|
|
|
|
state_dict = torch.load(self.ckpt_path, map_location="cpu")
|
|
model.load_state_dict(state_dict, strict=True)
|
|
model = model.to("cuda")
|
|
|
|
|
|
self.norm_stats = norm_stats
|
|
return model
|
|
|
|
def parse_observation(self, obs, target_size=(320, 180)):
|
|
|
|
head_rgb = obs["rgb"]["head_camera"]
|
|
|
|
img_head = Image.fromarray(cv2.resize(head_rgb, target_size))
|
|
state_vec = normalize_states(obs["state"], self.norm_stats)
|
|
# state_vec = pad_to_dim(np.array(state_vec), 100, axis=-1)
|
|
return img_head, state_vec, obs["prompt"]
|
|
|
|
def inference(self, observation: dict) -> dict:
|
|
|
|
img_head, state_vec, prompt = \
|
|
self.parse_observation(observation, target_size=(410, 224))
|
|
vla_input = {
|
|
# "batch_images": [[img_left, img_right, img_wrist]],
|
|
"image": [img_head],
|
|
"lang": prompt,
|
|
"state": state_vec[None, :], # (1, 62)
|
|
}
|
|
|
|
with torch.no_grad():
|
|
output = self.model.predict_action(examples=vla_input)
|
|
|
|
actions = output.get("normalized_actions")
|
|
|
|
if isinstance(actions, torch.Tensor):
|
|
actions = actions.cpu().numpy()
|
|
|
|
if actions.ndim == 3:
|
|
actions = actions[0] # (16, 10)
|
|
actions = unnormalize_actions(actions, self.norm_stats)
|
|
return {"left_arm": {
|
|
"ee_delta_position_chunks": actions[:, :3].tolist(),
|
|
"ee_delta_rot6d_chunks": actions[:, 3:9].tolist(),
|
|
"finger_chunks": actions[:, 9:31].tolist()},
|
|
"right_arm": {
|
|
"ee_delta_position_chunks": actions[:, 31:34].tolist(),
|
|
"ee_delta_rot6d_chunks": actions[:, 34:40].tolist(),
|
|
"finger_chunks": actions[:, 40:62].tolist()}
|
|
}
|
|
|
|
def register_routes(self):
|
|
|
|
@self.app.route("/policy/inference", methods=["POST"])
|
|
def policy():
|
|
try:
|
|
data = pickle.loads(request.data)
|
|
result = self.inference(data)
|
|
body = pickle.dumps(result, protocol=4)
|
|
return Response(body, mimetype="application/octet-stream")
|
|
except Exception as e:
|
|
err_obj = {
|
|
"error": str(e),
|
|
"traceback": traceback.format_exc(),
|
|
}
|
|
body = pickle.dumps(err_obj, protocol=4)
|
|
return Response(
|
|
body,
|
|
mimetype="application/octet-stream",
|
|
status=500,
|
|
)
|
|
@self.app.route("/policy/health", methods=["GET"])
|
|
def health():
|
|
if self.model is None:
|
|
return Response("Failed to load model", mimetype="application/json", status=500)
|
|
return Response("OK", mimetype="application/json")
|
|
|
|
def run(self):
|
|
|
|
print("StarVLA policy server running")
|
|
print(f"Host: {self.host}")
|
|
print(f"Port: {self.port}")
|
|
|
|
self.app.run(
|
|
host=self.host,
|
|
port=self.port,
|
|
threaded=True
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description="StarVLA inference server")
|
|
parser.add_argument(
|
|
"--config",
|
|
type=str,
|
|
default="./benchmark.yaml",
|
|
help="Path to benchmark.yaml",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
config_path = args.config
|
|
server = StarvlaInferenceServer(config_path)
|
|
server.run()
|