update
This commit is contained in:
257
starvla_inference_server_unnorm.py
Normal file
257
starvla_inference_server_unnorm.py
Normal file
@@ -0,0 +1,257 @@
|
||||
|
||||
import yaml
|
||||
import pickle
|
||||
import argparse
|
||||
import os
|
||||
import traceback
|
||||
from urllib.parse import urlparse
|
||||
import numpy as np
|
||||
import torch
|
||||
import cv2
|
||||
import json
|
||||
from flask import Flask, request, Response
|
||||
from PIL import Image
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def pad_to_dim(x: np.ndarray, target_dim: int, axis: int = -1, value: float = 0.0) -> np.ndarray:
|
||||
"""Pad an array to the target dimension with zeros along the specified axis."""
|
||||
current_dim = x.shape[axis]
|
||||
if current_dim < target_dim:
|
||||
pad_width = [(0, 0)] * len(x.shape)
|
||||
pad_width[axis] = (0, target_dim - current_dim)
|
||||
return np.pad(x, pad_width, constant_values=value)
|
||||
return x
|
||||
|
||||
def normalize_eepose_state(raw_state, stats, start_idx, norm_len=3):
|
||||
"""
|
||||
根据统计信息对原始动作的特定切片进行归一化 (Normalize),映射到 [-1.0, 1.0] 区间
|
||||
|
||||
:param raw_action: 包含原始动作的 numpy 数组,形状通常为 (batch_size, action_dim)
|
||||
:param stats: 归一化所需的统计信息字典 (包含 'q01' 和 'q99')
|
||||
:param start_idx: 需要归一化的特征在 action 向量中的起始索引
|
||||
:param norm_len: 需要归一化的特征长度,默认为 3 (如 xyz 坐标)
|
||||
:return: 归一化后的动作数组 (返回新数组,不修改原数组)
|
||||
"""
|
||||
# 1. 提取并截取对应的 q01 和 q99,保持 float16 精度对齐
|
||||
q01 = np.array(stats["q01"][start_idx : start_idx + norm_len], dtype=np.float16)
|
||||
q99 = np.array(stats["q99"][start_idx : start_idx + norm_len], dtype=np.float16)
|
||||
|
||||
# 2. 计算分母 denom,防除零保护
|
||||
denom = np.clip(q99 - q01, a_min=1e-5, a_max=None)
|
||||
|
||||
# 3. 定位切片
|
||||
target_slice = slice(start_idx, start_idx + norm_len)
|
||||
|
||||
# 4. 复制一份以防污染原始数据矩阵
|
||||
norm_state = raw_state.copy()
|
||||
x = norm_state[:, target_slice]
|
||||
|
||||
# 5. 执行正向映射计算:y = 2 * (x - q01) / denom - 1
|
||||
y = 2.0 * (x - q01) / denom - 1.0
|
||||
|
||||
# 6. 核心操作:截断到 [-1.0, 1.0] 区间,防止输入给模型的动作越界
|
||||
norm_state[:, target_slice] = np.clip(y, -1.0, 1.0)
|
||||
|
||||
return norm_state
|
||||
|
||||
def unnormalize_eepose_action(normalized_action, stats, start_idx, norm_len=3):
|
||||
"""
|
||||
读取统计信息 JSON 并对动作的特定切片进行反归一化 (Un-normalize)
|
||||
|
||||
:param normalized_action: 包含归一化动作的 numpy 数组,形状通常为 (batch_size, action_dim)
|
||||
:param stats: 反归一化所需的统计信息,通常是从 JSON 文件中读取的字典
|
||||
:param start_idx: 需要反归一化的特征在 action 向量中的起始索引 (即原代码中的 start)
|
||||
:param norm_len: 需要反归一化的特征长度,默认为 3 (如 xyz 坐标)
|
||||
:return: 反归一化后的动作数组 (返回新数组,不修改原数组)
|
||||
"""
|
||||
|
||||
# 2. 提取并截取对应的 q01 和 q99
|
||||
# 注意:为了和原归一化精度对齐,这里继续保持 float16
|
||||
# import ipdb;ipdb.set_trace()
|
||||
q01 = np.array(stats["q01"][start_idx : start_idx + norm_len], dtype=np.float16)
|
||||
q99 = np.array(stats["q99"][start_idx : start_idx + norm_len], dtype=np.float16)
|
||||
|
||||
# 3. 计算分母 denom,保持和原来完全相同的裁剪逻辑防除零
|
||||
denom = np.clip(q99 - q01, a_min=1e-5, a_max=None)
|
||||
|
||||
# 4. 定位切片
|
||||
target_slice = slice(start_idx, start_idx + norm_len)
|
||||
|
||||
# 5. 执行反向计算:x = (y + 1) / 2 * denom + q01
|
||||
# 复制一份以防污染原始的 normalized_action 矩阵
|
||||
unnorm_action = normalized_action.copy()
|
||||
y = unnorm_action[:, target_slice]
|
||||
unnorm_action[:, target_slice] = (y + 1.0) / 2.0 * denom + q01
|
||||
|
||||
return unnorm_action
|
||||
|
||||
class StarvlaInferenceServer:
|
||||
|
||||
def __init__(self, config_path: str):
|
||||
|
||||
with open(config_path, "r") as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
# import ipdb;ipdb.set_trace()
|
||||
policy_server_cfg = cfg["policy_server"]
|
||||
root_paths = cfg["general"]["root_paths"]
|
||||
self.ckpt_source = policy_server_cfg["ckpt_source"]
|
||||
self.ckpt_path = self._resolve_ckpt_path(
|
||||
ckpt_url=policy_server_cfg["ckpt_path"],
|
||||
root_paths=root_paths,
|
||||
)
|
||||
self.rel_eepose_stats = self.read_rel_eepose_stats()
|
||||
self.host = policy_server_cfg.get("host", "0.0.0.0")
|
||||
self.port = policy_server_cfg.get("port", 5000)
|
||||
self.use_bf16 = policy_server_cfg.get("use_bf16", True)
|
||||
self.unnorm_key = policy_server_cfg.get("unnorm_key", "oxe_bridge")
|
||||
self.state_mode = policy_server_cfg.get("state_mode", "ee_pose7")
|
||||
|
||||
print("Loading StarVLA model...")
|
||||
self.model = self.load_model()
|
||||
print("Model loaded.")
|
||||
|
||||
self.app = Flask(__name__)
|
||||
self.register_routes()
|
||||
|
||||
@staticmethod
|
||||
def _resolve_ckpt_path(ckpt_url: str, root_paths: dict) -> str:
|
||||
parsed = urlparse(ckpt_url)
|
||||
if not parsed.scheme:
|
||||
return ckpt_url
|
||||
root = root_paths.get(parsed.scheme)
|
||||
if not root:
|
||||
raise KeyError(
|
||||
f"cannot find the checkpoint root path in root_paths: {root_paths}"
|
||||
)
|
||||
rel = (parsed.netloc + parsed.path).lstrip("/")
|
||||
return os.path.join(root, rel)
|
||||
|
||||
def read_rel_eepose_stats(self):
|
||||
stats_path = Path(self.ckpt_path).parents[1] / "action_delta_eepose_stats.json"
|
||||
if stats_path is None or not Path(stats_path).exists():
|
||||
return {}
|
||||
with open(stats_path, 'r', encoding='utf-8') as f:
|
||||
stats = json.load(f)
|
||||
return stats
|
||||
|
||||
def load_model(self):
|
||||
|
||||
from starVLA.model.framework.share_tools import read_mode_config, dict_to_namespace
|
||||
from starVLA.model.framework.__init__ import build_framework
|
||||
|
||||
model_config, norm_stats = read_mode_config(self.ckpt_path)
|
||||
|
||||
cfg = dict_to_namespace(model_config)
|
||||
cfg.trainer.pretrained_checkpoint = None
|
||||
|
||||
model = build_framework(cfg=cfg)
|
||||
model.norm_stats = norm_stats
|
||||
|
||||
state_dict = torch.load(self.ckpt_path, map_location="cpu")
|
||||
model.load_state_dict(state_dict, strict=True)
|
||||
|
||||
if self.use_bf16:
|
||||
model = model.to(torch.bfloat16)
|
||||
|
||||
model = model.to("cuda").eval()
|
||||
|
||||
self.norm_stats = norm_stats
|
||||
self.action_norm_stats = norm_stats.get(self.unnorm_key, {}).get("action", None)
|
||||
|
||||
return model
|
||||
|
||||
def parse_observation(self, obs, target_size=(320, 180)):
|
||||
|
||||
left_rgb, right_rgb, wrist_rgb = obs["rgb"]["Left_Camera"], obs["rgb"]["Right_Camera"], obs["rgb"]["Hand_Camera"]
|
||||
|
||||
img_left = Image.fromarray(cv2.resize(left_rgb, target_size))
|
||||
img_right = Image.fromarray(cv2.resize(right_rgb, target_size))
|
||||
img_wrist = Image.fromarray(cv2.resize(wrist_rgb, target_size))
|
||||
|
||||
state_vec = obs["state"] #[:10]
|
||||
# import ipdb;ipdb.set_trace()
|
||||
state_vec = pad_to_dim(np.array(state_vec), 100, axis=-1)
|
||||
return img_left, img_right, img_wrist, state_vec, obs["prompt"]
|
||||
|
||||
def inference(self, observation: dict) -> dict:
|
||||
|
||||
img_left, img_right, img_wrist, state_vec, prompt = \
|
||||
self.parse_observation(observation)
|
||||
print(f"{state_vec.shape}")
|
||||
vla_input = {
|
||||
"batch_images": [[img_left, img_right, img_wrist]],
|
||||
"instructions": [prompt],
|
||||
"state": [state_vec]
|
||||
}
|
||||
# import ipdb;ipdb.set_trace()
|
||||
with torch.no_grad():
|
||||
output = self.model.predict_action(**vla_input)
|
||||
|
||||
actions = output.get("normalized_actions")
|
||||
|
||||
if isinstance(actions, torch.Tensor):
|
||||
actions = actions.cpu().numpy()
|
||||
|
||||
if actions.ndim == 3:
|
||||
actions = actions[0] # (16, 10)
|
||||
# import ipdb;ipdb.set_trace()
|
||||
# 反归一化特定切片 (假设需要反归一化的部分是 action 向量中的第 0-8 维,即 ee_pose7 + gripper_width)
|
||||
actions[:, :3] = unnormalize_eepose_action(actions[:, :3], self.rel_eepose_stats, start_idx=0, norm_len=3)
|
||||
# from ipdb import set_trace; set_trace()
|
||||
return {"ee_delta_position_chunks": actions[:, :3].tolist(),
|
||||
"ee_delta_rot6d_chunks": actions[:, 3:9].tolist(),
|
||||
"gripper_width_chunks": actions[:, 9:10].tolist()}
|
||||
|
||||
def register_routes(self):
|
||||
|
||||
@self.app.route("/policy/inference", methods=["POST"])
|
||||
def policy():
|
||||
try:
|
||||
data = pickle.loads(request.data)
|
||||
result = self.inference(data)
|
||||
body = pickle.dumps(result, protocol=4)
|
||||
return Response(body, mimetype="application/octet-stream")
|
||||
except Exception as e:
|
||||
err_obj = {
|
||||
"error": str(e),
|
||||
"traceback": traceback.format_exc(),
|
||||
}
|
||||
body = pickle.dumps(err_obj, protocol=4)
|
||||
return Response(
|
||||
body,
|
||||
mimetype="application/octet-stream",
|
||||
status=500,
|
||||
)
|
||||
@self.app.route("/policy/health", methods=["GET"])
|
||||
def health():
|
||||
if self.model is None:
|
||||
return Response("Failed to load model", mimetype="application/json", status=500)
|
||||
return Response("OK", mimetype="application/json")
|
||||
|
||||
def run(self):
|
||||
|
||||
print("StarVLA policy server running")
|
||||
print(f"Host: {self.host}")
|
||||
print(f"Port: {self.port}")
|
||||
|
||||
self.app.run(
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
threaded=True
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="StarVLA inference server")
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
type=str,
|
||||
default="./benchmark.yaml",
|
||||
help="Path to benchmark.yaml",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
config_path = args.config
|
||||
server = StarvlaInferenceServer(config_path)
|
||||
server.run()
|
||||
Reference in New Issue
Block a user