This commit is contained in:
Junhan
2026-05-07 19:31:45 +08:00
parent 7ce2823c56
commit 2514cc943d
6 changed files with 5976 additions and 94 deletions

View File

@@ -1,77 +1,80 @@
general: general:
scan_project: true scan_project: true
root_paths: root_paths:
asset: /home/ubuntu/projects/gen_data/data asset: /home/ubuntu/xionghao/sim_hofee/sim_hofee/assets
checkpoints: /home/ubuntu/data/models checkpoints: /home/ubuntu/xionghao/starVLA-starVLA/playground/Checkpoints
output: /home/ubuntu/output output: /home/ubuntu/xionghao/sim_hofee
simulation: simulation:
stereotype: isaaclab
intiailize_steps: 300
launch_config: launch_config:
device: cuda device: cuda
enable_cameras: true enable_cameras: true
headless: false headless: false
livestream: 0 livestream: 0
scene: scene:
name: kujiale_multispace name: 827313_home
base_config: base_config:
stereotype: usd stereotype: usd
name: _827313_home_workspace_00 name: _827313_home_workspace_01
source: local source: local
asset_path: asset://scenes/kujiale_multispace/827313_home/workspace_00.usd asset_path: asset://scenes/kujiale_multispace/827313_home/workspace_01.usd
object_cfg_dict: object_cfg_dict:
omni6DPose_can_016: omni6DPose_timer_017:
name: omni6DPose_can_016 name: omni6DPose_timer_017
stereotype: rigid stereotype: rigid
source: local source: local
asset_path: asset://objects/omni6DPose/can/omni6DPose_can_016/Aligned.usd asset_path: asset://objects/omni6DPose/timer/omni6DPose_timer_017/Aligned.usd
scale: scale:
- 0.001 - 0.001
- 0.001 - 0.001
- 0.001 - 0.001
position: position:
- 0.2 - 0.552364
- -4.15243 - -4.0582599999999995
- 0.5 - 0.524713118
quaternion: quaternion:
- -0.304408012043137 - 0.166210542394157
- -0.304408012043137 - 0.166210542394157
- 0.638228612805745 - 0.6872947370648492
- 0.6382286128057448 - 0.6872947370648491
axis_y_up: true axis_y_up: true
omni6DPose_plug_001: omni6DPose_book_031:
name: omni6DPose_plug_001 name: omni6DPose_book_031
stereotype: rigid stereotype: rigid
source: local source: local
asset_path: asset://objects/omni6DPose/plug/omni6DPose_plug_001/Aligned.usd asset_path: asset://objects/omni6DPose/book/omni6DPose_book_031/Aligned.usd
scale: scale:
- 0.001 - 0.001
- 0.001 - 0.001
- 0.001 - 0.001
position: position:
- 0.219859 - 0.6623640000000001
- -3.852430000000001 - -3.7882599999999997
- 0.510259093 - 0.5101601435
quaternion: quaternion:
- 0.7070997233636068 - 0.7063055546421202
- 0.707099723363607 - 0.7063055546421203
- 0.003159306745230588 - -0.03365209475927027
- 0.003159306745230588 - -0.033652094759270265
axis_y_up: true axis_y_up: true
robot_cfg_dict: robot_cfg_dict:
Franka: Franka_Robotiq_2f85:
name: Franka name: Franka_Robotiq_2f85
asset_path: asset://Franka/franka_robotiq_2f85_zedmini.usd asset_path: asset://Franka/franka_robotiq_2f85_zedmini.usd
position: position:
- -0.3779152929316859 - 1.082364
- -3.943187336951741 - -3.92826
- 0.3130090550954169 - 0.47629299999999997
rotation: rotation:
- 0.9987576855319703 - 7.549799991308018e-08
- 0.0 - 0.0
- 0.0 - 0.0
- 0.04983056883903615 - 0.9999999999999973
stereotype: single_gripper_arm_robot stereotype: modular_robot
source: local source: local
ee_link_name: panda_link8 ee_link_name: panda_link8
ik_joint_names: ik_joint_names:
@@ -83,18 +86,43 @@ scene:
- panda_joint6 - panda_joint6
- panda_joint7 - panda_joint7
init_joint_position: init_joint_position:
# panda_joint1: 0.18641542
# panda_joint2: 0.47660449
# panda_joint3: -0.03320411
# panda_joint4: -2.27693725
# panda_joint5: 0.98161776
# panda_joint6: 2.20247197
# panda_joint7: 0.71794897
panda_joint2: -0.1633 panda_joint2: -0.1633
panda_joint4: -1.07 panda_joint4: -1.07
panda_joint6: 0.8933 panda_joint6: 0.8933
panda_joint7: 0.785 panda_joint7: 0.785
arm_modules:
main_arm:
arm_actuator_name: franka_arm arm_actuator_name: franka_arm
gripper_actuator_name: robotiq_2f_85 ee_link_name: panda_link8
use_planner: true ee_type: gripper
planner_cfg: ee_actuator_name: robotiq_gripper
stereotype: curobo actuator_cfg_dict:
lazy_init: true franka_arm:
robot_config_file: asset://curobo/franka_robotiq_2f85/franka_robotiq_2f85.yml stereotype: arm
world_config_source: stage joint_names_expr: [panda_joint1, panda_joint2, panda_joint3, panda_joint4, panda_joint5, panda_joint6, panda_joint7]
stiffness: 3000.0
damping: 800.0
robotiq_gripper:
stereotype: gripper
joint_names_expr: [robotiq_85_left_knuckle_joint]
stiffness: 10000
damping: 500.0
close_control_type: velocity
open_control_type: position
drive_joints:
robotiq_85_left_knuckle_joint:
close_velocity: 5.0
open_velocity: -5.0
close_position: 0.8
open_position: 0.0
use_planner: false
sensor_cfg_dict: sensor_cfg_dict:
Hand_Camera: Hand_Camera:
name: Hand_Camera name: Hand_Camera
@@ -108,23 +136,23 @@ scene:
camera_model: pinhole camera_model: pinhole
fix_camera: true fix_camera: true
focal_length: 2.8 focal_length: 2.8
horizontal_aperture: 4.890881131191918 horizontal_aperture: 4.893416860031241
vertical_aperture: 2.7608816125932627 vertical_aperture: 2.7608816125932627
convention: opengl convention: opengl
attach_to: attach_to:
target_name: Franka target_name: Franka_Robotiq_2f85
is_articulation_part: true is_articulation_part: true
articulation_part_name: panda_link8 articulation_part_name: panda_link8
create_fixed_joint: true create_fixed_joint: true
local_position: local_position:
- -0.07176474936469446 - -0.07128738160694643
- 0.02890958100382394 - 0.03551506300731732
- 0.01978286477078585 - 0.018927748370281355
local_rotation: local_rotation:
- 0.12352531576657892 - -0.12117023430710862
- 0.7000519980574638 - -0.6862313269668
- -0.6933396337721066 - 0.7070213671685396
- -0.11810524383481495 - 0.12052023305019997
Left_Camera: Left_Camera:
name: Left_Camera name: Left_Camera
stereotype: camera stereotype: camera
@@ -137,20 +165,20 @@ scene:
camera_model: pinhole camera_model: pinhole
fix_camera: false fix_camera: false
focal_length: 2.1 focal_length: 2.1
horizontal_aperture: 5.019302546405283 horizontal_aperture: 5.030789363390793
vertical_aperture: 2.833796298140747 vertical_aperture: 2.833796298140747
convention: opengl convention: opengl
attach_to: attach_to:
target_name: Franka target_name: Franka_Robotiq_2f85
local_position: local_position:
- -0.07383269512283744 - 0.31702696813014064
- -0.4566116797983716 - -0.3844238699868664
- 0.5664518136443555 - 0.6551552990137672
local_rotation: local_rotation:
- 0.7669289562134721 - 0.8742457685173938
- 0.4437472984144337 - 0.38378563025938384
- -0.232167881174614 - -0.11951449178007277
- -0.4012560108236337 - -0.27224843891267797
Right_Camera: Right_Camera:
name: Right_Camera name: Right_Camera
stereotype: camera stereotype: camera
@@ -163,20 +191,20 @@ scene:
camera_model: pinhole camera_model: pinhole
fix_camera: false fix_camera: false
focal_length: 2.1 focal_length: 2.1
horizontal_aperture: 5.053278483887542 horizontal_aperture: 5.050364265142387
vertical_aperture: 2.833796298140747 vertical_aperture: 2.833796298140747
convention: opengl convention: opengl
attach_to: attach_to:
target_name: Franka target_name: Franka_Robotiq_2f85
local_position: local_position:
- 0.4501524009531093 - 0.21844487914880717
- 0.7206899873248545 - 0.20172329179193413
- 0.27525652672526973 - 0.30108042236545296
local_rotation: local_rotation:
- 0.007056820167439314 - -0.5316249212230874
- 0.007107689842383092 - -0.38697158527836417
- 0.709606075042274 - 0.44338617110944967
- 0.7045274304789896 - 0.6091277686910994
extension: extension:
extension_cfg_dict: extension_cfg_dict:
@@ -185,9 +213,10 @@ extension:
stereotype: data_collect stereotype: data_collect
observer_cfgs: observer_cfgs:
- stereotype: robot_observer - stereotype: robot_observer
name: Franka name: Franka_Robotiq_2f85
target_joint_names: [panda_joint1, panda_joint2, panda_joint3, panda_joint4, panda_joint5, panda_joint6, panda_joint7, robotiq_85_left_knuckle_joint]
observe_ee_pose: true observe_ee_pose: true
observe_gripper_drive_state: true observe_ee_state: true
observe_joint_position: true observe_joint_position: true
observe_joint_velocity: true observe_joint_velocity: true
observe_joint_positions: true observe_joint_positions: true
@@ -210,22 +239,23 @@ extension:
stereotype: benchmark stereotype: benchmark
data_collector_name: benchmark_data_collect data_collector_name: benchmark_data_collect
action_frequency: 15.0 action_frequency: 15.0
timeout_per_episode: 30 timeout_per_episode: 300
goals: goals:
- name: cola on top of book - name: cola on top of book
description: check if the cola bottle is on the book description: check if the cola bottle is on the book
stereotype: on_top stereotype: on_top
object_A_name: omni6DPose_plug_001 object_A_name: omni6DPose_book_031
object_B_name: omni6DPose_can_016 object_B_name: omni6DPose_timer_017
policy: policy:
stereotype: starvla stereotype: starvla
robot_name: Franka robot_name: Franka_Robotiq_2f85
arm_name: main_arm
sensor_names: [Hand_Camera, Left_Camera, Right_Camera] sensor_names: [Hand_Camera, Left_Camera, Right_Camera]
prompt: pick up the can prompt: pick up the timer and put on the book
run_trunk_size: 16 run_trunk_size: 16
gripper_width_mapper_file: ./gripper_width_robotiq_2f85_fixed.json gripper_width_mapper_file: ./gripper_width_robotiq_2f85_fixed.json
visualize_action_ee_pose: false visualize_action_ee_pose: true
visualize_state_ee_pose: false visualize_state_ee_pose: true
visualize_bounding_box_targets: [] # [omni6DPose_plug_001, omni6DPose_can_016] # 打开会被policy看到会影响policy的推理结果 visualize_bounding_box_targets: [] # [omni6DPose_plug_001, omni6DPose_can_016] # 打开会被policy看到会影响policy的推理结果
recorder: recorder:
@@ -237,7 +267,10 @@ extension:
postprocess_list: ["hdf5", "video"] postprocess_list: ["hdf5", "video"]
policy_server: policy_server:
ckpt_path: checkpoints://0318_qwenpi_droid_pretrain_8node/checkpoints/steps_30000_pytorch_model.pt # ckpt_path: checkpoints://0324_qwenpi_droid_pretrain_8node/checkpoints/steps_30000_pytorch_model.pt
# ckpt_path: checkpoints://0405_qwenpi_droid_norm_pretrain_8node/checkpoints/steps_60000_pytorch_model.pt
# ckpt_path: checkpoints://0407_qwenpi_droid_postrain/final_model/pytorch_model.pt
ckpt_path: checkpoints://0407_qwenpi_droid_from_scratch/final_model/pytorch_model.pt
ckpt_source: local ckpt_source: local
host: 0.0.0.0 host: 0.0.0.0
port: 5000 port: 5000

257
benchmark_can.yaml Normal file
View File

@@ -0,0 +1,257 @@
general:
scan_project: true
root_paths:
asset: /home/ubuntu/xionghao/sim_hofee/sim_hofee/assets
checkpoints: /home/ubuntu/xionghao/starVLA-starVLA/playground/Checkpoints
output: /home/ubuntu/xionghao/sim_hofee
simulation:
launch_config:
device: cuda
enable_cameras: true
headless: false
livestream: 0
scene:
name: kujiale_multispace
base_config:
stereotype: usd
name: _827313_home_workspace_00
source: local
asset_path: asset://scenes/kujiale_multispace/827313_home/workspace_00.usd
object_cfg_dict:
omni6DPose_can_016:
name: omni6DPose_can_016
stereotype: rigid
source: local
asset_path: asset://objects/omni6DPose/can/omni6DPose_can_016/Aligned.usd
scale:
- 0.001
- 0.001
- 0.001
position:
- 0.2
- -4.05243
- 0.5
quaternion:
- -0.304408012043137
- -0.304408012043137
- 0.638228612805745
- 0.6382286128057448
axis_y_up: true
omni6DPose_can_011:
name: omni6DPose_can_011
stereotype: rigid
source: local
asset_path: asset://objects/omni6DPose/can/omni6DPose_can_011/Aligned.usd
scale:
- 0.001
- 0.001
- 0.001
position:
- 0.219859
- -3.82430000000001
- 0.510259093
quaternion:
- 0.7070997233636068
- 0.707099723363607
- 0.003159306745230588
- 0.003159306745230588
axis_y_up: true
robot_cfg_dict:
Franka:
name: Franka
asset_path: asset://Franka/franka_robotiq_2f85_zedmini.usd
position:
- -0.3779152929316859
- -3.943187336951741
- 0.3130090550954169
rotation:
- 0.9987576855319703
- 0.0
- 0.0
- 0.04983056883903615
stereotype: single_gripper_arm_robot
source: local
ee_link_name: panda_link8
ik_joint_names:
- panda_joint1
- panda_joint2
- panda_joint3
- panda_joint4
- panda_joint5
- panda_joint6
- panda_joint7
init_joint_position:
panda_joint2: -0.1633
panda_joint4: -1.07
panda_joint6: 0.8933
panda_joint7: 0.785
# panda_joint1: 0.00201746
# panda_joint2: -0.25797674
# panda_joint3: -0.02563748
# panda_joint4: -1.68997145
# panda_joint5: -0.0735122
# panda_joint6: 1.45765436
# panda_joint7: 0.56564939
arm_actuator_name: franka_arm
gripper_actuator_name: robotiq_2f_85
use_planner: true
planner_cfg:
stereotype: curobo
lazy_init: true
robot_config_file: asset://curobo/franka_robotiq_2f85/franka_robotiq_2f85.yml
world_config_source: stage
sensor_cfg_dict:
Hand_Camera:
name: Hand_Camera
stereotype: camera
data_types:
- rgb
- depth
- normals
width: 1280
height: 720
camera_model: pinhole
fix_camera: true
focal_length: 2.8
horizontal_aperture: 4.890881131191918
vertical_aperture: 2.7608816125932627
convention: opengl
attach_to:
target_name: Franka
is_articulation_part: true
articulation_part_name: panda_link8
create_fixed_joint: true
local_position:
- -0.07176474936469446
- 0.02890958100382394
- 0.01978286477078585
local_rotation:
- 0.12352531576657892
- 0.7000519980574638
- -0.6933396337721066
- -0.11810524383481495
Left_Camera:
name: Left_Camera
stereotype: camera
data_types:
- rgb
- depth
- normals
width: 1280
height: 720
camera_model: pinhole
fix_camera: false
focal_length: 2.1
horizontal_aperture: 5.019302546405283
vertical_aperture: 2.833796298140747
convention: opengl
attach_to:
target_name: Franka
local_position:
- -0.07383269512283744
- -0.4566116797983716
- 0.5664518136443555
local_rotation:
- 0.7669289562134721
- 0.4437472984144337
- -0.232167881174614
- -0.4012560108236337
Right_Camera:
name: Right_Camera
stereotype: camera
data_types:
- rgb
- depth
- normals
width: 1280
height: 720
camera_model: pinhole
fix_camera: false
focal_length: 2.1
horizontal_aperture: 5.053278483887542
vertical_aperture: 2.833796298140747
convention: opengl
attach_to:
target_name: Franka
local_position:
- 0.4501524009531093
- 0.7206899873248545
- 0.27525652672526973
local_rotation:
- 0.007056820167439314
- 0.007107689842383092
- 0.709606075042274
- 0.7045274304789896
extension:
extension_cfg_dict:
benchmark_data_collect:
enable: true
stereotype: data_collect
observer_cfgs:
- stereotype: robot_observer
name: Franka
target_joint_names: [panda_joint1, panda_joint2, panda_joint3, panda_joint4, panda_joint5, panda_joint6, panda_joint7]
observe_ee_pose: true
observe_gripper_drive_state: true
observe_joint_position: true
observe_joint_velocity: true
observe_joint_positions: true
observe_joint_velocities: true
observe_joint_accelerations: true
observe_joint_position_targets: true
observe_joint_velocity_targets: true
- stereotype: sensor_observer
name: Hand_Camera
observe_rgb: true
- stereotype: sensor_observer
name: Left_Camera
observe_rgb: true
- stereotype: sensor_observer
name: Right_Camera
observe_rgb: true
starvla_benchmark:
enable: true
stereotype: benchmark
data_collector_name: benchmark_data_collect
action_frequency: 15.0
timeout_per_episode: 300
goals:
- name: cola on top of book
description: check if the cola bottle is on the book
stereotype: on_top
object_A_name: omni6DPose_can_016
object_B_name: omni6DPose_can_011
policy:
stereotype: starvla
robot_name: Franka
sensor_names: [Hand_Camera, Left_Camera, Right_Camera]
prompt: pick up the red can on the table
run_trunk_size: 16
gripper_width_mapper_file: ./gripper_width_robotiq_2f85_fixed.json
visualize_action_ee_pose: false
visualize_state_ee_pose: false
visualize_bounding_box_targets: [] # [omni6DPose_plug_001, omni6DPose_can_016] # 打开会被policy看到会影响policy的推理结果
recorder:
enable: true # set to true to record the data
stereotype: record
data_collector_name: benchmark_data_collect
record_fps: 30
backend_root_path: output://benchmark_record
postprocess_list: ["hdf5", "video"]
policy_server:
# ckpt_path: checkpoints://0324_qwenpi_droid_pretrain_8node/checkpoints/steps_30000_pytorch_model.pt
# ckpt_path: checkpoints://0401_qwenpi_droid_pretrain_8node/checkpoints/steps_20000_pytorch_model.p
ckpt_path: checkpoints://0405_qwenpi_droid_norm_pretrain_8node/checkpoints/steps_60000_pytorch_model.pt
# ckpt_path: checkpoints://0407_qwenpi_droid_postrain/checkpoints/steps_10000_pytorch_model.pt
ckpt_source: local
host: 0.0.0.0
port: 5000
use_bf16: true
unnorm_key: oxe_bridge
state_mode: ee_pose7

View File

@@ -12,6 +12,14 @@ import cv2
from flask import Flask, request, Response from flask import Flask, request, Response
from PIL import Image from PIL import Image
def pad_to_dim(x: np.ndarray, target_dim: int, axis: int = -1, value: float = 0.0) -> np.ndarray:
"""Pad an array to the target dimension with zeros along the specified axis."""
current_dim = x.shape[axis]
if current_dim < target_dim:
pad_width = [(0, 0)] * len(x.shape)
pad_width[axis] = (0, target_dim - current_dim)
return np.pad(x, pad_width, constant_values=value)
return x
class StarvlaInferenceServer: class StarvlaInferenceServer:
@@ -88,14 +96,15 @@ class StarvlaInferenceServer:
img_wrist = Image.fromarray(cv2.resize(wrist_rgb, target_size)) img_wrist = Image.fromarray(cv2.resize(wrist_rgb, target_size))
state_vec = obs["state"] state_vec = obs["state"]
# import ipdb;ipdb.set_trace()
# state_vec = pad_to_dim(np.array(state_vec), 100, axis=-1)
return img_left, img_right, img_wrist, state_vec, obs["prompt"] return img_left, img_right, img_wrist, state_vec, obs["prompt"]
def inference(self, observation: dict) -> dict: def inference(self, observation: dict) -> dict:
img_left, img_right, img_wrist, state_vec, prompt = \ img_left, img_right, img_wrist, state_vec, prompt = \
self.parse_observation(observation) self.parse_observation(observation)
print(f"{state_vec.shape}")
vla_input = { vla_input = {
"batch_images": [[img_left, img_right, img_wrist]], "batch_images": [[img_left, img_right, img_wrist]],
"instructions": [prompt], "instructions": [prompt],

View File

@@ -0,0 +1,257 @@
import yaml
import pickle
import argparse
import os
import traceback
from urllib.parse import urlparse
import numpy as np
import torch
import cv2
import json
from flask import Flask, request, Response
from PIL import Image
from pathlib import Path
def pad_to_dim(x: np.ndarray, target_dim: int, axis: int = -1, value: float = 0.0) -> np.ndarray:
"""Pad an array to the target dimension with zeros along the specified axis."""
current_dim = x.shape[axis]
if current_dim < target_dim:
pad_width = [(0, 0)] * len(x.shape)
pad_width[axis] = (0, target_dim - current_dim)
return np.pad(x, pad_width, constant_values=value)
return x
def normalize_eepose_state(raw_state, stats, start_idx, norm_len=3):
"""
根据统计信息对原始动作的特定切片进行归一化 (Normalize),映射到 [-1.0, 1.0] 区间
:param raw_action: 包含原始动作的 numpy 数组,形状通常为 (batch_size, action_dim)
:param stats: 归一化所需的统计信息字典 (包含 'q01''q99')
:param start_idx: 需要归一化的特征在 action 向量中的起始索引
:param norm_len: 需要归一化的特征长度,默认为 3 (如 xyz 坐标)
:return: 归一化后的动作数组 (返回新数组,不修改原数组)
"""
# 1. 提取并截取对应的 q01 和 q99保持 float16 精度对齐
q01 = np.array(stats["q01"][start_idx : start_idx + norm_len], dtype=np.float16)
q99 = np.array(stats["q99"][start_idx : start_idx + norm_len], dtype=np.float16)
# 2. 计算分母 denom防除零保护
denom = np.clip(q99 - q01, a_min=1e-5, a_max=None)
# 3. 定位切片
target_slice = slice(start_idx, start_idx + norm_len)
# 4. 复制一份以防污染原始数据矩阵
norm_state = raw_state.copy()
x = norm_state[:, target_slice]
# 5. 执行正向映射计算y = 2 * (x - q01) / denom - 1
y = 2.0 * (x - q01) / denom - 1.0
# 6. 核心操作:截断到 [-1.0, 1.0] 区间,防止输入给模型的动作越界
norm_state[:, target_slice] = np.clip(y, -1.0, 1.0)
return norm_state
def unnormalize_eepose_action(normalized_action, stats, start_idx, norm_len=3):
"""
读取统计信息 JSON 并对动作的特定切片进行反归一化 (Un-normalize)
:param normalized_action: 包含归一化动作的 numpy 数组,形状通常为 (batch_size, action_dim)
:param stats: 反归一化所需的统计信息,通常是从 JSON 文件中读取的字典
:param start_idx: 需要反归一化的特征在 action 向量中的起始索引 (即原代码中的 start)
:param norm_len: 需要反归一化的特征长度,默认为 3 (如 xyz 坐标)
:return: 反归一化后的动作数组 (返回新数组,不修改原数组)
"""
# 2. 提取并截取对应的 q01 和 q99
# 注意:为了和原归一化精度对齐,这里继续保持 float16
# import ipdb;ipdb.set_trace()
q01 = np.array(stats["q01"][start_idx : start_idx + norm_len], dtype=np.float16)
q99 = np.array(stats["q99"][start_idx : start_idx + norm_len], dtype=np.float16)
# 3. 计算分母 denom保持和原来完全相同的裁剪逻辑防除零
denom = np.clip(q99 - q01, a_min=1e-5, a_max=None)
# 4. 定位切片
target_slice = slice(start_idx, start_idx + norm_len)
# 5. 执行反向计算x = (y + 1) / 2 * denom + q01
# 复制一份以防污染原始的 normalized_action 矩阵
unnorm_action = normalized_action.copy()
y = unnorm_action[:, target_slice]
unnorm_action[:, target_slice] = (y + 1.0) / 2.0 * denom + q01
return unnorm_action
class StarvlaInferenceServer:
def __init__(self, config_path: str):
with open(config_path, "r") as f:
cfg = yaml.safe_load(f)
# import ipdb;ipdb.set_trace()
policy_server_cfg = cfg["policy_server"]
root_paths = cfg["general"]["root_paths"]
self.ckpt_source = policy_server_cfg["ckpt_source"]
self.ckpt_path = self._resolve_ckpt_path(
ckpt_url=policy_server_cfg["ckpt_path"],
root_paths=root_paths,
)
self.rel_eepose_stats = self.read_rel_eepose_stats()
self.host = policy_server_cfg.get("host", "0.0.0.0")
self.port = policy_server_cfg.get("port", 5000)
self.use_bf16 = policy_server_cfg.get("use_bf16", True)
self.unnorm_key = policy_server_cfg.get("unnorm_key", "oxe_bridge")
self.state_mode = policy_server_cfg.get("state_mode", "ee_pose7")
print("Loading StarVLA model...")
self.model = self.load_model()
print("Model loaded.")
self.app = Flask(__name__)
self.register_routes()
@staticmethod
def _resolve_ckpt_path(ckpt_url: str, root_paths: dict) -> str:
parsed = urlparse(ckpt_url)
if not parsed.scheme:
return ckpt_url
root = root_paths.get(parsed.scheme)
if not root:
raise KeyError(
f"cannot find the checkpoint root path in root_paths: {root_paths}"
)
rel = (parsed.netloc + parsed.path).lstrip("/")
return os.path.join(root, rel)
def read_rel_eepose_stats(self):
stats_path = Path(self.ckpt_path).parents[1] / "action_delta_eepose_stats.json"
if stats_path is None or not Path(stats_path).exists():
return {}
with open(stats_path, 'r', encoding='utf-8') as f:
stats = json.load(f)
return stats
def load_model(self):
from starVLA.model.framework.share_tools import read_mode_config, dict_to_namespace
from starVLA.model.framework.__init__ import build_framework
model_config, norm_stats = read_mode_config(self.ckpt_path)
cfg = dict_to_namespace(model_config)
cfg.trainer.pretrained_checkpoint = None
model = build_framework(cfg=cfg)
model.norm_stats = norm_stats
state_dict = torch.load(self.ckpt_path, map_location="cpu")
model.load_state_dict(state_dict, strict=True)
if self.use_bf16:
model = model.to(torch.bfloat16)
model = model.to("cuda").eval()
self.norm_stats = norm_stats
self.action_norm_stats = norm_stats.get(self.unnorm_key, {}).get("action", None)
return model
def parse_observation(self, obs, target_size=(320, 180)):
left_rgb, right_rgb, wrist_rgb = obs["rgb"]["Left_Camera"], obs["rgb"]["Right_Camera"], obs["rgb"]["Hand_Camera"]
img_left = Image.fromarray(cv2.resize(left_rgb, target_size))
img_right = Image.fromarray(cv2.resize(right_rgb, target_size))
img_wrist = Image.fromarray(cv2.resize(wrist_rgb, target_size))
state_vec = obs["state"] #[:10]
# import ipdb;ipdb.set_trace()
state_vec = pad_to_dim(np.array(state_vec), 100, axis=-1)
return img_left, img_right, img_wrist, state_vec, obs["prompt"]
def inference(self, observation: dict) -> dict:
img_left, img_right, img_wrist, state_vec, prompt = \
self.parse_observation(observation)
print(f"{state_vec.shape}")
vla_input = {
"batch_images": [[img_left, img_right, img_wrist]],
"instructions": [prompt],
"state": [state_vec]
}
# import ipdb;ipdb.set_trace()
with torch.no_grad():
output = self.model.predict_action(**vla_input)
actions = output.get("normalized_actions")
if isinstance(actions, torch.Tensor):
actions = actions.cpu().numpy()
if actions.ndim == 3:
actions = actions[0] # (16, 10)
# import ipdb;ipdb.set_trace()
# 反归一化特定切片 (假设需要反归一化的部分是 action 向量中的第 0-8 维,即 ee_pose7 + gripper_width)
actions[:, :3] = unnormalize_eepose_action(actions[:, :3], self.rel_eepose_stats, start_idx=0, norm_len=3)
# from ipdb import set_trace; set_trace()
return {"ee_delta_position_chunks": actions[:, :3].tolist(),
"ee_delta_rot6d_chunks": actions[:, 3:9].tolist(),
"gripper_width_chunks": actions[:, 9:10].tolist()}
def register_routes(self):
@self.app.route("/policy/inference", methods=["POST"])
def policy():
try:
data = pickle.loads(request.data)
result = self.inference(data)
body = pickle.dumps(result, protocol=4)
return Response(body, mimetype="application/octet-stream")
except Exception as e:
err_obj = {
"error": str(e),
"traceback": traceback.format_exc(),
}
body = pickle.dumps(err_obj, protocol=4)
return Response(
body,
mimetype="application/octet-stream",
status=500,
)
@self.app.route("/policy/health", methods=["GET"])
def health():
if self.model is None:
return Response("Failed to load model", mimetype="application/json", status=500)
return Response("OK", mimetype="application/json")
def run(self):
print("StarVLA policy server running")
print(f"Host: {self.host}")
print(f"Port: {self.port}")
self.app.run(
host=self.host,
port=self.port,
threaded=True
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="StarVLA inference server")
parser.add_argument(
"--config",
type=str,
default="./benchmark.yaml",
help="Path to benchmark.yaml",
)
args = parser.parse_args()
config_path = args.config
server = StarvlaInferenceServer(config_path)
server.run()

View File

@@ -8,8 +8,9 @@ from joysim.annotations.config_class import configclass, field
from joysim.annotations.stereotype import stereotype from joysim.annotations.stereotype import stereotype
from joysim.controllers.spawnable_controller import SpawnableController from joysim.controllers.spawnable_controller import SpawnableController
from joysim.controllers.visualize_controller import VisualizeController from joysim.controllers.visualize_controller import VisualizeController
from joysim.unisim.robots.models.modular_robot import ModularRobot
from joysim.utils.namespace import PoseVisualType, SimulatorType from joysim.utils.namespace import PoseVisualType, SimulatorType
from joysim.unisim.robots.configs.actuator_configs.grippers import GripperDriveJointConfig from joysim.unisim.robots.actuator_configs.grippers import GripperDriveJointConfig
from joysim.extensions.benchmark.action import RobotAction from joysim.extensions.benchmark.action import RobotAction
from joysim.extensions.benchmark.benchmark import ( from joysim.extensions.benchmark.benchmark import (
BenchmarkAction, BenchmarkAction,
@@ -25,6 +26,8 @@ from joysim.utils.pose import Pose
class StarvlaPolicyConfig(PolicyConfig): class StarvlaPolicyConfig(PolicyConfig):
robot_name: str = field(default="None", required=True, comment="The name of the robot") robot_name: str = field(default="None", required=True, comment="The name of the robot")
arm_name: str = field(default="main_arm", required=True, comment="The name of the arm module to control")
drive_name: str = field(default="robotiq_85_left_knuckle_joint", required=True, comment="The name of the drive module to control")
gripper_width_mapper_file: str = field(default="", required=True, comment="The file path to the gripper width mapper") gripper_width_mapper_file: str = field(default="", required=True, comment="The file path to the gripper width mapper")
visualize_action_ee_pose: bool = field(default=False, required=True, comment="Whether to visualize the action end effector pose") visualize_action_ee_pose: bool = field(default=False, required=True, comment="Whether to visualize the action end effector pose")
visualize_state_ee_pose: bool = field(default=False, required=True, comment="Whether to visualize the state end effector pose") visualize_state_ee_pose: bool = field(default=False, required=True, comment="Whether to visualize the state end effector pose")
@@ -64,6 +67,8 @@ class StarvlaPolicy(Policy):
super().__init__(config) super().__init__(config)
self.robot_name = config.robot_name self.robot_name = config.robot_name
self.arm_name = config.arm_name
self.drive_name = config.drive_name
self.sensor_names = config.sensor_names self.sensor_names = config.sensor_names
self.server_url = config.server_url self.server_url = config.server_url
self.prompt = config.prompt self.prompt = config.prompt
@@ -79,11 +84,15 @@ class StarvlaPolicy(Policy):
self.current_chunk_id = 0 self.current_chunk_id = 0
self.current_chunk_result = None self.current_chunk_result = None
self.run_trunk_size = self.config.run_trunk_size self.run_trunk_size = self.config.run_trunk_size
self.drive_joints: dict[str, GripperDriveJointConfig] = SpawnableController.control_robot(self.robot_name, "get_gripper_drive_joints").unwrap() self.robot: ModularRobot = SpawnableController.get_spawnable_data(self.robot_name).unwrap()
self.drive_joints: dict[str, GripperDriveJointConfig] = self.robot.get_arm(self.arm_name).get_ee().get_drive_joints()
self.robot_drive_name = list(self.drive_joints.keys())[0]
for joint_name, joint_config in self.drive_joints.items(): for joint_name, joint_config in self.drive_joints.items():
SpawnableController.control_robot(self.robot_name, "set_joint_stiffness", parameters={"joint_names": [joint_name], "stiffness": joint_config.position_control_stiffness}).unwrap() SpawnableController.control_robot(self.robot_name, "set_joint_stiffness", parameters={"joint_names": [joint_name], "stiffness": joint_config.position_control_stiffness}).unwrap()
SpawnableController.control_robot(self.robot_name, "set_joint_damping", parameters={"joint_names": [joint_name], "damping": joint_config.position_control_damping}).unwrap() SpawnableController.control_robot(self.robot_name, "set_joint_damping", parameters={"joint_names": [joint_name], "damping": joint_config.position_control_damping}).unwrap()
SpawnableController.control_robot(self.robot_name, "set_joint_effort_limit", parameters={"joint_names": [joint_name], "effort_limit": 50}).unwrap() SpawnableController.control_robot(self.robot_name, "set_joint_effort_limit", parameters={"joint_names": [joint_name], "effort_limit": 5000}).unwrap()
SpawnableController.control_robot(self.robot_name, "set_joint_effort_limit", parameters={"joint_names": [self.robot_drive_name], "effort_limit": 5000}).unwrap()
self.max_width = float("-inf") self.max_width = float("-inf")
self.min_width = float("inf") self.min_width = float("inf")
for entry in self.gripper_width_mapper: for entry in self.gripper_width_mapper:
@@ -113,11 +122,13 @@ class StarvlaPolicy(Policy):
def preprocess_observation(self, benchmark_observation: BenchmarkObservation) -> dict: def preprocess_observation(self, benchmark_observation: BenchmarkObservation) -> dict:
robot_obs = benchmark_observation.get_robot_observations(self.robot_name)["robot_data"] robot_obs = benchmark_observation.get_robot_observations(self.robot_name)["robot_data"]
ee_pose_base = robot_obs["ee_pose"]["base_frame"] ee_pose_base = robot_obs["ee_pose"][self.arm_name]["base_frame"]
ee_position, ee_rot6d = ee_pose_base["position"],ee_pose_base["rot6d"] ee_position, ee_rot6d = ee_pose_base["position"],ee_pose_base["rot6d"]
normalized_gripper_width = self.__map_joint_position_to_normalized_width(robot_obs["gripper_drive_state"]["position"][0]) arm_joint_positions = robot_obs["joint_positions"][:7] # 临时多加了一个drive的位置现在读的最后一个joint值是drive
drive_joint_positions = robot_obs["joint_positions"][-1]
normalized_gripper_width = self.__map_joint_position_to_normalized_width(drive_joint_positions)
Log.debug(f"input normalized_gripper_width state: {round(normalized_gripper_width, 2)}") Log.debug(f"input normalized_gripper_width state: {round(normalized_gripper_width, 2)}")
state = np.concatenate([ee_position,ee_rot6d,np.array([normalized_gripper_width])]) state = np.concatenate([ee_position,ee_rot6d,np.array([normalized_gripper_width]), [0]*10, np.array(arm_joint_positions)])
rgb_data = {} rgb_data = {}
for sensor_name in self.sensor_names: for sensor_name in self.sensor_names:
sensor_obs = benchmark_observation.get_sensor_observations(sensor_name) sensor_obs = benchmark_observation.get_sensor_observations(sensor_name)
@@ -137,6 +148,7 @@ class StarvlaPolicy(Policy):
data=payload, data=payload,
headers={"Content-Type": "application/octet-stream"} headers={"Content-Type": "application/octet-stream"}
) )
self.test_obs = observation["state"] #TODO
self._handle_server_error(response) self._handle_server_error(response)
result = pickle.loads(response.content) result = pickle.loads(response.content)
max_trunk_size = len(result["ee_delta_position_chunks"]) max_trunk_size = len(result["ee_delta_position_chunks"])
@@ -176,6 +188,8 @@ class StarvlaPolicy(Policy):
def postprocess_action(self, action: dict) -> BenchmarkAction: def postprocess_action(self, action: dict) -> BenchmarkAction:
benchmark_action = BenchmarkAction() benchmark_action = BenchmarkAction()
Log.debug(f"observation: {self.test_obs}")
# import ipdb;ipdb.set_trace()
# get base frame end-effector pose # get base frame end-effector pose
delta_ee_pose = Pose(position=action["ee_delta_position_chunks"][self.current_chunk_id], rot6d=action["ee_delta_rot6d_chunks"][self.current_chunk_id]) delta_ee_pose = Pose(position=action["ee_delta_position_chunks"][self.current_chunk_id], rot6d=action["ee_delta_rot6d_chunks"][self.current_chunk_id])
@@ -221,7 +235,7 @@ class StarvlaPolicy(Policy):
VisualizeController.create_pose_visualization( VisualizeController.create_pose_visualization(
robot_base_world * pose_state_base, robot_base_world * pose_state_base,
name=f"{self.robot_name}/starvla_state_ee", name=f"{self.robot_name}/starvla_state_ee",
simulator=SimulatorType.ISAACSIM, simulator=SimulatorType.ISAACLAB,
pose_type=PoseVisualType.COORDINATE, pose_type=PoseVisualType.COORDINATE,
extra_params={"axis_length": 0.08, "thickness": 0.006}, extra_params={"axis_length": 0.08, "thickness": 0.006},
).unwrap() ).unwrap()
@@ -229,7 +243,7 @@ class StarvlaPolicy(Policy):
VisualizeController.create_pose_visualization( VisualizeController.create_pose_visualization(
robot_base_world * pose_action_base, robot_base_world * pose_action_base,
name=f"{self.robot_name}/starvla_action_ee", name=f"{self.robot_name}/starvla_action_ee",
simulator=SimulatorType.ISAACSIM, simulator=SimulatorType.ISAACLAB,
pose_type=PoseVisualType.COORDINATE, pose_type=PoseVisualType.COORDINATE,
extra_params={"axis_length": 0.1, "thickness": 0.006}, extra_params={"axis_length": 0.1, "thickness": 0.006},
).unwrap() ).unwrap()
@@ -239,5 +253,5 @@ class StarvlaPolicy(Policy):
return return
for target_name in self.visualize_bounding_box_targets: for target_name in self.visualize_bounding_box_targets:
VisualizeController.visualize_target_bounding_box( VisualizeController.visualize_target_bounding_box(
target_name, simulator=SimulatorType.ISAACSIM target_name, simulator=SimulatorType.ISAACLAB
).unwrap() ).unwrap()

5312
task.yaml Normal file

File diff suppressed because it is too large Load Diff