commit cc9815f3b816f6130583a659bd9cb10628e1b885 Author: hufei.hofee Date: Wed Mar 18 18:12:34 2026 +0800 finish load inference server model diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..600d2d3 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +.vscode \ No newline at end of file diff --git a/benchmark.yaml b/benchmark.yaml new file mode 100644 index 0000000..8affca8 --- /dev/null +++ b/benchmark.yaml @@ -0,0 +1,179 @@ +general: + scan_project: true + root_paths: + asset: /home/ubuntu/projects/gen_data/data + output: /home/ubuntu/output + checkpoints: /home/ubuntu/data/models + +simulation: + launch_config: + device: cuda + enable_cameras: true + headless: false + livestream: 0 + +scene: + name: default_scene_name + position: [0, 0, 0] + rotation: [1, 0, 0, 0] + base_config: + name: default_base + source: primitive + stereotype: ground_plane + ground_size: [100,100] + + object_cfg_dict: + table: + name: simple_table + position: [0.5, 0, 0.25] + source: primitive + stereotype: rigid + primitive_type: cuboid + primitive_size: [0.5, 1, 0.5] + mass: 1e4 + + target: + name: target + position: [0.4, 0.0, 0.5] + scale: [0.001, 0.001, 0.001] + axis_y_up: true + asset_path: asset://objects/omni6DPose/ball/omni6DPose_ball_020/Aligned.usd + stereotype: rigid + source: local + + + robot_cfg_dict: + robot: + name: my_robot + asset_path: asset://Franka/franka_robotiq_2f85_zedmini.usd + position: [0, 0, 0] + stereotype: single_gripper_arm_robot + source: local + init_joint_position: + panda_joint2: -0.1633 + panda_joint4: -1.070 + panda_joint6: 0.8933 + panda_joint7: 0.785 + + arm_actuator_name: franka_arm + gripper_actuator_name: robotiq_2f_85 + + use_planner: true + planner_cfg: + stereotype: curobo + lazy_init: true + robot_config_file: asset://curobo/franka_robotiq_2f85/franka_robotiq_2f85.yml + world_config_source: stage + world_stage_ignore_substrings: [my_robot] + world_stage_only_paths: [/World] + world_stage_reference_prim_path: /World/Robot/SingleGripperArmRobot/my_robot + + sensor_cfg_dict: + front_camera: + name: front_camera + stereotype: camera + position: [0.8, 0.0, 0.8] + data_types: [rgb, depth, normals] + width: 1280 + height: 720 + camera_model: pinhole + fix_camera: true + + left_camera: + name: left_camera + stereotype: camera + position: [0.6, 0.7, 0.8] + data_types: [rgb, depth, normals] + width: 1280 + height: 720 + camera_model: pinhole + fix_camera: true + + right_camera: + name: right_camera + stereotype: camera + position: [0.6, -0.7, 0.8] + data_types: [rgb, depth, normals] + width: 1280 + height: 720 + camera_model: pinhole + fix_camera: true + +extension: + extension_cfg_dict: + my_data_collect: + enable: true + stereotype: data_collect + observer_cfgs: + - stereotype: robot_observer + name: my_robot + observe_joint_positions: true + observe_joint_velocities: true + observe_joint_accelerations: true + observe_joint_position_targets: true + observe_joint_velocity_targets: true + observe_position: true + observe_rotation: true + observe_ee_pose: true + observe_gripper_state: true + observe_gripper_drive_state: true + - stereotype: sensor_observer + name: front_camera + observe_intrinsic_matrix: true + observe_extrinsic_matrix: true + observe_rgb: true + observe_depth: true + observe_normals: true + - stereotype: sensor_observer + name: left_camera + observe_intrinsic_matrix: true + observe_extrinsic_matrix: true + observe_rgb: true + observe_depth: true + observe_normals: true + - stereotype: sensor_observer + name: right_camera + observe_intrinsic_matrix: true + observe_extrinsic_matrix: true + observe_rgb: true + observe_depth: true + observe_normals: true + + - stereotype: task_observer + name: task + + - stereotype: object_observer + name: target + observe_position: true + observe_rotation: true + observe_scale: true + + my_benchmark: + enable: true + stereotype: benchmark + data_collector_name: my_data_collect + goals: + - name: reach_target + description: Reach the target + stereotype: pose + pose_A_source: ee + pose_A_params: + robot_name: my_robot + pose_B_source: spawnable + pose_B_params: + spawnable_name: target + position_tolerance: 0.005 + policy: + stereotype: starvla + robot_name: my_robot + object_name: target + prompt: pick the cola bottle and place it on the book + +policy_server: + ckpt_path: checkpoints://0309_qwenpi_droid_cola_post/final_model/pytorch_model.pt + ckpt_source: local + host: 0.0.0.0 + port: 5000 + use_bf16: true + unnorm_key: oxe_bridge + state_mode: ee_pose7 \ No newline at end of file diff --git a/starvla_inference_server.py b/starvla_inference_server.py new file mode 100644 index 0000000..d489bfa --- /dev/null +++ b/starvla_inference_server.py @@ -0,0 +1,173 @@ + +import yaml +import pickle +import os +from urllib.parse import urlparse +import numpy as np +import torch +import cv2 + +from flask import Flask, request, Response +from PIL import Image + + +class StarvlaInferenceServer: + + def __init__(self, config_path: str): + + with open(config_path, "r") as f: + cfg = yaml.safe_load(f) + + policy_server_cfg = cfg["policy_server"] + root_paths = cfg["general"]["root_paths"] + self.ckpt_source = policy_server_cfg["ckpt_source"] + self.ckpt_path = self._resolve_ckpt_path( + ckpt_url=policy_server_cfg["ckpt_path"], + root_paths=root_paths, + ) + self.host = policy_server_cfg.get("host", "0.0.0.0") + self.port = policy_server_cfg.get("port", 5000) + self.use_bf16 = policy_server_cfg.get("use_bf16", True) + self.unnorm_key = policy_server_cfg.get("unnorm_key", "oxe_bridge") + self.state_mode = policy_server_cfg.get("state_mode", "ee_pose7") + + print("Loading StarVLA model...") + self.model = self.load_model() + print("Model loaded.") + + self.app = Flask(__name__) + self.register_routes() + + @staticmethod + def _resolve_ckpt_path(ckpt_url: str, root_paths: dict) -> str: + parsed = urlparse(ckpt_url) + if not parsed.scheme: + return ckpt_url + root = root_paths.get(parsed.scheme) + if not root: + raise KeyError( + f"cannot find the checkpoint root path in root_paths: {root_paths}" + ) + rel = (parsed.netloc + parsed.path).lstrip("/") + return os.path.join(root, rel) + + def load_model(self): + + from starVLA.model.framework.share_tools import read_mode_config, dict_to_namespace + from starVLA.model.framework.__init__ import build_framework + + model_config, norm_stats = read_mode_config(self.ckpt_path) + + cfg = dict_to_namespace(model_config) + cfg.trainer.pretrained_checkpoint = None + + model = build_framework(cfg=cfg) + model.norm_stats = norm_stats + + state_dict = torch.load(self.ckpt_path, map_location="cpu") + model.load_state_dict(state_dict, strict=True) + + if self.use_bf16: + model = model.to(torch.bfloat16) + + model = model.to("cuda").eval() + + self.norm_stats = norm_stats + self.action_norm_stats = norm_stats.get(self.unnorm_key, {}).get("action", None) + + return model + + def parse_observation(self, obs): + + rgb = obs["rgb"][-1] + state = obs["state"][-1] + joint = obs.get("joint", None) + prompt = obs["prompt"] + + left = rgb[:, :, :3] + right = rgb[:, :, 3:6] + wrist = rgb[:, :, 6:9] + + target_size = (320, 180) + + left = cv2.resize(left, target_size) + right = cv2.resize(right, target_size) + wrist = cv2.resize(wrist, target_size) + + img_left = Image.fromarray(left) + img_right = Image.fromarray(right) + img_wrist = Image.fromarray(wrist) + + if self.state_mode == "joint8": + + joint_last = joint[-1] + gripper = state[9] + + state_vec = np.concatenate( + [joint_last, np.array([gripper])], + axis=0 + ) + + else: + + xyz = state[0:3] + rot6d = state[3:9] + gripper = state[9] + + state_vec = np.concatenate( + [xyz, rot6d[:3], np.array([gripper])], + axis=0 + ) + + return img_left, img_right, img_wrist, state_vec, prompt + + def inference(self, observation: dict) -> dict: + + img_left, img_right, img_wrist, state_vec, prompt = \ + self.parse_observation(observation) + + vla_input = { + "batch_images": [[img_left, img_right, img_wrist]], + "instructions": [prompt], + "state": [state_vec] + } + + with torch.no_grad(): + output = self.model.predict_action(**vla_input) + + actions = output.get("normalized_actions") + + if isinstance(actions, torch.Tensor): + actions = actions.cpu().numpy() + + if actions.ndim == 3: + actions = actions[0] + + return {"action": actions.astype(np.float32)} + + def register_routes(self): + + @self.app.route("/policy", methods=["POST"]) + def policy(): + data = pickle.loads(request.data) + result = self.inference(data) + body = pickle.dumps(result, protocol=4) + return Response(body, mimetype="application/octet-stream") + + def run(self): + + print("StarVLA policy server running") + print(f"Host: {self.host}") + print(f"Port: {self.port}") + + self.app.run( + host=self.host, + port=self.port, + threaded=True + ) + + +if __name__ == "__main__": + config_path = "./benchmark.yaml" + server = StarvlaInferenceServer(config_path) + server.run() \ No newline at end of file diff --git a/starvla_policy.py b/starvla_policy.py new file mode 100644 index 0000000..b6010a6 --- /dev/null +++ b/starvla_policy.py @@ -0,0 +1,123 @@ +from joysim.annotations.config_class import configclass, field +from joysim.annotations.stereotype import stereotype +from joysim.app import JoySim +from joysim.core.scene_manager import SceneManager +from joysim.extensions.benchmark.action import RobotAction +from joysim.extensions.benchmark.benchmark import ( + BenchmarkAction, + BenchmarkObservation, + ControlMode, +) +from joysim.extensions.benchmark.policy import Policy, PolicyConfig + +import numpy as np +import pickle +import requests + + +@configclass +@stereotype.register_config("starvla") +class StarvlaPolicyConfig(PolicyConfig): + + robot_name: str = field(default="my_robot", required=True, comment="The name of the robot") + object_name: str = field(default="target", required=True, comment="The name of the object") + + server_url: str = field( + default="http://127.0.0.1:5000/policy", + required=True, + comment="StarVLA policy server url" + ) + + prompt: str = field( + default="pick the object", + required=True, + comment="task instruction" + ) + + +@stereotype.register_model("starvla") +class StarvlaPolicy(Policy): + + def __init__(self, config: StarvlaPolicyConfig): + super().__init__(config) + + self.robot_name = config.robot_name + self.object_name = config.object_name + self.server_url = config.server_url + self.prompt = config.prompt + + def reset(self) -> None: + pass + + def warmup(self, benchmark_observation: BenchmarkObservation) -> None: + pass + + def needs_observation(self) -> bool: + return True + + def preprocess_observation(self, benchmark_observation: BenchmarkObservation) -> dict: + + robot_obs = benchmark_observation.get_robot_observations(self.robot_name)["robot_data"] + joint_positions = robot_obs["joint_positions"] + robot_position = robot_obs["position"] + robot_quaternion = robot_obs["rotation"] + + state = np.concatenate([ + robot_position, + robot_quaternion, + np.array([0.0]) + ]) + + camera_obs = benchmark_observation.get_sensor_observations() + rgb = camera_obs["rgb"] + + obs = { + "state": np.expand_dims(state, axis=0), + "joint": np.expand_dims(joint_positions, axis=0), + "rgb": np.expand_dims(rgb, axis=0), + "prompt": self.prompt + } + + return obs + + def compute_action(self, observation: dict) -> dict: + + payload = pickle.dumps(observation) + + response = requests.post( + self.server_url, + data=payload, + headers={"Content-Type": "application/octet-stream"} + ) + + if response.status_code != 200: + raise RuntimeError(f"StarVLA server error: {response.text}") + + result = pickle.loads(response.content) + + return result + + def postprocess_action(self, action: dict) -> BenchmarkAction: + + benchmark_action = BenchmarkAction() + + robot = SceneManager.get_robot(self.robot_name) + joint_names = robot.get_planner().get_plannable_joint_names() + + joint_positions = action["action"][0] + + benchmark_action.add_robot_action( + RobotAction( + control_mode=ControlMode.POSITION, + robot_name=self.robot_name, + joint_names=joint_names, + joint_positions=joint_positions + ) + ) + + return benchmark_action + + +if __name__ == "__main__": + js = JoySim("./benchmark.yaml") + js.start() \ No newline at end of file