import yaml import pickle import argparse import os import traceback from urllib.parse import urlparse import numpy as np import torch import cv2 import json from flask import Flask, request, Response from PIL import Image from pathlib import Path def pad_to_dim(x: np.ndarray, target_dim: int, axis: int = -1, value: float = 0.0) -> np.ndarray: """Pad an array to the target dimension with zeros along the specified axis.""" current_dim = x.shape[axis] if current_dim < target_dim: pad_width = [(0, 0)] * len(x.shape) pad_width[axis] = (0, target_dim - current_dim) return np.pad(x, pad_width, constant_values=value) return x def normalize_eepose_state(raw_state, stats, start_idx, norm_len=3): """ 根据统计信息对原始动作的特定切片进行归一化 (Normalize),映射到 [-1.0, 1.0] 区间 :param raw_action: 包含原始动作的 numpy 数组,形状通常为 (batch_size, action_dim) :param stats: 归一化所需的统计信息字典 (包含 'q01' 和 'q99') :param start_idx: 需要归一化的特征在 action 向量中的起始索引 :param norm_len: 需要归一化的特征长度,默认为 3 (如 xyz 坐标) :return: 归一化后的动作数组 (返回新数组,不修改原数组) """ # 1. 提取并截取对应的 q01 和 q99,保持 float16 精度对齐 q01 = np.array(stats["q01"][start_idx : start_idx + norm_len], dtype=np.float16) q99 = np.array(stats["q99"][start_idx : start_idx + norm_len], dtype=np.float16) # 2. 计算分母 denom,防除零保护 denom = np.clip(q99 - q01, a_min=1e-5, a_max=None) # 3. 定位切片 target_slice = slice(start_idx, start_idx + norm_len) # 4. 复制一份以防污染原始数据矩阵 norm_state = raw_state.copy() x = norm_state[:, target_slice] # 5. 执行正向映射计算:y = 2 * (x - q01) / denom - 1 y = 2.0 * (x - q01) / denom - 1.0 # 6. 核心操作:截断到 [-1.0, 1.0] 区间,防止输入给模型的动作越界 norm_state[:, target_slice] = np.clip(y, -1.0, 1.0) return norm_state def unnormalize_eepose_action(normalized_action, stats, start_idx, norm_len=3): """ 读取统计信息 JSON 并对动作的特定切片进行反归一化 (Un-normalize) :param normalized_action: 包含归一化动作的 numpy 数组,形状通常为 (batch_size, action_dim) :param stats: 反归一化所需的统计信息,通常是从 JSON 文件中读取的字典 :param start_idx: 需要反归一化的特征在 action 向量中的起始索引 (即原代码中的 start) :param norm_len: 需要反归一化的特征长度,默认为 3 (如 xyz 坐标) :return: 反归一化后的动作数组 (返回新数组,不修改原数组) """ # 2. 提取并截取对应的 q01 和 q99 # 注意:为了和原归一化精度对齐,这里继续保持 float16 # import ipdb;ipdb.set_trace() q01 = np.array(stats["q01"][start_idx : start_idx + norm_len], dtype=np.float16) q99 = np.array(stats["q99"][start_idx : start_idx + norm_len], dtype=np.float16) # 3. 计算分母 denom,保持和原来完全相同的裁剪逻辑防除零 denom = np.clip(q99 - q01, a_min=1e-5, a_max=None) # 4. 定位切片 target_slice = slice(start_idx, start_idx + norm_len) # 5. 执行反向计算:x = (y + 1) / 2 * denom + q01 # 复制一份以防污染原始的 normalized_action 矩阵 unnorm_action = normalized_action.copy() y = unnorm_action[:, target_slice] unnorm_action[:, target_slice] = (y + 1.0) / 2.0 * denom + q01 return unnorm_action class StarvlaInferenceServer: def __init__(self, config_path: str): with open(config_path, "r") as f: cfg = yaml.safe_load(f) # import ipdb;ipdb.set_trace() policy_server_cfg = cfg["policy_server"] root_paths = cfg["general"]["root_paths"] self.ckpt_source = policy_server_cfg["ckpt_source"] self.ckpt_path = self._resolve_ckpt_path( ckpt_url=policy_server_cfg["ckpt_path"], root_paths=root_paths, ) self.rel_eepose_stats = self.read_rel_eepose_stats() self.host = policy_server_cfg.get("host", "0.0.0.0") self.port = policy_server_cfg.get("port", 5000) self.use_bf16 = policy_server_cfg.get("use_bf16", True) self.unnorm_key = policy_server_cfg.get("unnorm_key", "oxe_bridge") self.state_mode = policy_server_cfg.get("state_mode", "ee_pose7") print("Loading StarVLA model...") self.model = self.load_model() print("Model loaded.") self.app = Flask(__name__) self.register_routes() @staticmethod def _resolve_ckpt_path(ckpt_url: str, root_paths: dict) -> str: parsed = urlparse(ckpt_url) if not parsed.scheme: return ckpt_url root = root_paths.get(parsed.scheme) if not root: raise KeyError( f"cannot find the checkpoint root path in root_paths: {root_paths}" ) rel = (parsed.netloc + parsed.path).lstrip("/") return os.path.join(root, rel) def read_rel_eepose_stats(self): stats_path = Path(self.ckpt_path).parents[1] / "action_delta_eepose_stats.json" if stats_path is None or not Path(stats_path).exists(): return {} with open(stats_path, 'r', encoding='utf-8') as f: stats = json.load(f) return stats def load_model(self): from starVLA.model.framework.share_tools import read_mode_config, dict_to_namespace from starVLA.model.framework.__init__ import build_framework model_config, norm_stats = read_mode_config(self.ckpt_path) cfg = dict_to_namespace(model_config) cfg.trainer.pretrained_checkpoint = None model = build_framework(cfg=cfg) model.norm_stats = norm_stats state_dict = torch.load(self.ckpt_path, map_location="cpu") model.load_state_dict(state_dict, strict=True) if self.use_bf16: model = model.to(torch.bfloat16) model = model.to("cuda").eval() self.norm_stats = norm_stats self.action_norm_stats = norm_stats.get(self.unnorm_key, {}).get("action", None) return model def parse_observation(self, obs, target_size=(320, 180)): left_rgb, right_rgb, wrist_rgb = obs["rgb"]["Left_Camera"], obs["rgb"]["Right_Camera"], obs["rgb"]["Hand_Camera"] img_left = Image.fromarray(cv2.resize(left_rgb, target_size)) img_right = Image.fromarray(cv2.resize(right_rgb, target_size)) img_wrist = Image.fromarray(cv2.resize(wrist_rgb, target_size)) state_vec = obs["state"] #[:10] # import ipdb;ipdb.set_trace() state_vec = pad_to_dim(np.array(state_vec), 100, axis=-1) return img_left, img_right, img_wrist, state_vec, obs["prompt"] def inference(self, observation: dict) -> dict: img_left, img_right, img_wrist, state_vec, prompt = \ self.parse_observation(observation) print(f"{state_vec.shape}") vla_input = { "batch_images": [[img_left, img_right, img_wrist]], "instructions": [prompt], "state": [state_vec] } # import ipdb;ipdb.set_trace() with torch.no_grad(): output = self.model.predict_action(**vla_input) actions = output.get("normalized_actions") if isinstance(actions, torch.Tensor): actions = actions.cpu().numpy() if actions.ndim == 3: actions = actions[0] # (16, 10) # import ipdb;ipdb.set_trace() # 反归一化特定切片 (假设需要反归一化的部分是 action 向量中的第 0-8 维,即 ee_pose7 + gripper_width) actions[:, :3] = unnormalize_eepose_action(actions[:, :3], self.rel_eepose_stats, start_idx=0, norm_len=3) # from ipdb import set_trace; set_trace() return {"ee_delta_position_chunks": actions[:, :3].tolist(), "ee_delta_rot6d_chunks": actions[:, 3:9].tolist(), "gripper_width_chunks": actions[:, 9:10].tolist()} def register_routes(self): @self.app.route("/policy/inference", methods=["POST"]) def policy(): try: data = pickle.loads(request.data) result = self.inference(data) body = pickle.dumps(result, protocol=4) return Response(body, mimetype="application/octet-stream") except Exception as e: err_obj = { "error": str(e), "traceback": traceback.format_exc(), } body = pickle.dumps(err_obj, protocol=4) return Response( body, mimetype="application/octet-stream", status=500, ) @self.app.route("/policy/health", methods=["GET"]) def health(): if self.model is None: return Response("Failed to load model", mimetype="application/json", status=500) return Response("OK", mimetype="application/json") def run(self): print("StarVLA policy server running") print(f"Host: {self.host}") print(f"Port: {self.port}") self.app.run( host=self.host, port=self.port, threaded=True ) if __name__ == "__main__": parser = argparse.ArgumentParser(description="StarVLA inference server") parser.add_argument( "--config", type=str, default="./benchmark.yaml", help="Path to benchmark.yaml", ) args = parser.parse_args() config_path = args.config server = StarvlaInferenceServer(config_path) server.run()