173 lines
4.8 KiB
Python
173 lines
4.8 KiB
Python
|
|
import yaml
|
|
import pickle
|
|
import os
|
|
from urllib.parse import urlparse
|
|
import numpy as np
|
|
import torch
|
|
import cv2
|
|
|
|
from flask import Flask, request, Response
|
|
from PIL import Image
|
|
|
|
|
|
class StarvlaInferenceServer:
|
|
|
|
def __init__(self, config_path: str):
|
|
|
|
with open(config_path, "r") as f:
|
|
cfg = yaml.safe_load(f)
|
|
|
|
policy_server_cfg = cfg["policy_server"]
|
|
root_paths = cfg["general"]["root_paths"]
|
|
self.ckpt_source = policy_server_cfg["ckpt_source"]
|
|
self.ckpt_path = self._resolve_ckpt_path(
|
|
ckpt_url=policy_server_cfg["ckpt_path"],
|
|
root_paths=root_paths,
|
|
)
|
|
self.host = policy_server_cfg.get("host", "0.0.0.0")
|
|
self.port = policy_server_cfg.get("port", 5000)
|
|
self.use_bf16 = policy_server_cfg.get("use_bf16", True)
|
|
self.unnorm_key = policy_server_cfg.get("unnorm_key", "oxe_bridge")
|
|
self.state_mode = policy_server_cfg.get("state_mode", "ee_pose7")
|
|
|
|
print("Loading StarVLA model...")
|
|
self.model = self.load_model()
|
|
print("Model loaded.")
|
|
|
|
self.app = Flask(__name__)
|
|
self.register_routes()
|
|
|
|
@staticmethod
|
|
def _resolve_ckpt_path(ckpt_url: str, root_paths: dict) -> str:
|
|
parsed = urlparse(ckpt_url)
|
|
if not parsed.scheme:
|
|
return ckpt_url
|
|
root = root_paths.get(parsed.scheme)
|
|
if not root:
|
|
raise KeyError(
|
|
f"cannot find the checkpoint root path in root_paths: {root_paths}"
|
|
)
|
|
rel = (parsed.netloc + parsed.path).lstrip("/")
|
|
return os.path.join(root, rel)
|
|
|
|
def load_model(self):
|
|
|
|
from starVLA.model.framework.share_tools import read_mode_config, dict_to_namespace
|
|
from starVLA.model.framework.__init__ import build_framework
|
|
|
|
model_config, norm_stats = read_mode_config(self.ckpt_path)
|
|
|
|
cfg = dict_to_namespace(model_config)
|
|
cfg.trainer.pretrained_checkpoint = None
|
|
|
|
model = build_framework(cfg=cfg)
|
|
model.norm_stats = norm_stats
|
|
|
|
state_dict = torch.load(self.ckpt_path, map_location="cpu")
|
|
model.load_state_dict(state_dict, strict=True)
|
|
|
|
if self.use_bf16:
|
|
model = model.to(torch.bfloat16)
|
|
|
|
model = model.to("cuda").eval()
|
|
|
|
self.norm_stats = norm_stats
|
|
self.action_norm_stats = norm_stats.get(self.unnorm_key, {}).get("action", None)
|
|
|
|
return model
|
|
|
|
def parse_observation(self, obs):
|
|
|
|
rgb = obs["rgb"][-1]
|
|
state = obs["state"][-1]
|
|
joint = obs.get("joint", None)
|
|
prompt = obs["prompt"]
|
|
|
|
left = rgb[:, :, :3]
|
|
right = rgb[:, :, 3:6]
|
|
wrist = rgb[:, :, 6:9]
|
|
|
|
target_size = (320, 180)
|
|
|
|
left = cv2.resize(left, target_size)
|
|
right = cv2.resize(right, target_size)
|
|
wrist = cv2.resize(wrist, target_size)
|
|
|
|
img_left = Image.fromarray(left)
|
|
img_right = Image.fromarray(right)
|
|
img_wrist = Image.fromarray(wrist)
|
|
|
|
if self.state_mode == "joint8":
|
|
|
|
joint_last = joint[-1]
|
|
gripper = state[9]
|
|
|
|
state_vec = np.concatenate(
|
|
[joint_last, np.array([gripper])],
|
|
axis=0
|
|
)
|
|
|
|
else:
|
|
|
|
xyz = state[0:3]
|
|
rot6d = state[3:9]
|
|
gripper = state[9]
|
|
|
|
state_vec = np.concatenate(
|
|
[xyz, rot6d[:3], np.array([gripper])],
|
|
axis=0
|
|
)
|
|
|
|
return img_left, img_right, img_wrist, state_vec, prompt
|
|
|
|
def inference(self, observation: dict) -> dict:
|
|
|
|
img_left, img_right, img_wrist, state_vec, prompt = \
|
|
self.parse_observation(observation)
|
|
|
|
vla_input = {
|
|
"batch_images": [[img_left, img_right, img_wrist]],
|
|
"instructions": [prompt],
|
|
"state": [state_vec]
|
|
}
|
|
|
|
with torch.no_grad():
|
|
output = self.model.predict_action(**vla_input)
|
|
|
|
actions = output.get("normalized_actions")
|
|
|
|
if isinstance(actions, torch.Tensor):
|
|
actions = actions.cpu().numpy()
|
|
|
|
if actions.ndim == 3:
|
|
actions = actions[0]
|
|
|
|
return {"action": actions.astype(np.float32)}
|
|
|
|
def register_routes(self):
|
|
|
|
@self.app.route("/policy", methods=["POST"])
|
|
def policy():
|
|
data = pickle.loads(request.data)
|
|
result = self.inference(data)
|
|
body = pickle.dumps(result, protocol=4)
|
|
return Response(body, mimetype="application/octet-stream")
|
|
|
|
def run(self):
|
|
|
|
print("StarVLA policy server running")
|
|
print(f"Host: {self.host}")
|
|
print(f"Port: {self.port}")
|
|
|
|
self.app.run(
|
|
host=self.host,
|
|
port=self.port,
|
|
threaded=True
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
config_path = "./benchmark.yaml"
|
|
server = StarvlaInferenceServer(config_path)
|
|
server.run() |