179 lines
5.9 KiB
Python
179 lines
5.9 KiB
Python
|
|
import yaml
|
|
import pickle
|
|
import argparse
|
|
import os
|
|
import traceback
|
|
from urllib.parse import urlparse
|
|
import numpy as np
|
|
import torch
|
|
import cv2
|
|
|
|
from flask import Flask, request, Response
|
|
from PIL import Image
|
|
|
|
def pad_to_dim(x: np.ndarray, target_dim: int, axis: int = -1, value: float = 0.0) -> np.ndarray:
|
|
"""Pad an array to the target dimension with zeros along the specified axis."""
|
|
current_dim = x.shape[axis]
|
|
if current_dim < target_dim:
|
|
pad_width = [(0, 0)] * len(x.shape)
|
|
pad_width[axis] = (0, target_dim - current_dim)
|
|
return np.pad(x, pad_width, constant_values=value)
|
|
return x
|
|
|
|
class StarvlaInferenceServer:
|
|
|
|
def __init__(self, config_path: str):
|
|
|
|
with open(config_path, "r") as f:
|
|
cfg = yaml.safe_load(f)
|
|
|
|
policy_server_cfg = cfg["policy_server"]
|
|
root_paths = cfg["general"]["root_paths"]
|
|
self.ckpt_source = policy_server_cfg["ckpt_source"]
|
|
self.ckpt_path = self._resolve_ckpt_path(
|
|
ckpt_url=policy_server_cfg["ckpt_path"],
|
|
root_paths=root_paths,
|
|
)
|
|
self.host = policy_server_cfg.get("host", "0.0.0.0")
|
|
self.port = policy_server_cfg.get("port", 5000)
|
|
self.use_bf16 = policy_server_cfg.get("use_bf16", True)
|
|
self.unnorm_key = policy_server_cfg.get("unnorm_key", "oxe_bridge")
|
|
self.state_mode = policy_server_cfg.get("state_mode", "ee_pose7")
|
|
|
|
print("Loading StarVLA model...")
|
|
self.model = self.load_model()
|
|
print("Model loaded.")
|
|
|
|
self.app = Flask(__name__)
|
|
self.register_routes()
|
|
|
|
@staticmethod
|
|
def _resolve_ckpt_path(ckpt_url: str, root_paths: dict) -> str:
|
|
parsed = urlparse(ckpt_url)
|
|
if not parsed.scheme:
|
|
return ckpt_url
|
|
root = root_paths.get(parsed.scheme)
|
|
if not root:
|
|
raise KeyError(
|
|
f"cannot find the checkpoint root path in root_paths: {root_paths}"
|
|
)
|
|
rel = (parsed.netloc + parsed.path).lstrip("/")
|
|
return os.path.join(root, rel)
|
|
|
|
def load_model(self):
|
|
|
|
from starVLA.model.framework.share_tools import read_mode_config, dict_to_namespace
|
|
from starVLA.model.framework.__init__ import build_framework
|
|
|
|
model_config, norm_stats = read_mode_config(self.ckpt_path)
|
|
|
|
cfg = dict_to_namespace(model_config)
|
|
cfg.trainer.pretrained_checkpoint = None
|
|
|
|
model = build_framework(cfg=cfg)
|
|
model.norm_stats = norm_stats
|
|
|
|
state_dict = torch.load(self.ckpt_path, map_location="cpu")
|
|
model.load_state_dict(state_dict, strict=True)
|
|
|
|
if self.use_bf16:
|
|
model = model.to(torch.bfloat16)
|
|
|
|
model = model.to("cuda").eval()
|
|
|
|
self.norm_stats = norm_stats
|
|
self.action_norm_stats = norm_stats.get(self.unnorm_key, {}).get("action", None)
|
|
|
|
return model
|
|
|
|
def parse_observation(self, obs, target_size=(320, 180)):
|
|
|
|
left_rgb, right_rgb, wrist_rgb = obs["rgb"]["Left_Camera"], obs["rgb"]["Right_Camera"], obs["rgb"]["Hand_Camera"]
|
|
|
|
img_left = Image.fromarray(cv2.resize(left_rgb, target_size))
|
|
img_right = Image.fromarray(cv2.resize(right_rgb, target_size))
|
|
img_wrist = Image.fromarray(cv2.resize(wrist_rgb, target_size))
|
|
|
|
state_vec = obs["state"]
|
|
# import ipdb;ipdb.set_trace()
|
|
# state_vec = pad_to_dim(np.array(state_vec), 100, axis=-1)
|
|
return img_left, img_right, img_wrist, state_vec, obs["prompt"]
|
|
|
|
def inference(self, observation: dict) -> dict:
|
|
|
|
img_left, img_right, img_wrist, state_vec, prompt = \
|
|
self.parse_observation(observation)
|
|
print(f"{state_vec.shape}")
|
|
vla_input = {
|
|
"batch_images": [[img_left, img_right, img_wrist]],
|
|
"instructions": [prompt],
|
|
"state": [state_vec]
|
|
}
|
|
|
|
with torch.no_grad():
|
|
output = self.model.predict_action(**vla_input)
|
|
|
|
actions = output.get("normalized_actions")
|
|
|
|
if isinstance(actions, torch.Tensor):
|
|
actions = actions.cpu().numpy()
|
|
|
|
if actions.ndim == 3:
|
|
actions = actions[0] # (16, 10)
|
|
return {"ee_delta_position_chunks": actions[:, :3].tolist(),
|
|
"ee_delta_rot6d_chunks": actions[:, 3:9].tolist(),
|
|
"gripper_width_chunks": actions[:, 9:10].tolist()}
|
|
|
|
def register_routes(self):
|
|
|
|
@self.app.route("/policy/inference", methods=["POST"])
|
|
def policy():
|
|
try:
|
|
data = pickle.loads(request.data)
|
|
result = self.inference(data)
|
|
body = pickle.dumps(result, protocol=4)
|
|
return Response(body, mimetype="application/octet-stream")
|
|
except Exception as e:
|
|
err_obj = {
|
|
"error": str(e),
|
|
"traceback": traceback.format_exc(),
|
|
}
|
|
body = pickle.dumps(err_obj, protocol=4)
|
|
return Response(
|
|
body,
|
|
mimetype="application/octet-stream",
|
|
status=500,
|
|
)
|
|
@self.app.route("/policy/health", methods=["GET"])
|
|
def health():
|
|
if self.model is None:
|
|
return Response("Failed to load model", mimetype="application/json", status=500)
|
|
return Response("OK", mimetype="application/json")
|
|
|
|
def run(self):
|
|
|
|
print("StarVLA policy server running")
|
|
print(f"Host: {self.host}")
|
|
print(f"Port: {self.port}")
|
|
|
|
self.app.run(
|
|
host=self.host,
|
|
port=self.port,
|
|
threaded=True
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description="StarVLA inference server")
|
|
parser.add_argument(
|
|
"--config",
|
|
type=str,
|
|
default="./benchmark.yaml",
|
|
help="Path to benchmark.yaml",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
config_path = args.config
|
|
server = StarvlaInferenceServer(config_path)
|
|
server.run() |