update
This commit is contained in:
193
benchmark.yaml
193
benchmark.yaml
@@ -1,77 +1,80 @@
|
||||
general:
|
||||
scan_project: true
|
||||
root_paths:
|
||||
asset: /home/ubuntu/projects/gen_data/data
|
||||
checkpoints: /home/ubuntu/data/models
|
||||
output: /home/ubuntu/output
|
||||
asset: /home/ubuntu/xionghao/sim_hofee/sim_hofee/assets
|
||||
checkpoints: /home/ubuntu/xionghao/starVLA-starVLA/playground/Checkpoints
|
||||
output: /home/ubuntu/xionghao/sim_hofee
|
||||
|
||||
simulation:
|
||||
stereotype: isaaclab
|
||||
intiailize_steps: 300
|
||||
launch_config:
|
||||
device: cuda
|
||||
enable_cameras: true
|
||||
headless: false
|
||||
livestream: 0
|
||||
|
||||
|
||||
scene:
|
||||
name: kujiale_multispace
|
||||
name: 827313_home
|
||||
base_config:
|
||||
stereotype: usd
|
||||
name: _827313_home_workspace_00
|
||||
name: _827313_home_workspace_01
|
||||
source: local
|
||||
asset_path: asset://scenes/kujiale_multispace/827313_home/workspace_00.usd
|
||||
asset_path: asset://scenes/kujiale_multispace/827313_home/workspace_01.usd
|
||||
object_cfg_dict:
|
||||
omni6DPose_can_016:
|
||||
name: omni6DPose_can_016
|
||||
omni6DPose_timer_017:
|
||||
name: omni6DPose_timer_017
|
||||
stereotype: rigid
|
||||
source: local
|
||||
asset_path: asset://objects/omni6DPose/can/omni6DPose_can_016/Aligned.usd
|
||||
asset_path: asset://objects/omni6DPose/timer/omni6DPose_timer_017/Aligned.usd
|
||||
scale:
|
||||
- 0.001
|
||||
- 0.001
|
||||
- 0.001
|
||||
position:
|
||||
- 0.2
|
||||
- -4.15243
|
||||
- 0.5
|
||||
- 0.552364
|
||||
- -4.0582599999999995
|
||||
- 0.524713118
|
||||
quaternion:
|
||||
- -0.304408012043137
|
||||
- -0.304408012043137
|
||||
- 0.638228612805745
|
||||
- 0.6382286128057448
|
||||
- 0.166210542394157
|
||||
- 0.166210542394157
|
||||
- 0.6872947370648492
|
||||
- 0.6872947370648491
|
||||
axis_y_up: true
|
||||
omni6DPose_plug_001:
|
||||
name: omni6DPose_plug_001
|
||||
omni6DPose_book_031:
|
||||
name: omni6DPose_book_031
|
||||
stereotype: rigid
|
||||
source: local
|
||||
asset_path: asset://objects/omni6DPose/plug/omni6DPose_plug_001/Aligned.usd
|
||||
asset_path: asset://objects/omni6DPose/book/omni6DPose_book_031/Aligned.usd
|
||||
scale:
|
||||
- 0.001
|
||||
- 0.001
|
||||
- 0.001
|
||||
position:
|
||||
- 0.219859
|
||||
- -3.852430000000001
|
||||
- 0.510259093
|
||||
- 0.6623640000000001
|
||||
- -3.7882599999999997
|
||||
- 0.5101601435
|
||||
quaternion:
|
||||
- 0.7070997233636068
|
||||
- 0.707099723363607
|
||||
- 0.003159306745230588
|
||||
- 0.003159306745230588
|
||||
- 0.7063055546421202
|
||||
- 0.7063055546421203
|
||||
- -0.03365209475927027
|
||||
- -0.033652094759270265
|
||||
axis_y_up: true
|
||||
robot_cfg_dict:
|
||||
Franka:
|
||||
name: Franka
|
||||
Franka_Robotiq_2f85:
|
||||
name: Franka_Robotiq_2f85
|
||||
asset_path: asset://Franka/franka_robotiq_2f85_zedmini.usd
|
||||
position:
|
||||
- -0.3779152929316859
|
||||
- -3.943187336951741
|
||||
- 0.3130090550954169
|
||||
rotation:
|
||||
- 0.9987576855319703
|
||||
- 1.082364
|
||||
- -3.92826
|
||||
- 0.47629299999999997
|
||||
rotation:
|
||||
- 7.549799991308018e-08
|
||||
- 0.0
|
||||
- 0.0
|
||||
- 0.04983056883903615
|
||||
stereotype: single_gripper_arm_robot
|
||||
- 0.9999999999999973
|
||||
stereotype: modular_robot
|
||||
source: local
|
||||
ee_link_name: panda_link8
|
||||
ik_joint_names:
|
||||
@@ -83,18 +86,43 @@ scene:
|
||||
- panda_joint6
|
||||
- panda_joint7
|
||||
init_joint_position:
|
||||
# panda_joint1: 0.18641542
|
||||
# panda_joint2: 0.47660449
|
||||
# panda_joint3: -0.03320411
|
||||
# panda_joint4: -2.27693725
|
||||
# panda_joint5: 0.98161776
|
||||
# panda_joint6: 2.20247197
|
||||
# panda_joint7: 0.71794897
|
||||
panda_joint2: -0.1633
|
||||
panda_joint4: -1.07
|
||||
panda_joint6: 0.8933
|
||||
panda_joint7: 0.785
|
||||
arm_actuator_name: franka_arm
|
||||
gripper_actuator_name: robotiq_2f_85
|
||||
use_planner: true
|
||||
planner_cfg:
|
||||
stereotype: curobo
|
||||
lazy_init: true
|
||||
robot_config_file: asset://curobo/franka_robotiq_2f85/franka_robotiq_2f85.yml
|
||||
world_config_source: stage
|
||||
arm_modules:
|
||||
main_arm:
|
||||
arm_actuator_name: franka_arm
|
||||
ee_link_name: panda_link8
|
||||
ee_type: gripper
|
||||
ee_actuator_name: robotiq_gripper
|
||||
actuator_cfg_dict:
|
||||
franka_arm:
|
||||
stereotype: arm
|
||||
joint_names_expr: [panda_joint1, panda_joint2, panda_joint3, panda_joint4, panda_joint5, panda_joint6, panda_joint7]
|
||||
stiffness: 3000.0
|
||||
damping: 800.0
|
||||
robotiq_gripper:
|
||||
stereotype: gripper
|
||||
joint_names_expr: [robotiq_85_left_knuckle_joint]
|
||||
stiffness: 10000
|
||||
damping: 500.0
|
||||
close_control_type: velocity
|
||||
open_control_type: position
|
||||
drive_joints:
|
||||
robotiq_85_left_knuckle_joint:
|
||||
close_velocity: 5.0
|
||||
open_velocity: -5.0
|
||||
close_position: 0.8
|
||||
open_position: 0.0
|
||||
use_planner: false
|
||||
sensor_cfg_dict:
|
||||
Hand_Camera:
|
||||
name: Hand_Camera
|
||||
@@ -108,23 +136,23 @@ scene:
|
||||
camera_model: pinhole
|
||||
fix_camera: true
|
||||
focal_length: 2.8
|
||||
horizontal_aperture: 4.890881131191918
|
||||
horizontal_aperture: 4.893416860031241
|
||||
vertical_aperture: 2.7608816125932627
|
||||
convention: opengl
|
||||
attach_to:
|
||||
target_name: Franka
|
||||
target_name: Franka_Robotiq_2f85
|
||||
is_articulation_part: true
|
||||
articulation_part_name: panda_link8
|
||||
create_fixed_joint: true
|
||||
local_position:
|
||||
- -0.07176474936469446
|
||||
- 0.02890958100382394
|
||||
- 0.01978286477078585
|
||||
- -0.07128738160694643
|
||||
- 0.03551506300731732
|
||||
- 0.018927748370281355
|
||||
local_rotation:
|
||||
- 0.12352531576657892
|
||||
- 0.7000519980574638
|
||||
- -0.6933396337721066
|
||||
- -0.11810524383481495
|
||||
- -0.12117023430710862
|
||||
- -0.6862313269668
|
||||
- 0.7070213671685396
|
||||
- 0.12052023305019997
|
||||
Left_Camera:
|
||||
name: Left_Camera
|
||||
stereotype: camera
|
||||
@@ -137,20 +165,20 @@ scene:
|
||||
camera_model: pinhole
|
||||
fix_camera: false
|
||||
focal_length: 2.1
|
||||
horizontal_aperture: 5.019302546405283
|
||||
horizontal_aperture: 5.030789363390793
|
||||
vertical_aperture: 2.833796298140747
|
||||
convention: opengl
|
||||
attach_to:
|
||||
target_name: Franka
|
||||
target_name: Franka_Robotiq_2f85
|
||||
local_position:
|
||||
- -0.07383269512283744
|
||||
- -0.4566116797983716
|
||||
- 0.5664518136443555
|
||||
- 0.31702696813014064
|
||||
- -0.3844238699868664
|
||||
- 0.6551552990137672
|
||||
local_rotation:
|
||||
- 0.7669289562134721
|
||||
- 0.4437472984144337
|
||||
- -0.232167881174614
|
||||
- -0.4012560108236337
|
||||
- 0.8742457685173938
|
||||
- 0.38378563025938384
|
||||
- -0.11951449178007277
|
||||
- -0.27224843891267797
|
||||
Right_Camera:
|
||||
name: Right_Camera
|
||||
stereotype: camera
|
||||
@@ -163,20 +191,20 @@ scene:
|
||||
camera_model: pinhole
|
||||
fix_camera: false
|
||||
focal_length: 2.1
|
||||
horizontal_aperture: 5.053278483887542
|
||||
horizontal_aperture: 5.050364265142387
|
||||
vertical_aperture: 2.833796298140747
|
||||
convention: opengl
|
||||
attach_to:
|
||||
target_name: Franka
|
||||
target_name: Franka_Robotiq_2f85
|
||||
local_position:
|
||||
- 0.4501524009531093
|
||||
- 0.7206899873248545
|
||||
- 0.27525652672526973
|
||||
- 0.21844487914880717
|
||||
- 0.20172329179193413
|
||||
- 0.30108042236545296
|
||||
local_rotation:
|
||||
- 0.007056820167439314
|
||||
- 0.007107689842383092
|
||||
- 0.709606075042274
|
||||
- 0.7045274304789896
|
||||
- -0.5316249212230874
|
||||
- -0.38697158527836417
|
||||
- 0.44338617110944967
|
||||
- 0.6091277686910994
|
||||
|
||||
extension:
|
||||
extension_cfg_dict:
|
||||
@@ -185,9 +213,10 @@ extension:
|
||||
stereotype: data_collect
|
||||
observer_cfgs:
|
||||
- stereotype: robot_observer
|
||||
name: Franka
|
||||
name: Franka_Robotiq_2f85
|
||||
target_joint_names: [panda_joint1, panda_joint2, panda_joint3, panda_joint4, panda_joint5, panda_joint6, panda_joint7, robotiq_85_left_knuckle_joint]
|
||||
observe_ee_pose: true
|
||||
observe_gripper_drive_state: true
|
||||
observe_ee_state: true
|
||||
observe_joint_position: true
|
||||
observe_joint_velocity: true
|
||||
observe_joint_positions: true
|
||||
@@ -210,22 +239,23 @@ extension:
|
||||
stereotype: benchmark
|
||||
data_collector_name: benchmark_data_collect
|
||||
action_frequency: 15.0
|
||||
timeout_per_episode: 30
|
||||
timeout_per_episode: 300
|
||||
goals:
|
||||
- name: cola on top of book
|
||||
description: check if the cola bottle is on the book
|
||||
stereotype: on_top
|
||||
object_A_name: omni6DPose_plug_001
|
||||
object_B_name: omni6DPose_can_016
|
||||
object_A_name: omni6DPose_book_031
|
||||
object_B_name: omni6DPose_timer_017
|
||||
policy:
|
||||
stereotype: starvla
|
||||
robot_name: Franka
|
||||
robot_name: Franka_Robotiq_2f85
|
||||
arm_name: main_arm
|
||||
sensor_names: [Hand_Camera, Left_Camera, Right_Camera]
|
||||
prompt: pick up the can
|
||||
prompt: pick up the timer and put on the book
|
||||
run_trunk_size: 16
|
||||
gripper_width_mapper_file: ./gripper_width_robotiq_2f85_fixed.json
|
||||
visualize_action_ee_pose: false
|
||||
visualize_state_ee_pose: false
|
||||
visualize_action_ee_pose: true
|
||||
visualize_state_ee_pose: true
|
||||
visualize_bounding_box_targets: [] # [omni6DPose_plug_001, omni6DPose_can_016] # 打开会被policy看到,会影响policy的推理结果
|
||||
|
||||
recorder:
|
||||
@@ -237,7 +267,10 @@ extension:
|
||||
postprocess_list: ["hdf5", "video"]
|
||||
|
||||
policy_server:
|
||||
ckpt_path: checkpoints://0318_qwenpi_droid_pretrain_8node/checkpoints/steps_30000_pytorch_model.pt
|
||||
# ckpt_path: checkpoints://0324_qwenpi_droid_pretrain_8node/checkpoints/steps_30000_pytorch_model.pt
|
||||
# ckpt_path: checkpoints://0405_qwenpi_droid_norm_pretrain_8node/checkpoints/steps_60000_pytorch_model.pt
|
||||
# ckpt_path: checkpoints://0407_qwenpi_droid_postrain/final_model/pytorch_model.pt
|
||||
ckpt_path: checkpoints://0407_qwenpi_droid_from_scratch/final_model/pytorch_model.pt
|
||||
ckpt_source: local
|
||||
host: 0.0.0.0
|
||||
port: 5000
|
||||
|
||||
257
benchmark_can.yaml
Normal file
257
benchmark_can.yaml
Normal file
@@ -0,0 +1,257 @@
|
||||
general:
|
||||
scan_project: true
|
||||
root_paths:
|
||||
asset: /home/ubuntu/xionghao/sim_hofee/sim_hofee/assets
|
||||
checkpoints: /home/ubuntu/xionghao/starVLA-starVLA/playground/Checkpoints
|
||||
output: /home/ubuntu/xionghao/sim_hofee
|
||||
|
||||
simulation:
|
||||
launch_config:
|
||||
device: cuda
|
||||
enable_cameras: true
|
||||
headless: false
|
||||
livestream: 0
|
||||
|
||||
scene:
|
||||
name: kujiale_multispace
|
||||
base_config:
|
||||
stereotype: usd
|
||||
name: _827313_home_workspace_00
|
||||
source: local
|
||||
asset_path: asset://scenes/kujiale_multispace/827313_home/workspace_00.usd
|
||||
object_cfg_dict:
|
||||
omni6DPose_can_016:
|
||||
name: omni6DPose_can_016
|
||||
stereotype: rigid
|
||||
source: local
|
||||
asset_path: asset://objects/omni6DPose/can/omni6DPose_can_016/Aligned.usd
|
||||
scale:
|
||||
- 0.001
|
||||
- 0.001
|
||||
- 0.001
|
||||
position:
|
||||
- 0.2
|
||||
- -4.05243
|
||||
- 0.5
|
||||
quaternion:
|
||||
- -0.304408012043137
|
||||
- -0.304408012043137
|
||||
- 0.638228612805745
|
||||
- 0.6382286128057448
|
||||
axis_y_up: true
|
||||
omni6DPose_can_011:
|
||||
name: omni6DPose_can_011
|
||||
stereotype: rigid
|
||||
source: local
|
||||
asset_path: asset://objects/omni6DPose/can/omni6DPose_can_011/Aligned.usd
|
||||
scale:
|
||||
- 0.001
|
||||
- 0.001
|
||||
- 0.001
|
||||
position:
|
||||
- 0.219859
|
||||
- -3.82430000000001
|
||||
- 0.510259093
|
||||
quaternion:
|
||||
- 0.7070997233636068
|
||||
- 0.707099723363607
|
||||
- 0.003159306745230588
|
||||
- 0.003159306745230588
|
||||
axis_y_up: true
|
||||
robot_cfg_dict:
|
||||
Franka:
|
||||
name: Franka
|
||||
asset_path: asset://Franka/franka_robotiq_2f85_zedmini.usd
|
||||
position:
|
||||
- -0.3779152929316859
|
||||
- -3.943187336951741
|
||||
- 0.3130090550954169
|
||||
rotation:
|
||||
- 0.9987576855319703
|
||||
- 0.0
|
||||
- 0.0
|
||||
- 0.04983056883903615
|
||||
stereotype: single_gripper_arm_robot
|
||||
source: local
|
||||
ee_link_name: panda_link8
|
||||
ik_joint_names:
|
||||
- panda_joint1
|
||||
- panda_joint2
|
||||
- panda_joint3
|
||||
- panda_joint4
|
||||
- panda_joint5
|
||||
- panda_joint6
|
||||
- panda_joint7
|
||||
init_joint_position:
|
||||
panda_joint2: -0.1633
|
||||
panda_joint4: -1.07
|
||||
panda_joint6: 0.8933
|
||||
panda_joint7: 0.785
|
||||
# panda_joint1: 0.00201746
|
||||
# panda_joint2: -0.25797674
|
||||
# panda_joint3: -0.02563748
|
||||
# panda_joint4: -1.68997145
|
||||
# panda_joint5: -0.0735122
|
||||
# panda_joint6: 1.45765436
|
||||
# panda_joint7: 0.56564939
|
||||
arm_actuator_name: franka_arm
|
||||
gripper_actuator_name: robotiq_2f_85
|
||||
use_planner: true
|
||||
planner_cfg:
|
||||
stereotype: curobo
|
||||
lazy_init: true
|
||||
robot_config_file: asset://curobo/franka_robotiq_2f85/franka_robotiq_2f85.yml
|
||||
world_config_source: stage
|
||||
sensor_cfg_dict:
|
||||
Hand_Camera:
|
||||
name: Hand_Camera
|
||||
stereotype: camera
|
||||
data_types:
|
||||
- rgb
|
||||
- depth
|
||||
- normals
|
||||
width: 1280
|
||||
height: 720
|
||||
camera_model: pinhole
|
||||
fix_camera: true
|
||||
focal_length: 2.8
|
||||
horizontal_aperture: 4.890881131191918
|
||||
vertical_aperture: 2.7608816125932627
|
||||
convention: opengl
|
||||
attach_to:
|
||||
target_name: Franka
|
||||
is_articulation_part: true
|
||||
articulation_part_name: panda_link8
|
||||
create_fixed_joint: true
|
||||
local_position:
|
||||
- -0.07176474936469446
|
||||
- 0.02890958100382394
|
||||
- 0.01978286477078585
|
||||
local_rotation:
|
||||
- 0.12352531576657892
|
||||
- 0.7000519980574638
|
||||
- -0.6933396337721066
|
||||
- -0.11810524383481495
|
||||
Left_Camera:
|
||||
name: Left_Camera
|
||||
stereotype: camera
|
||||
data_types:
|
||||
- rgb
|
||||
- depth
|
||||
- normals
|
||||
width: 1280
|
||||
height: 720
|
||||
camera_model: pinhole
|
||||
fix_camera: false
|
||||
focal_length: 2.1
|
||||
horizontal_aperture: 5.019302546405283
|
||||
vertical_aperture: 2.833796298140747
|
||||
convention: opengl
|
||||
attach_to:
|
||||
target_name: Franka
|
||||
local_position:
|
||||
- -0.07383269512283744
|
||||
- -0.4566116797983716
|
||||
- 0.5664518136443555
|
||||
local_rotation:
|
||||
- 0.7669289562134721
|
||||
- 0.4437472984144337
|
||||
- -0.232167881174614
|
||||
- -0.4012560108236337
|
||||
Right_Camera:
|
||||
name: Right_Camera
|
||||
stereotype: camera
|
||||
data_types:
|
||||
- rgb
|
||||
- depth
|
||||
- normals
|
||||
width: 1280
|
||||
height: 720
|
||||
camera_model: pinhole
|
||||
fix_camera: false
|
||||
focal_length: 2.1
|
||||
horizontal_aperture: 5.053278483887542
|
||||
vertical_aperture: 2.833796298140747
|
||||
convention: opengl
|
||||
attach_to:
|
||||
target_name: Franka
|
||||
local_position:
|
||||
- 0.4501524009531093
|
||||
- 0.7206899873248545
|
||||
- 0.27525652672526973
|
||||
local_rotation:
|
||||
- 0.007056820167439314
|
||||
- 0.007107689842383092
|
||||
- 0.709606075042274
|
||||
- 0.7045274304789896
|
||||
|
||||
extension:
|
||||
extension_cfg_dict:
|
||||
benchmark_data_collect:
|
||||
enable: true
|
||||
stereotype: data_collect
|
||||
observer_cfgs:
|
||||
- stereotype: robot_observer
|
||||
name: Franka
|
||||
target_joint_names: [panda_joint1, panda_joint2, panda_joint3, panda_joint4, panda_joint5, panda_joint6, panda_joint7]
|
||||
observe_ee_pose: true
|
||||
observe_gripper_drive_state: true
|
||||
observe_joint_position: true
|
||||
observe_joint_velocity: true
|
||||
observe_joint_positions: true
|
||||
observe_joint_velocities: true
|
||||
observe_joint_accelerations: true
|
||||
observe_joint_position_targets: true
|
||||
observe_joint_velocity_targets: true
|
||||
- stereotype: sensor_observer
|
||||
name: Hand_Camera
|
||||
observe_rgb: true
|
||||
- stereotype: sensor_observer
|
||||
name: Left_Camera
|
||||
observe_rgb: true
|
||||
- stereotype: sensor_observer
|
||||
name: Right_Camera
|
||||
observe_rgb: true
|
||||
|
||||
starvla_benchmark:
|
||||
enable: true
|
||||
stereotype: benchmark
|
||||
data_collector_name: benchmark_data_collect
|
||||
action_frequency: 15.0
|
||||
timeout_per_episode: 300
|
||||
goals:
|
||||
- name: cola on top of book
|
||||
description: check if the cola bottle is on the book
|
||||
stereotype: on_top
|
||||
object_A_name: omni6DPose_can_016
|
||||
object_B_name: omni6DPose_can_011
|
||||
policy:
|
||||
stereotype: starvla
|
||||
robot_name: Franka
|
||||
sensor_names: [Hand_Camera, Left_Camera, Right_Camera]
|
||||
prompt: pick up the red can on the table
|
||||
run_trunk_size: 16
|
||||
gripper_width_mapper_file: ./gripper_width_robotiq_2f85_fixed.json
|
||||
visualize_action_ee_pose: false
|
||||
visualize_state_ee_pose: false
|
||||
visualize_bounding_box_targets: [] # [omni6DPose_plug_001, omni6DPose_can_016] # 打开会被policy看到,会影响policy的推理结果
|
||||
|
||||
recorder:
|
||||
enable: true # set to true to record the data
|
||||
stereotype: record
|
||||
data_collector_name: benchmark_data_collect
|
||||
record_fps: 30
|
||||
backend_root_path: output://benchmark_record
|
||||
postprocess_list: ["hdf5", "video"]
|
||||
|
||||
policy_server:
|
||||
# ckpt_path: checkpoints://0324_qwenpi_droid_pretrain_8node/checkpoints/steps_30000_pytorch_model.pt
|
||||
# ckpt_path: checkpoints://0401_qwenpi_droid_pretrain_8node/checkpoints/steps_20000_pytorch_model.p
|
||||
ckpt_path: checkpoints://0405_qwenpi_droid_norm_pretrain_8node/checkpoints/steps_60000_pytorch_model.pt
|
||||
# ckpt_path: checkpoints://0407_qwenpi_droid_postrain/checkpoints/steps_10000_pytorch_model.pt
|
||||
ckpt_source: local
|
||||
host: 0.0.0.0
|
||||
port: 5000
|
||||
use_bf16: true
|
||||
unnorm_key: oxe_bridge
|
||||
state_mode: ee_pose7
|
||||
@@ -12,6 +12,14 @@ import cv2
|
||||
from flask import Flask, request, Response
|
||||
from PIL import Image
|
||||
|
||||
def pad_to_dim(x: np.ndarray, target_dim: int, axis: int = -1, value: float = 0.0) -> np.ndarray:
|
||||
"""Pad an array to the target dimension with zeros along the specified axis."""
|
||||
current_dim = x.shape[axis]
|
||||
if current_dim < target_dim:
|
||||
pad_width = [(0, 0)] * len(x.shape)
|
||||
pad_width[axis] = (0, target_dim - current_dim)
|
||||
return np.pad(x, pad_width, constant_values=value)
|
||||
return x
|
||||
|
||||
class StarvlaInferenceServer:
|
||||
|
||||
@@ -19,7 +27,7 @@ class StarvlaInferenceServer:
|
||||
|
||||
with open(config_path, "r") as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
|
||||
|
||||
policy_server_cfg = cfg["policy_server"]
|
||||
root_paths = cfg["general"]["root_paths"]
|
||||
self.ckpt_source = policy_server_cfg["ckpt_source"]
|
||||
@@ -88,14 +96,15 @@ class StarvlaInferenceServer:
|
||||
img_wrist = Image.fromarray(cv2.resize(wrist_rgb, target_size))
|
||||
|
||||
state_vec = obs["state"]
|
||||
|
||||
# import ipdb;ipdb.set_trace()
|
||||
# state_vec = pad_to_dim(np.array(state_vec), 100, axis=-1)
|
||||
return img_left, img_right, img_wrist, state_vec, obs["prompt"]
|
||||
|
||||
def inference(self, observation: dict) -> dict:
|
||||
|
||||
|
||||
img_left, img_right, img_wrist, state_vec, prompt = \
|
||||
self.parse_observation(observation)
|
||||
|
||||
print(f"{state_vec.shape}")
|
||||
vla_input = {
|
||||
"batch_images": [[img_left, img_right, img_wrist]],
|
||||
"instructions": [prompt],
|
||||
|
||||
257
starvla_inference_server_unnorm.py
Normal file
257
starvla_inference_server_unnorm.py
Normal file
@@ -0,0 +1,257 @@
|
||||
|
||||
import yaml
|
||||
import pickle
|
||||
import argparse
|
||||
import os
|
||||
import traceback
|
||||
from urllib.parse import urlparse
|
||||
import numpy as np
|
||||
import torch
|
||||
import cv2
|
||||
import json
|
||||
from flask import Flask, request, Response
|
||||
from PIL import Image
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def pad_to_dim(x: np.ndarray, target_dim: int, axis: int = -1, value: float = 0.0) -> np.ndarray:
|
||||
"""Pad an array to the target dimension with zeros along the specified axis."""
|
||||
current_dim = x.shape[axis]
|
||||
if current_dim < target_dim:
|
||||
pad_width = [(0, 0)] * len(x.shape)
|
||||
pad_width[axis] = (0, target_dim - current_dim)
|
||||
return np.pad(x, pad_width, constant_values=value)
|
||||
return x
|
||||
|
||||
def normalize_eepose_state(raw_state, stats, start_idx, norm_len=3):
|
||||
"""
|
||||
根据统计信息对原始动作的特定切片进行归一化 (Normalize),映射到 [-1.0, 1.0] 区间
|
||||
|
||||
:param raw_action: 包含原始动作的 numpy 数组,形状通常为 (batch_size, action_dim)
|
||||
:param stats: 归一化所需的统计信息字典 (包含 'q01' 和 'q99')
|
||||
:param start_idx: 需要归一化的特征在 action 向量中的起始索引
|
||||
:param norm_len: 需要归一化的特征长度,默认为 3 (如 xyz 坐标)
|
||||
:return: 归一化后的动作数组 (返回新数组,不修改原数组)
|
||||
"""
|
||||
# 1. 提取并截取对应的 q01 和 q99,保持 float16 精度对齐
|
||||
q01 = np.array(stats["q01"][start_idx : start_idx + norm_len], dtype=np.float16)
|
||||
q99 = np.array(stats["q99"][start_idx : start_idx + norm_len], dtype=np.float16)
|
||||
|
||||
# 2. 计算分母 denom,防除零保护
|
||||
denom = np.clip(q99 - q01, a_min=1e-5, a_max=None)
|
||||
|
||||
# 3. 定位切片
|
||||
target_slice = slice(start_idx, start_idx + norm_len)
|
||||
|
||||
# 4. 复制一份以防污染原始数据矩阵
|
||||
norm_state = raw_state.copy()
|
||||
x = norm_state[:, target_slice]
|
||||
|
||||
# 5. 执行正向映射计算:y = 2 * (x - q01) / denom - 1
|
||||
y = 2.0 * (x - q01) / denom - 1.0
|
||||
|
||||
# 6. 核心操作:截断到 [-1.0, 1.0] 区间,防止输入给模型的动作越界
|
||||
norm_state[:, target_slice] = np.clip(y, -1.0, 1.0)
|
||||
|
||||
return norm_state
|
||||
|
||||
def unnormalize_eepose_action(normalized_action, stats, start_idx, norm_len=3):
|
||||
"""
|
||||
读取统计信息 JSON 并对动作的特定切片进行反归一化 (Un-normalize)
|
||||
|
||||
:param normalized_action: 包含归一化动作的 numpy 数组,形状通常为 (batch_size, action_dim)
|
||||
:param stats: 反归一化所需的统计信息,通常是从 JSON 文件中读取的字典
|
||||
:param start_idx: 需要反归一化的特征在 action 向量中的起始索引 (即原代码中的 start)
|
||||
:param norm_len: 需要反归一化的特征长度,默认为 3 (如 xyz 坐标)
|
||||
:return: 反归一化后的动作数组 (返回新数组,不修改原数组)
|
||||
"""
|
||||
|
||||
# 2. 提取并截取对应的 q01 和 q99
|
||||
# 注意:为了和原归一化精度对齐,这里继续保持 float16
|
||||
# import ipdb;ipdb.set_trace()
|
||||
q01 = np.array(stats["q01"][start_idx : start_idx + norm_len], dtype=np.float16)
|
||||
q99 = np.array(stats["q99"][start_idx : start_idx + norm_len], dtype=np.float16)
|
||||
|
||||
# 3. 计算分母 denom,保持和原来完全相同的裁剪逻辑防除零
|
||||
denom = np.clip(q99 - q01, a_min=1e-5, a_max=None)
|
||||
|
||||
# 4. 定位切片
|
||||
target_slice = slice(start_idx, start_idx + norm_len)
|
||||
|
||||
# 5. 执行反向计算:x = (y + 1) / 2 * denom + q01
|
||||
# 复制一份以防污染原始的 normalized_action 矩阵
|
||||
unnorm_action = normalized_action.copy()
|
||||
y = unnorm_action[:, target_slice]
|
||||
unnorm_action[:, target_slice] = (y + 1.0) / 2.0 * denom + q01
|
||||
|
||||
return unnorm_action
|
||||
|
||||
class StarvlaInferenceServer:
|
||||
|
||||
def __init__(self, config_path: str):
|
||||
|
||||
with open(config_path, "r") as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
# import ipdb;ipdb.set_trace()
|
||||
policy_server_cfg = cfg["policy_server"]
|
||||
root_paths = cfg["general"]["root_paths"]
|
||||
self.ckpt_source = policy_server_cfg["ckpt_source"]
|
||||
self.ckpt_path = self._resolve_ckpt_path(
|
||||
ckpt_url=policy_server_cfg["ckpt_path"],
|
||||
root_paths=root_paths,
|
||||
)
|
||||
self.rel_eepose_stats = self.read_rel_eepose_stats()
|
||||
self.host = policy_server_cfg.get("host", "0.0.0.0")
|
||||
self.port = policy_server_cfg.get("port", 5000)
|
||||
self.use_bf16 = policy_server_cfg.get("use_bf16", True)
|
||||
self.unnorm_key = policy_server_cfg.get("unnorm_key", "oxe_bridge")
|
||||
self.state_mode = policy_server_cfg.get("state_mode", "ee_pose7")
|
||||
|
||||
print("Loading StarVLA model...")
|
||||
self.model = self.load_model()
|
||||
print("Model loaded.")
|
||||
|
||||
self.app = Flask(__name__)
|
||||
self.register_routes()
|
||||
|
||||
@staticmethod
|
||||
def _resolve_ckpt_path(ckpt_url: str, root_paths: dict) -> str:
|
||||
parsed = urlparse(ckpt_url)
|
||||
if not parsed.scheme:
|
||||
return ckpt_url
|
||||
root = root_paths.get(parsed.scheme)
|
||||
if not root:
|
||||
raise KeyError(
|
||||
f"cannot find the checkpoint root path in root_paths: {root_paths}"
|
||||
)
|
||||
rel = (parsed.netloc + parsed.path).lstrip("/")
|
||||
return os.path.join(root, rel)
|
||||
|
||||
def read_rel_eepose_stats(self):
|
||||
stats_path = Path(self.ckpt_path).parents[1] / "action_delta_eepose_stats.json"
|
||||
if stats_path is None or not Path(stats_path).exists():
|
||||
return {}
|
||||
with open(stats_path, 'r', encoding='utf-8') as f:
|
||||
stats = json.load(f)
|
||||
return stats
|
||||
|
||||
def load_model(self):
|
||||
|
||||
from starVLA.model.framework.share_tools import read_mode_config, dict_to_namespace
|
||||
from starVLA.model.framework.__init__ import build_framework
|
||||
|
||||
model_config, norm_stats = read_mode_config(self.ckpt_path)
|
||||
|
||||
cfg = dict_to_namespace(model_config)
|
||||
cfg.trainer.pretrained_checkpoint = None
|
||||
|
||||
model = build_framework(cfg=cfg)
|
||||
model.norm_stats = norm_stats
|
||||
|
||||
state_dict = torch.load(self.ckpt_path, map_location="cpu")
|
||||
model.load_state_dict(state_dict, strict=True)
|
||||
|
||||
if self.use_bf16:
|
||||
model = model.to(torch.bfloat16)
|
||||
|
||||
model = model.to("cuda").eval()
|
||||
|
||||
self.norm_stats = norm_stats
|
||||
self.action_norm_stats = norm_stats.get(self.unnorm_key, {}).get("action", None)
|
||||
|
||||
return model
|
||||
|
||||
def parse_observation(self, obs, target_size=(320, 180)):
|
||||
|
||||
left_rgb, right_rgb, wrist_rgb = obs["rgb"]["Left_Camera"], obs["rgb"]["Right_Camera"], obs["rgb"]["Hand_Camera"]
|
||||
|
||||
img_left = Image.fromarray(cv2.resize(left_rgb, target_size))
|
||||
img_right = Image.fromarray(cv2.resize(right_rgb, target_size))
|
||||
img_wrist = Image.fromarray(cv2.resize(wrist_rgb, target_size))
|
||||
|
||||
state_vec = obs["state"] #[:10]
|
||||
# import ipdb;ipdb.set_trace()
|
||||
state_vec = pad_to_dim(np.array(state_vec), 100, axis=-1)
|
||||
return img_left, img_right, img_wrist, state_vec, obs["prompt"]
|
||||
|
||||
def inference(self, observation: dict) -> dict:
|
||||
|
||||
img_left, img_right, img_wrist, state_vec, prompt = \
|
||||
self.parse_observation(observation)
|
||||
print(f"{state_vec.shape}")
|
||||
vla_input = {
|
||||
"batch_images": [[img_left, img_right, img_wrist]],
|
||||
"instructions": [prompt],
|
||||
"state": [state_vec]
|
||||
}
|
||||
# import ipdb;ipdb.set_trace()
|
||||
with torch.no_grad():
|
||||
output = self.model.predict_action(**vla_input)
|
||||
|
||||
actions = output.get("normalized_actions")
|
||||
|
||||
if isinstance(actions, torch.Tensor):
|
||||
actions = actions.cpu().numpy()
|
||||
|
||||
if actions.ndim == 3:
|
||||
actions = actions[0] # (16, 10)
|
||||
# import ipdb;ipdb.set_trace()
|
||||
# 反归一化特定切片 (假设需要反归一化的部分是 action 向量中的第 0-8 维,即 ee_pose7 + gripper_width)
|
||||
actions[:, :3] = unnormalize_eepose_action(actions[:, :3], self.rel_eepose_stats, start_idx=0, norm_len=3)
|
||||
# from ipdb import set_trace; set_trace()
|
||||
return {"ee_delta_position_chunks": actions[:, :3].tolist(),
|
||||
"ee_delta_rot6d_chunks": actions[:, 3:9].tolist(),
|
||||
"gripper_width_chunks": actions[:, 9:10].tolist()}
|
||||
|
||||
def register_routes(self):
|
||||
|
||||
@self.app.route("/policy/inference", methods=["POST"])
|
||||
def policy():
|
||||
try:
|
||||
data = pickle.loads(request.data)
|
||||
result = self.inference(data)
|
||||
body = pickle.dumps(result, protocol=4)
|
||||
return Response(body, mimetype="application/octet-stream")
|
||||
except Exception as e:
|
||||
err_obj = {
|
||||
"error": str(e),
|
||||
"traceback": traceback.format_exc(),
|
||||
}
|
||||
body = pickle.dumps(err_obj, protocol=4)
|
||||
return Response(
|
||||
body,
|
||||
mimetype="application/octet-stream",
|
||||
status=500,
|
||||
)
|
||||
@self.app.route("/policy/health", methods=["GET"])
|
||||
def health():
|
||||
if self.model is None:
|
||||
return Response("Failed to load model", mimetype="application/json", status=500)
|
||||
return Response("OK", mimetype="application/json")
|
||||
|
||||
def run(self):
|
||||
|
||||
print("StarVLA policy server running")
|
||||
print(f"Host: {self.host}")
|
||||
print(f"Port: {self.port}")
|
||||
|
||||
self.app.run(
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
threaded=True
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="StarVLA inference server")
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
type=str,
|
||||
default="./benchmark.yaml",
|
||||
help="Path to benchmark.yaml",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
config_path = args.config
|
||||
server = StarvlaInferenceServer(config_path)
|
||||
server.run()
|
||||
@@ -8,8 +8,9 @@ from joysim.annotations.config_class import configclass, field
|
||||
from joysim.annotations.stereotype import stereotype
|
||||
from joysim.controllers.spawnable_controller import SpawnableController
|
||||
from joysim.controllers.visualize_controller import VisualizeController
|
||||
from joysim.unisim.robots.models.modular_robot import ModularRobot
|
||||
from joysim.utils.namespace import PoseVisualType, SimulatorType
|
||||
from joysim.unisim.robots.configs.actuator_configs.grippers import GripperDriveJointConfig
|
||||
from joysim.unisim.robots.actuator_configs.grippers import GripperDriveJointConfig
|
||||
from joysim.extensions.benchmark.action import RobotAction
|
||||
from joysim.extensions.benchmark.benchmark import (
|
||||
BenchmarkAction,
|
||||
@@ -25,6 +26,8 @@ from joysim.utils.pose import Pose
|
||||
class StarvlaPolicyConfig(PolicyConfig):
|
||||
|
||||
robot_name: str = field(default="None", required=True, comment="The name of the robot")
|
||||
arm_name: str = field(default="main_arm", required=True, comment="The name of the arm module to control")
|
||||
drive_name: str = field(default="robotiq_85_left_knuckle_joint", required=True, comment="The name of the drive module to control")
|
||||
gripper_width_mapper_file: str = field(default="", required=True, comment="The file path to the gripper width mapper")
|
||||
visualize_action_ee_pose: bool = field(default=False, required=True, comment="Whether to visualize the action end effector pose")
|
||||
visualize_state_ee_pose: bool = field(default=False, required=True, comment="Whether to visualize the state end effector pose")
|
||||
@@ -64,6 +67,8 @@ class StarvlaPolicy(Policy):
|
||||
super().__init__(config)
|
||||
|
||||
self.robot_name = config.robot_name
|
||||
self.arm_name = config.arm_name
|
||||
self.drive_name = config.drive_name
|
||||
self.sensor_names = config.sensor_names
|
||||
self.server_url = config.server_url
|
||||
self.prompt = config.prompt
|
||||
@@ -79,11 +84,15 @@ class StarvlaPolicy(Policy):
|
||||
self.current_chunk_id = 0
|
||||
self.current_chunk_result = None
|
||||
self.run_trunk_size = self.config.run_trunk_size
|
||||
self.drive_joints: dict[str, GripperDriveJointConfig] = SpawnableController.control_robot(self.robot_name, "get_gripper_drive_joints").unwrap()
|
||||
self.robot: ModularRobot = SpawnableController.get_spawnable_data(self.robot_name).unwrap()
|
||||
self.drive_joints: dict[str, GripperDriveJointConfig] = self.robot.get_arm(self.arm_name).get_ee().get_drive_joints()
|
||||
self.robot_drive_name = list(self.drive_joints.keys())[0]
|
||||
|
||||
for joint_name, joint_config in self.drive_joints.items():
|
||||
SpawnableController.control_robot(self.robot_name, "set_joint_stiffness", parameters={"joint_names": [joint_name], "stiffness": joint_config.position_control_stiffness}).unwrap()
|
||||
SpawnableController.control_robot(self.robot_name, "set_joint_damping", parameters={"joint_names": [joint_name], "damping": joint_config.position_control_damping}).unwrap()
|
||||
SpawnableController.control_robot(self.robot_name, "set_joint_effort_limit", parameters={"joint_names": [joint_name], "effort_limit": 50}).unwrap()
|
||||
SpawnableController.control_robot(self.robot_name, "set_joint_effort_limit", parameters={"joint_names": [joint_name], "effort_limit": 5000}).unwrap()
|
||||
SpawnableController.control_robot(self.robot_name, "set_joint_effort_limit", parameters={"joint_names": [self.robot_drive_name], "effort_limit": 5000}).unwrap()
|
||||
self.max_width = float("-inf")
|
||||
self.min_width = float("inf")
|
||||
for entry in self.gripper_width_mapper:
|
||||
@@ -113,11 +122,13 @@ class StarvlaPolicy(Policy):
|
||||
|
||||
def preprocess_observation(self, benchmark_observation: BenchmarkObservation) -> dict:
|
||||
robot_obs = benchmark_observation.get_robot_observations(self.robot_name)["robot_data"]
|
||||
ee_pose_base = robot_obs["ee_pose"]["base_frame"]
|
||||
ee_pose_base = robot_obs["ee_pose"][self.arm_name]["base_frame"]
|
||||
ee_position, ee_rot6d = ee_pose_base["position"],ee_pose_base["rot6d"]
|
||||
normalized_gripper_width = self.__map_joint_position_to_normalized_width(robot_obs["gripper_drive_state"]["position"][0])
|
||||
arm_joint_positions = robot_obs["joint_positions"][:7] # 临时多加了一个drive的位置,现在读的最后一个joint值是drive
|
||||
drive_joint_positions = robot_obs["joint_positions"][-1]
|
||||
normalized_gripper_width = self.__map_joint_position_to_normalized_width(drive_joint_positions)
|
||||
Log.debug(f"input normalized_gripper_width state: {round(normalized_gripper_width, 2)}")
|
||||
state = np.concatenate([ee_position,ee_rot6d,np.array([normalized_gripper_width])])
|
||||
state = np.concatenate([ee_position,ee_rot6d,np.array([normalized_gripper_width]), [0]*10, np.array(arm_joint_positions)])
|
||||
rgb_data = {}
|
||||
for sensor_name in self.sensor_names:
|
||||
sensor_obs = benchmark_observation.get_sensor_observations(sensor_name)
|
||||
@@ -137,6 +148,7 @@ class StarvlaPolicy(Policy):
|
||||
data=payload,
|
||||
headers={"Content-Type": "application/octet-stream"}
|
||||
)
|
||||
self.test_obs = observation["state"] #TODO
|
||||
self._handle_server_error(response)
|
||||
result = pickle.loads(response.content)
|
||||
max_trunk_size = len(result["ee_delta_position_chunks"])
|
||||
@@ -176,7 +188,9 @@ class StarvlaPolicy(Policy):
|
||||
|
||||
def postprocess_action(self, action: dict) -> BenchmarkAction:
|
||||
benchmark_action = BenchmarkAction()
|
||||
|
||||
Log.debug(f"observation: {self.test_obs}")
|
||||
# import ipdb;ipdb.set_trace()
|
||||
|
||||
# get base frame end-effector pose
|
||||
delta_ee_pose = Pose(position=action["ee_delta_position_chunks"][self.current_chunk_id], rot6d=action["ee_delta_rot6d_chunks"][self.current_chunk_id])
|
||||
curr_state_ee_pose = Pose(position=self.current_ee_position_state, rot6d=self.current_ee_rot6d_state)
|
||||
@@ -221,7 +235,7 @@ class StarvlaPolicy(Policy):
|
||||
VisualizeController.create_pose_visualization(
|
||||
robot_base_world * pose_state_base,
|
||||
name=f"{self.robot_name}/starvla_state_ee",
|
||||
simulator=SimulatorType.ISAACSIM,
|
||||
simulator=SimulatorType.ISAACLAB,
|
||||
pose_type=PoseVisualType.COORDINATE,
|
||||
extra_params={"axis_length": 0.08, "thickness": 0.006},
|
||||
).unwrap()
|
||||
@@ -229,7 +243,7 @@ class StarvlaPolicy(Policy):
|
||||
VisualizeController.create_pose_visualization(
|
||||
robot_base_world * pose_action_base,
|
||||
name=f"{self.robot_name}/starvla_action_ee",
|
||||
simulator=SimulatorType.ISAACSIM,
|
||||
simulator=SimulatorType.ISAACLAB,
|
||||
pose_type=PoseVisualType.COORDINATE,
|
||||
extra_params={"axis_length": 0.1, "thickness": 0.006},
|
||||
).unwrap()
|
||||
@@ -239,5 +253,5 @@ class StarvlaPolicy(Policy):
|
||||
return
|
||||
for target_name in self.visualize_bounding_box_targets:
|
||||
VisualizeController.visualize_target_bounding_box(
|
||||
target_name, simulator=SimulatorType.ISAACSIM
|
||||
target_name, simulator=SimulatorType.ISAACLAB
|
||||
).unwrap()
|
||||
Reference in New Issue
Block a user