This commit is contained in:
Junhan
2026-05-07 19:31:45 +08:00
parent 7ce2823c56
commit 2514cc943d
6 changed files with 5976 additions and 94 deletions

View File

@@ -0,0 +1,257 @@
import yaml
import pickle
import argparse
import os
import traceback
from urllib.parse import urlparse
import numpy as np
import torch
import cv2
import json
from flask import Flask, request, Response
from PIL import Image
from pathlib import Path
def pad_to_dim(x: np.ndarray, target_dim: int, axis: int = -1, value: float = 0.0) -> np.ndarray:
"""Pad an array to the target dimension with zeros along the specified axis."""
current_dim = x.shape[axis]
if current_dim < target_dim:
pad_width = [(0, 0)] * len(x.shape)
pad_width[axis] = (0, target_dim - current_dim)
return np.pad(x, pad_width, constant_values=value)
return x
def normalize_eepose_state(raw_state, stats, start_idx, norm_len=3):
"""
根据统计信息对原始动作的特定切片进行归一化 (Normalize),映射到 [-1.0, 1.0] 区间
:param raw_action: 包含原始动作的 numpy 数组,形状通常为 (batch_size, action_dim)
:param stats: 归一化所需的统计信息字典 (包含 'q01''q99')
:param start_idx: 需要归一化的特征在 action 向量中的起始索引
:param norm_len: 需要归一化的特征长度,默认为 3 (如 xyz 坐标)
:return: 归一化后的动作数组 (返回新数组,不修改原数组)
"""
# 1. 提取并截取对应的 q01 和 q99保持 float16 精度对齐
q01 = np.array(stats["q01"][start_idx : start_idx + norm_len], dtype=np.float16)
q99 = np.array(stats["q99"][start_idx : start_idx + norm_len], dtype=np.float16)
# 2. 计算分母 denom防除零保护
denom = np.clip(q99 - q01, a_min=1e-5, a_max=None)
# 3. 定位切片
target_slice = slice(start_idx, start_idx + norm_len)
# 4. 复制一份以防污染原始数据矩阵
norm_state = raw_state.copy()
x = norm_state[:, target_slice]
# 5. 执行正向映射计算y = 2 * (x - q01) / denom - 1
y = 2.0 * (x - q01) / denom - 1.0
# 6. 核心操作:截断到 [-1.0, 1.0] 区间,防止输入给模型的动作越界
norm_state[:, target_slice] = np.clip(y, -1.0, 1.0)
return norm_state
def unnormalize_eepose_action(normalized_action, stats, start_idx, norm_len=3):
"""
读取统计信息 JSON 并对动作的特定切片进行反归一化 (Un-normalize)
:param normalized_action: 包含归一化动作的 numpy 数组,形状通常为 (batch_size, action_dim)
:param stats: 反归一化所需的统计信息,通常是从 JSON 文件中读取的字典
:param start_idx: 需要反归一化的特征在 action 向量中的起始索引 (即原代码中的 start)
:param norm_len: 需要反归一化的特征长度,默认为 3 (如 xyz 坐标)
:return: 反归一化后的动作数组 (返回新数组,不修改原数组)
"""
# 2. 提取并截取对应的 q01 和 q99
# 注意:为了和原归一化精度对齐,这里继续保持 float16
# import ipdb;ipdb.set_trace()
q01 = np.array(stats["q01"][start_idx : start_idx + norm_len], dtype=np.float16)
q99 = np.array(stats["q99"][start_idx : start_idx + norm_len], dtype=np.float16)
# 3. 计算分母 denom保持和原来完全相同的裁剪逻辑防除零
denom = np.clip(q99 - q01, a_min=1e-5, a_max=None)
# 4. 定位切片
target_slice = slice(start_idx, start_idx + norm_len)
# 5. 执行反向计算x = (y + 1) / 2 * denom + q01
# 复制一份以防污染原始的 normalized_action 矩阵
unnorm_action = normalized_action.copy()
y = unnorm_action[:, target_slice]
unnorm_action[:, target_slice] = (y + 1.0) / 2.0 * denom + q01
return unnorm_action
class StarvlaInferenceServer:
def __init__(self, config_path: str):
with open(config_path, "r") as f:
cfg = yaml.safe_load(f)
# import ipdb;ipdb.set_trace()
policy_server_cfg = cfg["policy_server"]
root_paths = cfg["general"]["root_paths"]
self.ckpt_source = policy_server_cfg["ckpt_source"]
self.ckpt_path = self._resolve_ckpt_path(
ckpt_url=policy_server_cfg["ckpt_path"],
root_paths=root_paths,
)
self.rel_eepose_stats = self.read_rel_eepose_stats()
self.host = policy_server_cfg.get("host", "0.0.0.0")
self.port = policy_server_cfg.get("port", 5000)
self.use_bf16 = policy_server_cfg.get("use_bf16", True)
self.unnorm_key = policy_server_cfg.get("unnorm_key", "oxe_bridge")
self.state_mode = policy_server_cfg.get("state_mode", "ee_pose7")
print("Loading StarVLA model...")
self.model = self.load_model()
print("Model loaded.")
self.app = Flask(__name__)
self.register_routes()
@staticmethod
def _resolve_ckpt_path(ckpt_url: str, root_paths: dict) -> str:
parsed = urlparse(ckpt_url)
if not parsed.scheme:
return ckpt_url
root = root_paths.get(parsed.scheme)
if not root:
raise KeyError(
f"cannot find the checkpoint root path in root_paths: {root_paths}"
)
rel = (parsed.netloc + parsed.path).lstrip("/")
return os.path.join(root, rel)
def read_rel_eepose_stats(self):
stats_path = Path(self.ckpt_path).parents[1] / "action_delta_eepose_stats.json"
if stats_path is None or not Path(stats_path).exists():
return {}
with open(stats_path, 'r', encoding='utf-8') as f:
stats = json.load(f)
return stats
def load_model(self):
from starVLA.model.framework.share_tools import read_mode_config, dict_to_namespace
from starVLA.model.framework.__init__ import build_framework
model_config, norm_stats = read_mode_config(self.ckpt_path)
cfg = dict_to_namespace(model_config)
cfg.trainer.pretrained_checkpoint = None
model = build_framework(cfg=cfg)
model.norm_stats = norm_stats
state_dict = torch.load(self.ckpt_path, map_location="cpu")
model.load_state_dict(state_dict, strict=True)
if self.use_bf16:
model = model.to(torch.bfloat16)
model = model.to("cuda").eval()
self.norm_stats = norm_stats
self.action_norm_stats = norm_stats.get(self.unnorm_key, {}).get("action", None)
return model
def parse_observation(self, obs, target_size=(320, 180)):
left_rgb, right_rgb, wrist_rgb = obs["rgb"]["Left_Camera"], obs["rgb"]["Right_Camera"], obs["rgb"]["Hand_Camera"]
img_left = Image.fromarray(cv2.resize(left_rgb, target_size))
img_right = Image.fromarray(cv2.resize(right_rgb, target_size))
img_wrist = Image.fromarray(cv2.resize(wrist_rgb, target_size))
state_vec = obs["state"] #[:10]
# import ipdb;ipdb.set_trace()
state_vec = pad_to_dim(np.array(state_vec), 100, axis=-1)
return img_left, img_right, img_wrist, state_vec, obs["prompt"]
def inference(self, observation: dict) -> dict:
img_left, img_right, img_wrist, state_vec, prompt = \
self.parse_observation(observation)
print(f"{state_vec.shape}")
vla_input = {
"batch_images": [[img_left, img_right, img_wrist]],
"instructions": [prompt],
"state": [state_vec]
}
# import ipdb;ipdb.set_trace()
with torch.no_grad():
output = self.model.predict_action(**vla_input)
actions = output.get("normalized_actions")
if isinstance(actions, torch.Tensor):
actions = actions.cpu().numpy()
if actions.ndim == 3:
actions = actions[0] # (16, 10)
# import ipdb;ipdb.set_trace()
# 反归一化特定切片 (假设需要反归一化的部分是 action 向量中的第 0-8 维,即 ee_pose7 + gripper_width)
actions[:, :3] = unnormalize_eepose_action(actions[:, :3], self.rel_eepose_stats, start_idx=0, norm_len=3)
# from ipdb import set_trace; set_trace()
return {"ee_delta_position_chunks": actions[:, :3].tolist(),
"ee_delta_rot6d_chunks": actions[:, 3:9].tolist(),
"gripper_width_chunks": actions[:, 9:10].tolist()}
def register_routes(self):
@self.app.route("/policy/inference", methods=["POST"])
def policy():
try:
data = pickle.loads(request.data)
result = self.inference(data)
body = pickle.dumps(result, protocol=4)
return Response(body, mimetype="application/octet-stream")
except Exception as e:
err_obj = {
"error": str(e),
"traceback": traceback.format_exc(),
}
body = pickle.dumps(err_obj, protocol=4)
return Response(
body,
mimetype="application/octet-stream",
status=500,
)
@self.app.route("/policy/health", methods=["GET"])
def health():
if self.model is None:
return Response("Failed to load model", mimetype="application/json", status=500)
return Response("OK", mimetype="application/json")
def run(self):
print("StarVLA policy server running")
print(f"Host: {self.host}")
print(f"Port: {self.port}")
self.app.run(
host=self.host,
port=self.port,
threaded=True
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="StarVLA inference server")
parser.add_argument(
"--config",
type=str,
default="./benchmark.yaml",
help="Path to benchmark.yaml",
)
args = parser.parse_args()
config_path = args.config
server = StarvlaInferenceServer(config_path)
server.run()