finish load inference server model
This commit is contained in:
173
starvla_inference_server.py
Normal file
173
starvla_inference_server.py
Normal file
@@ -0,0 +1,173 @@
|
||||
|
||||
import yaml
|
||||
import pickle
|
||||
import os
|
||||
from urllib.parse import urlparse
|
||||
import numpy as np
|
||||
import torch
|
||||
import cv2
|
||||
|
||||
from flask import Flask, request, Response
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class StarvlaInferenceServer:
|
||||
|
||||
def __init__(self, config_path: str):
|
||||
|
||||
with open(config_path, "r") as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
|
||||
policy_server_cfg = cfg["policy_server"]
|
||||
root_paths = cfg["general"]["root_paths"]
|
||||
self.ckpt_source = policy_server_cfg["ckpt_source"]
|
||||
self.ckpt_path = self._resolve_ckpt_path(
|
||||
ckpt_url=policy_server_cfg["ckpt_path"],
|
||||
root_paths=root_paths,
|
||||
)
|
||||
self.host = policy_server_cfg.get("host", "0.0.0.0")
|
||||
self.port = policy_server_cfg.get("port", 5000)
|
||||
self.use_bf16 = policy_server_cfg.get("use_bf16", True)
|
||||
self.unnorm_key = policy_server_cfg.get("unnorm_key", "oxe_bridge")
|
||||
self.state_mode = policy_server_cfg.get("state_mode", "ee_pose7")
|
||||
|
||||
print("Loading StarVLA model...")
|
||||
self.model = self.load_model()
|
||||
print("Model loaded.")
|
||||
|
||||
self.app = Flask(__name__)
|
||||
self.register_routes()
|
||||
|
||||
@staticmethod
|
||||
def _resolve_ckpt_path(ckpt_url: str, root_paths: dict) -> str:
|
||||
parsed = urlparse(ckpt_url)
|
||||
if not parsed.scheme:
|
||||
return ckpt_url
|
||||
root = root_paths.get(parsed.scheme)
|
||||
if not root:
|
||||
raise KeyError(
|
||||
f"cannot find the checkpoint root path in root_paths: {root_paths}"
|
||||
)
|
||||
rel = (parsed.netloc + parsed.path).lstrip("/")
|
||||
return os.path.join(root, rel)
|
||||
|
||||
def load_model(self):
|
||||
|
||||
from starVLA.model.framework.share_tools import read_mode_config, dict_to_namespace
|
||||
from starVLA.model.framework.__init__ import build_framework
|
||||
|
||||
model_config, norm_stats = read_mode_config(self.ckpt_path)
|
||||
|
||||
cfg = dict_to_namespace(model_config)
|
||||
cfg.trainer.pretrained_checkpoint = None
|
||||
|
||||
model = build_framework(cfg=cfg)
|
||||
model.norm_stats = norm_stats
|
||||
|
||||
state_dict = torch.load(self.ckpt_path, map_location="cpu")
|
||||
model.load_state_dict(state_dict, strict=True)
|
||||
|
||||
if self.use_bf16:
|
||||
model = model.to(torch.bfloat16)
|
||||
|
||||
model = model.to("cuda").eval()
|
||||
|
||||
self.norm_stats = norm_stats
|
||||
self.action_norm_stats = norm_stats.get(self.unnorm_key, {}).get("action", None)
|
||||
|
||||
return model
|
||||
|
||||
def parse_observation(self, obs):
|
||||
|
||||
rgb = obs["rgb"][-1]
|
||||
state = obs["state"][-1]
|
||||
joint = obs.get("joint", None)
|
||||
prompt = obs["prompt"]
|
||||
|
||||
left = rgb[:, :, :3]
|
||||
right = rgb[:, :, 3:6]
|
||||
wrist = rgb[:, :, 6:9]
|
||||
|
||||
target_size = (320, 180)
|
||||
|
||||
left = cv2.resize(left, target_size)
|
||||
right = cv2.resize(right, target_size)
|
||||
wrist = cv2.resize(wrist, target_size)
|
||||
|
||||
img_left = Image.fromarray(left)
|
||||
img_right = Image.fromarray(right)
|
||||
img_wrist = Image.fromarray(wrist)
|
||||
|
||||
if self.state_mode == "joint8":
|
||||
|
||||
joint_last = joint[-1]
|
||||
gripper = state[9]
|
||||
|
||||
state_vec = np.concatenate(
|
||||
[joint_last, np.array([gripper])],
|
||||
axis=0
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
xyz = state[0:3]
|
||||
rot6d = state[3:9]
|
||||
gripper = state[9]
|
||||
|
||||
state_vec = np.concatenate(
|
||||
[xyz, rot6d[:3], np.array([gripper])],
|
||||
axis=0
|
||||
)
|
||||
|
||||
return img_left, img_right, img_wrist, state_vec, prompt
|
||||
|
||||
def inference(self, observation: dict) -> dict:
|
||||
|
||||
img_left, img_right, img_wrist, state_vec, prompt = \
|
||||
self.parse_observation(observation)
|
||||
|
||||
vla_input = {
|
||||
"batch_images": [[img_left, img_right, img_wrist]],
|
||||
"instructions": [prompt],
|
||||
"state": [state_vec]
|
||||
}
|
||||
|
||||
with torch.no_grad():
|
||||
output = self.model.predict_action(**vla_input)
|
||||
|
||||
actions = output.get("normalized_actions")
|
||||
|
||||
if isinstance(actions, torch.Tensor):
|
||||
actions = actions.cpu().numpy()
|
||||
|
||||
if actions.ndim == 3:
|
||||
actions = actions[0]
|
||||
|
||||
return {"action": actions.astype(np.float32)}
|
||||
|
||||
def register_routes(self):
|
||||
|
||||
@self.app.route("/policy", methods=["POST"])
|
||||
def policy():
|
||||
data = pickle.loads(request.data)
|
||||
result = self.inference(data)
|
||||
body = pickle.dumps(result, protocol=4)
|
||||
return Response(body, mimetype="application/octet-stream")
|
||||
|
||||
def run(self):
|
||||
|
||||
print("StarVLA policy server running")
|
||||
print(f"Host: {self.host}")
|
||||
print(f"Port: {self.port}")
|
||||
|
||||
self.app.run(
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
threaded=True
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
config_path = "./benchmark.yaml"
|
||||
server = StarvlaInferenceServer(config_path)
|
||||
server.run()
|
||||
Reference in New Issue
Block a user