update
This commit is contained in:
193
benchmark.yaml
193
benchmark.yaml
@@ -1,77 +1,80 @@
|
|||||||
general:
|
general:
|
||||||
scan_project: true
|
scan_project: true
|
||||||
root_paths:
|
root_paths:
|
||||||
asset: /home/ubuntu/projects/gen_data/data
|
asset: /home/ubuntu/xionghao/sim_hofee/sim_hofee/assets
|
||||||
checkpoints: /home/ubuntu/data/models
|
checkpoints: /home/ubuntu/xionghao/starVLA-starVLA/playground/Checkpoints
|
||||||
output: /home/ubuntu/output
|
output: /home/ubuntu/xionghao/sim_hofee
|
||||||
|
|
||||||
simulation:
|
simulation:
|
||||||
|
stereotype: isaaclab
|
||||||
|
intiailize_steps: 300
|
||||||
launch_config:
|
launch_config:
|
||||||
device: cuda
|
device: cuda
|
||||||
enable_cameras: true
|
enable_cameras: true
|
||||||
headless: false
|
headless: false
|
||||||
livestream: 0
|
livestream: 0
|
||||||
|
|
||||||
|
|
||||||
scene:
|
scene:
|
||||||
name: kujiale_multispace
|
name: 827313_home
|
||||||
base_config:
|
base_config:
|
||||||
stereotype: usd
|
stereotype: usd
|
||||||
name: _827313_home_workspace_00
|
name: _827313_home_workspace_01
|
||||||
source: local
|
source: local
|
||||||
asset_path: asset://scenes/kujiale_multispace/827313_home/workspace_00.usd
|
asset_path: asset://scenes/kujiale_multispace/827313_home/workspace_01.usd
|
||||||
object_cfg_dict:
|
object_cfg_dict:
|
||||||
omni6DPose_can_016:
|
omni6DPose_timer_017:
|
||||||
name: omni6DPose_can_016
|
name: omni6DPose_timer_017
|
||||||
stereotype: rigid
|
stereotype: rigid
|
||||||
source: local
|
source: local
|
||||||
asset_path: asset://objects/omni6DPose/can/omni6DPose_can_016/Aligned.usd
|
asset_path: asset://objects/omni6DPose/timer/omni6DPose_timer_017/Aligned.usd
|
||||||
scale:
|
scale:
|
||||||
- 0.001
|
- 0.001
|
||||||
- 0.001
|
- 0.001
|
||||||
- 0.001
|
- 0.001
|
||||||
position:
|
position:
|
||||||
- 0.2
|
- 0.552364
|
||||||
- -4.15243
|
- -4.0582599999999995
|
||||||
- 0.5
|
- 0.524713118
|
||||||
quaternion:
|
quaternion:
|
||||||
- -0.304408012043137
|
- 0.166210542394157
|
||||||
- -0.304408012043137
|
- 0.166210542394157
|
||||||
- 0.638228612805745
|
- 0.6872947370648492
|
||||||
- 0.6382286128057448
|
- 0.6872947370648491
|
||||||
axis_y_up: true
|
axis_y_up: true
|
||||||
omni6DPose_plug_001:
|
omni6DPose_book_031:
|
||||||
name: omni6DPose_plug_001
|
name: omni6DPose_book_031
|
||||||
stereotype: rigid
|
stereotype: rigid
|
||||||
source: local
|
source: local
|
||||||
asset_path: asset://objects/omni6DPose/plug/omni6DPose_plug_001/Aligned.usd
|
asset_path: asset://objects/omni6DPose/book/omni6DPose_book_031/Aligned.usd
|
||||||
scale:
|
scale:
|
||||||
- 0.001
|
- 0.001
|
||||||
- 0.001
|
- 0.001
|
||||||
- 0.001
|
- 0.001
|
||||||
position:
|
position:
|
||||||
- 0.219859
|
- 0.6623640000000001
|
||||||
- -3.852430000000001
|
- -3.7882599999999997
|
||||||
- 0.510259093
|
- 0.5101601435
|
||||||
quaternion:
|
quaternion:
|
||||||
- 0.7070997233636068
|
- 0.7063055546421202
|
||||||
- 0.707099723363607
|
- 0.7063055546421203
|
||||||
- 0.003159306745230588
|
- -0.03365209475927027
|
||||||
- 0.003159306745230588
|
- -0.033652094759270265
|
||||||
axis_y_up: true
|
axis_y_up: true
|
||||||
robot_cfg_dict:
|
robot_cfg_dict:
|
||||||
Franka:
|
Franka_Robotiq_2f85:
|
||||||
name: Franka
|
name: Franka_Robotiq_2f85
|
||||||
asset_path: asset://Franka/franka_robotiq_2f85_zedmini.usd
|
asset_path: asset://Franka/franka_robotiq_2f85_zedmini.usd
|
||||||
position:
|
position:
|
||||||
- -0.3779152929316859
|
- 1.082364
|
||||||
- -3.943187336951741
|
- -3.92826
|
||||||
- 0.3130090550954169
|
- 0.47629299999999997
|
||||||
rotation:
|
rotation:
|
||||||
- 0.9987576855319703
|
- 7.549799991308018e-08
|
||||||
- 0.0
|
- 0.0
|
||||||
- 0.0
|
- 0.0
|
||||||
- 0.04983056883903615
|
- 0.9999999999999973
|
||||||
stereotype: single_gripper_arm_robot
|
stereotype: modular_robot
|
||||||
source: local
|
source: local
|
||||||
ee_link_name: panda_link8
|
ee_link_name: panda_link8
|
||||||
ik_joint_names:
|
ik_joint_names:
|
||||||
@@ -83,18 +86,43 @@ scene:
|
|||||||
- panda_joint6
|
- panda_joint6
|
||||||
- panda_joint7
|
- panda_joint7
|
||||||
init_joint_position:
|
init_joint_position:
|
||||||
|
# panda_joint1: 0.18641542
|
||||||
|
# panda_joint2: 0.47660449
|
||||||
|
# panda_joint3: -0.03320411
|
||||||
|
# panda_joint4: -2.27693725
|
||||||
|
# panda_joint5: 0.98161776
|
||||||
|
# panda_joint6: 2.20247197
|
||||||
|
# panda_joint7: 0.71794897
|
||||||
panda_joint2: -0.1633
|
panda_joint2: -0.1633
|
||||||
panda_joint4: -1.07
|
panda_joint4: -1.07
|
||||||
panda_joint6: 0.8933
|
panda_joint6: 0.8933
|
||||||
panda_joint7: 0.785
|
panda_joint7: 0.785
|
||||||
arm_actuator_name: franka_arm
|
arm_modules:
|
||||||
gripper_actuator_name: robotiq_2f_85
|
main_arm:
|
||||||
use_planner: true
|
arm_actuator_name: franka_arm
|
||||||
planner_cfg:
|
ee_link_name: panda_link8
|
||||||
stereotype: curobo
|
ee_type: gripper
|
||||||
lazy_init: true
|
ee_actuator_name: robotiq_gripper
|
||||||
robot_config_file: asset://curobo/franka_robotiq_2f85/franka_robotiq_2f85.yml
|
actuator_cfg_dict:
|
||||||
world_config_source: stage
|
franka_arm:
|
||||||
|
stereotype: arm
|
||||||
|
joint_names_expr: [panda_joint1, panda_joint2, panda_joint3, panda_joint4, panda_joint5, panda_joint6, panda_joint7]
|
||||||
|
stiffness: 3000.0
|
||||||
|
damping: 800.0
|
||||||
|
robotiq_gripper:
|
||||||
|
stereotype: gripper
|
||||||
|
joint_names_expr: [robotiq_85_left_knuckle_joint]
|
||||||
|
stiffness: 10000
|
||||||
|
damping: 500.0
|
||||||
|
close_control_type: velocity
|
||||||
|
open_control_type: position
|
||||||
|
drive_joints:
|
||||||
|
robotiq_85_left_knuckle_joint:
|
||||||
|
close_velocity: 5.0
|
||||||
|
open_velocity: -5.0
|
||||||
|
close_position: 0.8
|
||||||
|
open_position: 0.0
|
||||||
|
use_planner: false
|
||||||
sensor_cfg_dict:
|
sensor_cfg_dict:
|
||||||
Hand_Camera:
|
Hand_Camera:
|
||||||
name: Hand_Camera
|
name: Hand_Camera
|
||||||
@@ -108,23 +136,23 @@ scene:
|
|||||||
camera_model: pinhole
|
camera_model: pinhole
|
||||||
fix_camera: true
|
fix_camera: true
|
||||||
focal_length: 2.8
|
focal_length: 2.8
|
||||||
horizontal_aperture: 4.890881131191918
|
horizontal_aperture: 4.893416860031241
|
||||||
vertical_aperture: 2.7608816125932627
|
vertical_aperture: 2.7608816125932627
|
||||||
convention: opengl
|
convention: opengl
|
||||||
attach_to:
|
attach_to:
|
||||||
target_name: Franka
|
target_name: Franka_Robotiq_2f85
|
||||||
is_articulation_part: true
|
is_articulation_part: true
|
||||||
articulation_part_name: panda_link8
|
articulation_part_name: panda_link8
|
||||||
create_fixed_joint: true
|
create_fixed_joint: true
|
||||||
local_position:
|
local_position:
|
||||||
- -0.07176474936469446
|
- -0.07128738160694643
|
||||||
- 0.02890958100382394
|
- 0.03551506300731732
|
||||||
- 0.01978286477078585
|
- 0.018927748370281355
|
||||||
local_rotation:
|
local_rotation:
|
||||||
- 0.12352531576657892
|
- -0.12117023430710862
|
||||||
- 0.7000519980574638
|
- -0.6862313269668
|
||||||
- -0.6933396337721066
|
- 0.7070213671685396
|
||||||
- -0.11810524383481495
|
- 0.12052023305019997
|
||||||
Left_Camera:
|
Left_Camera:
|
||||||
name: Left_Camera
|
name: Left_Camera
|
||||||
stereotype: camera
|
stereotype: camera
|
||||||
@@ -137,20 +165,20 @@ scene:
|
|||||||
camera_model: pinhole
|
camera_model: pinhole
|
||||||
fix_camera: false
|
fix_camera: false
|
||||||
focal_length: 2.1
|
focal_length: 2.1
|
||||||
horizontal_aperture: 5.019302546405283
|
horizontal_aperture: 5.030789363390793
|
||||||
vertical_aperture: 2.833796298140747
|
vertical_aperture: 2.833796298140747
|
||||||
convention: opengl
|
convention: opengl
|
||||||
attach_to:
|
attach_to:
|
||||||
target_name: Franka
|
target_name: Franka_Robotiq_2f85
|
||||||
local_position:
|
local_position:
|
||||||
- -0.07383269512283744
|
- 0.31702696813014064
|
||||||
- -0.4566116797983716
|
- -0.3844238699868664
|
||||||
- 0.5664518136443555
|
- 0.6551552990137672
|
||||||
local_rotation:
|
local_rotation:
|
||||||
- 0.7669289562134721
|
- 0.8742457685173938
|
||||||
- 0.4437472984144337
|
- 0.38378563025938384
|
||||||
- -0.232167881174614
|
- -0.11951449178007277
|
||||||
- -0.4012560108236337
|
- -0.27224843891267797
|
||||||
Right_Camera:
|
Right_Camera:
|
||||||
name: Right_Camera
|
name: Right_Camera
|
||||||
stereotype: camera
|
stereotype: camera
|
||||||
@@ -163,20 +191,20 @@ scene:
|
|||||||
camera_model: pinhole
|
camera_model: pinhole
|
||||||
fix_camera: false
|
fix_camera: false
|
||||||
focal_length: 2.1
|
focal_length: 2.1
|
||||||
horizontal_aperture: 5.053278483887542
|
horizontal_aperture: 5.050364265142387
|
||||||
vertical_aperture: 2.833796298140747
|
vertical_aperture: 2.833796298140747
|
||||||
convention: opengl
|
convention: opengl
|
||||||
attach_to:
|
attach_to:
|
||||||
target_name: Franka
|
target_name: Franka_Robotiq_2f85
|
||||||
local_position:
|
local_position:
|
||||||
- 0.4501524009531093
|
- 0.21844487914880717
|
||||||
- 0.7206899873248545
|
- 0.20172329179193413
|
||||||
- 0.27525652672526973
|
- 0.30108042236545296
|
||||||
local_rotation:
|
local_rotation:
|
||||||
- 0.007056820167439314
|
- -0.5316249212230874
|
||||||
- 0.007107689842383092
|
- -0.38697158527836417
|
||||||
- 0.709606075042274
|
- 0.44338617110944967
|
||||||
- 0.7045274304789896
|
- 0.6091277686910994
|
||||||
|
|
||||||
extension:
|
extension:
|
||||||
extension_cfg_dict:
|
extension_cfg_dict:
|
||||||
@@ -185,9 +213,10 @@ extension:
|
|||||||
stereotype: data_collect
|
stereotype: data_collect
|
||||||
observer_cfgs:
|
observer_cfgs:
|
||||||
- stereotype: robot_observer
|
- stereotype: robot_observer
|
||||||
name: Franka
|
name: Franka_Robotiq_2f85
|
||||||
|
target_joint_names: [panda_joint1, panda_joint2, panda_joint3, panda_joint4, panda_joint5, panda_joint6, panda_joint7, robotiq_85_left_knuckle_joint]
|
||||||
observe_ee_pose: true
|
observe_ee_pose: true
|
||||||
observe_gripper_drive_state: true
|
observe_ee_state: true
|
||||||
observe_joint_position: true
|
observe_joint_position: true
|
||||||
observe_joint_velocity: true
|
observe_joint_velocity: true
|
||||||
observe_joint_positions: true
|
observe_joint_positions: true
|
||||||
@@ -210,22 +239,23 @@ extension:
|
|||||||
stereotype: benchmark
|
stereotype: benchmark
|
||||||
data_collector_name: benchmark_data_collect
|
data_collector_name: benchmark_data_collect
|
||||||
action_frequency: 15.0
|
action_frequency: 15.0
|
||||||
timeout_per_episode: 30
|
timeout_per_episode: 300
|
||||||
goals:
|
goals:
|
||||||
- name: cola on top of book
|
- name: cola on top of book
|
||||||
description: check if the cola bottle is on the book
|
description: check if the cola bottle is on the book
|
||||||
stereotype: on_top
|
stereotype: on_top
|
||||||
object_A_name: omni6DPose_plug_001
|
object_A_name: omni6DPose_book_031
|
||||||
object_B_name: omni6DPose_can_016
|
object_B_name: omni6DPose_timer_017
|
||||||
policy:
|
policy:
|
||||||
stereotype: starvla
|
stereotype: starvla
|
||||||
robot_name: Franka
|
robot_name: Franka_Robotiq_2f85
|
||||||
|
arm_name: main_arm
|
||||||
sensor_names: [Hand_Camera, Left_Camera, Right_Camera]
|
sensor_names: [Hand_Camera, Left_Camera, Right_Camera]
|
||||||
prompt: pick up the can
|
prompt: pick up the timer and put on the book
|
||||||
run_trunk_size: 16
|
run_trunk_size: 16
|
||||||
gripper_width_mapper_file: ./gripper_width_robotiq_2f85_fixed.json
|
gripper_width_mapper_file: ./gripper_width_robotiq_2f85_fixed.json
|
||||||
visualize_action_ee_pose: false
|
visualize_action_ee_pose: true
|
||||||
visualize_state_ee_pose: false
|
visualize_state_ee_pose: true
|
||||||
visualize_bounding_box_targets: [] # [omni6DPose_plug_001, omni6DPose_can_016] # 打开会被policy看到,会影响policy的推理结果
|
visualize_bounding_box_targets: [] # [omni6DPose_plug_001, omni6DPose_can_016] # 打开会被policy看到,会影响policy的推理结果
|
||||||
|
|
||||||
recorder:
|
recorder:
|
||||||
@@ -237,7 +267,10 @@ extension:
|
|||||||
postprocess_list: ["hdf5", "video"]
|
postprocess_list: ["hdf5", "video"]
|
||||||
|
|
||||||
policy_server:
|
policy_server:
|
||||||
ckpt_path: checkpoints://0318_qwenpi_droid_pretrain_8node/checkpoints/steps_30000_pytorch_model.pt
|
# ckpt_path: checkpoints://0324_qwenpi_droid_pretrain_8node/checkpoints/steps_30000_pytorch_model.pt
|
||||||
|
# ckpt_path: checkpoints://0405_qwenpi_droid_norm_pretrain_8node/checkpoints/steps_60000_pytorch_model.pt
|
||||||
|
# ckpt_path: checkpoints://0407_qwenpi_droid_postrain/final_model/pytorch_model.pt
|
||||||
|
ckpt_path: checkpoints://0407_qwenpi_droid_from_scratch/final_model/pytorch_model.pt
|
||||||
ckpt_source: local
|
ckpt_source: local
|
||||||
host: 0.0.0.0
|
host: 0.0.0.0
|
||||||
port: 5000
|
port: 5000
|
||||||
|
|||||||
257
benchmark_can.yaml
Normal file
257
benchmark_can.yaml
Normal file
@@ -0,0 +1,257 @@
|
|||||||
|
general:
|
||||||
|
scan_project: true
|
||||||
|
root_paths:
|
||||||
|
asset: /home/ubuntu/xionghao/sim_hofee/sim_hofee/assets
|
||||||
|
checkpoints: /home/ubuntu/xionghao/starVLA-starVLA/playground/Checkpoints
|
||||||
|
output: /home/ubuntu/xionghao/sim_hofee
|
||||||
|
|
||||||
|
simulation:
|
||||||
|
launch_config:
|
||||||
|
device: cuda
|
||||||
|
enable_cameras: true
|
||||||
|
headless: false
|
||||||
|
livestream: 0
|
||||||
|
|
||||||
|
scene:
|
||||||
|
name: kujiale_multispace
|
||||||
|
base_config:
|
||||||
|
stereotype: usd
|
||||||
|
name: _827313_home_workspace_00
|
||||||
|
source: local
|
||||||
|
asset_path: asset://scenes/kujiale_multispace/827313_home/workspace_00.usd
|
||||||
|
object_cfg_dict:
|
||||||
|
omni6DPose_can_016:
|
||||||
|
name: omni6DPose_can_016
|
||||||
|
stereotype: rigid
|
||||||
|
source: local
|
||||||
|
asset_path: asset://objects/omni6DPose/can/omni6DPose_can_016/Aligned.usd
|
||||||
|
scale:
|
||||||
|
- 0.001
|
||||||
|
- 0.001
|
||||||
|
- 0.001
|
||||||
|
position:
|
||||||
|
- 0.2
|
||||||
|
- -4.05243
|
||||||
|
- 0.5
|
||||||
|
quaternion:
|
||||||
|
- -0.304408012043137
|
||||||
|
- -0.304408012043137
|
||||||
|
- 0.638228612805745
|
||||||
|
- 0.6382286128057448
|
||||||
|
axis_y_up: true
|
||||||
|
omni6DPose_can_011:
|
||||||
|
name: omni6DPose_can_011
|
||||||
|
stereotype: rigid
|
||||||
|
source: local
|
||||||
|
asset_path: asset://objects/omni6DPose/can/omni6DPose_can_011/Aligned.usd
|
||||||
|
scale:
|
||||||
|
- 0.001
|
||||||
|
- 0.001
|
||||||
|
- 0.001
|
||||||
|
position:
|
||||||
|
- 0.219859
|
||||||
|
- -3.82430000000001
|
||||||
|
- 0.510259093
|
||||||
|
quaternion:
|
||||||
|
- 0.7070997233636068
|
||||||
|
- 0.707099723363607
|
||||||
|
- 0.003159306745230588
|
||||||
|
- 0.003159306745230588
|
||||||
|
axis_y_up: true
|
||||||
|
robot_cfg_dict:
|
||||||
|
Franka:
|
||||||
|
name: Franka
|
||||||
|
asset_path: asset://Franka/franka_robotiq_2f85_zedmini.usd
|
||||||
|
position:
|
||||||
|
- -0.3779152929316859
|
||||||
|
- -3.943187336951741
|
||||||
|
- 0.3130090550954169
|
||||||
|
rotation:
|
||||||
|
- 0.9987576855319703
|
||||||
|
- 0.0
|
||||||
|
- 0.0
|
||||||
|
- 0.04983056883903615
|
||||||
|
stereotype: single_gripper_arm_robot
|
||||||
|
source: local
|
||||||
|
ee_link_name: panda_link8
|
||||||
|
ik_joint_names:
|
||||||
|
- panda_joint1
|
||||||
|
- panda_joint2
|
||||||
|
- panda_joint3
|
||||||
|
- panda_joint4
|
||||||
|
- panda_joint5
|
||||||
|
- panda_joint6
|
||||||
|
- panda_joint7
|
||||||
|
init_joint_position:
|
||||||
|
panda_joint2: -0.1633
|
||||||
|
panda_joint4: -1.07
|
||||||
|
panda_joint6: 0.8933
|
||||||
|
panda_joint7: 0.785
|
||||||
|
# panda_joint1: 0.00201746
|
||||||
|
# panda_joint2: -0.25797674
|
||||||
|
# panda_joint3: -0.02563748
|
||||||
|
# panda_joint4: -1.68997145
|
||||||
|
# panda_joint5: -0.0735122
|
||||||
|
# panda_joint6: 1.45765436
|
||||||
|
# panda_joint7: 0.56564939
|
||||||
|
arm_actuator_name: franka_arm
|
||||||
|
gripper_actuator_name: robotiq_2f_85
|
||||||
|
use_planner: true
|
||||||
|
planner_cfg:
|
||||||
|
stereotype: curobo
|
||||||
|
lazy_init: true
|
||||||
|
robot_config_file: asset://curobo/franka_robotiq_2f85/franka_robotiq_2f85.yml
|
||||||
|
world_config_source: stage
|
||||||
|
sensor_cfg_dict:
|
||||||
|
Hand_Camera:
|
||||||
|
name: Hand_Camera
|
||||||
|
stereotype: camera
|
||||||
|
data_types:
|
||||||
|
- rgb
|
||||||
|
- depth
|
||||||
|
- normals
|
||||||
|
width: 1280
|
||||||
|
height: 720
|
||||||
|
camera_model: pinhole
|
||||||
|
fix_camera: true
|
||||||
|
focal_length: 2.8
|
||||||
|
horizontal_aperture: 4.890881131191918
|
||||||
|
vertical_aperture: 2.7608816125932627
|
||||||
|
convention: opengl
|
||||||
|
attach_to:
|
||||||
|
target_name: Franka
|
||||||
|
is_articulation_part: true
|
||||||
|
articulation_part_name: panda_link8
|
||||||
|
create_fixed_joint: true
|
||||||
|
local_position:
|
||||||
|
- -0.07176474936469446
|
||||||
|
- 0.02890958100382394
|
||||||
|
- 0.01978286477078585
|
||||||
|
local_rotation:
|
||||||
|
- 0.12352531576657892
|
||||||
|
- 0.7000519980574638
|
||||||
|
- -0.6933396337721066
|
||||||
|
- -0.11810524383481495
|
||||||
|
Left_Camera:
|
||||||
|
name: Left_Camera
|
||||||
|
stereotype: camera
|
||||||
|
data_types:
|
||||||
|
- rgb
|
||||||
|
- depth
|
||||||
|
- normals
|
||||||
|
width: 1280
|
||||||
|
height: 720
|
||||||
|
camera_model: pinhole
|
||||||
|
fix_camera: false
|
||||||
|
focal_length: 2.1
|
||||||
|
horizontal_aperture: 5.019302546405283
|
||||||
|
vertical_aperture: 2.833796298140747
|
||||||
|
convention: opengl
|
||||||
|
attach_to:
|
||||||
|
target_name: Franka
|
||||||
|
local_position:
|
||||||
|
- -0.07383269512283744
|
||||||
|
- -0.4566116797983716
|
||||||
|
- 0.5664518136443555
|
||||||
|
local_rotation:
|
||||||
|
- 0.7669289562134721
|
||||||
|
- 0.4437472984144337
|
||||||
|
- -0.232167881174614
|
||||||
|
- -0.4012560108236337
|
||||||
|
Right_Camera:
|
||||||
|
name: Right_Camera
|
||||||
|
stereotype: camera
|
||||||
|
data_types:
|
||||||
|
- rgb
|
||||||
|
- depth
|
||||||
|
- normals
|
||||||
|
width: 1280
|
||||||
|
height: 720
|
||||||
|
camera_model: pinhole
|
||||||
|
fix_camera: false
|
||||||
|
focal_length: 2.1
|
||||||
|
horizontal_aperture: 5.053278483887542
|
||||||
|
vertical_aperture: 2.833796298140747
|
||||||
|
convention: opengl
|
||||||
|
attach_to:
|
||||||
|
target_name: Franka
|
||||||
|
local_position:
|
||||||
|
- 0.4501524009531093
|
||||||
|
- 0.7206899873248545
|
||||||
|
- 0.27525652672526973
|
||||||
|
local_rotation:
|
||||||
|
- 0.007056820167439314
|
||||||
|
- 0.007107689842383092
|
||||||
|
- 0.709606075042274
|
||||||
|
- 0.7045274304789896
|
||||||
|
|
||||||
|
extension:
|
||||||
|
extension_cfg_dict:
|
||||||
|
benchmark_data_collect:
|
||||||
|
enable: true
|
||||||
|
stereotype: data_collect
|
||||||
|
observer_cfgs:
|
||||||
|
- stereotype: robot_observer
|
||||||
|
name: Franka
|
||||||
|
target_joint_names: [panda_joint1, panda_joint2, panda_joint3, panda_joint4, panda_joint5, panda_joint6, panda_joint7]
|
||||||
|
observe_ee_pose: true
|
||||||
|
observe_gripper_drive_state: true
|
||||||
|
observe_joint_position: true
|
||||||
|
observe_joint_velocity: true
|
||||||
|
observe_joint_positions: true
|
||||||
|
observe_joint_velocities: true
|
||||||
|
observe_joint_accelerations: true
|
||||||
|
observe_joint_position_targets: true
|
||||||
|
observe_joint_velocity_targets: true
|
||||||
|
- stereotype: sensor_observer
|
||||||
|
name: Hand_Camera
|
||||||
|
observe_rgb: true
|
||||||
|
- stereotype: sensor_observer
|
||||||
|
name: Left_Camera
|
||||||
|
observe_rgb: true
|
||||||
|
- stereotype: sensor_observer
|
||||||
|
name: Right_Camera
|
||||||
|
observe_rgb: true
|
||||||
|
|
||||||
|
starvla_benchmark:
|
||||||
|
enable: true
|
||||||
|
stereotype: benchmark
|
||||||
|
data_collector_name: benchmark_data_collect
|
||||||
|
action_frequency: 15.0
|
||||||
|
timeout_per_episode: 300
|
||||||
|
goals:
|
||||||
|
- name: cola on top of book
|
||||||
|
description: check if the cola bottle is on the book
|
||||||
|
stereotype: on_top
|
||||||
|
object_A_name: omni6DPose_can_016
|
||||||
|
object_B_name: omni6DPose_can_011
|
||||||
|
policy:
|
||||||
|
stereotype: starvla
|
||||||
|
robot_name: Franka
|
||||||
|
sensor_names: [Hand_Camera, Left_Camera, Right_Camera]
|
||||||
|
prompt: pick up the red can on the table
|
||||||
|
run_trunk_size: 16
|
||||||
|
gripper_width_mapper_file: ./gripper_width_robotiq_2f85_fixed.json
|
||||||
|
visualize_action_ee_pose: false
|
||||||
|
visualize_state_ee_pose: false
|
||||||
|
visualize_bounding_box_targets: [] # [omni6DPose_plug_001, omni6DPose_can_016] # 打开会被policy看到,会影响policy的推理结果
|
||||||
|
|
||||||
|
recorder:
|
||||||
|
enable: true # set to true to record the data
|
||||||
|
stereotype: record
|
||||||
|
data_collector_name: benchmark_data_collect
|
||||||
|
record_fps: 30
|
||||||
|
backend_root_path: output://benchmark_record
|
||||||
|
postprocess_list: ["hdf5", "video"]
|
||||||
|
|
||||||
|
policy_server:
|
||||||
|
# ckpt_path: checkpoints://0324_qwenpi_droid_pretrain_8node/checkpoints/steps_30000_pytorch_model.pt
|
||||||
|
# ckpt_path: checkpoints://0401_qwenpi_droid_pretrain_8node/checkpoints/steps_20000_pytorch_model.p
|
||||||
|
ckpt_path: checkpoints://0405_qwenpi_droid_norm_pretrain_8node/checkpoints/steps_60000_pytorch_model.pt
|
||||||
|
# ckpt_path: checkpoints://0407_qwenpi_droid_postrain/checkpoints/steps_10000_pytorch_model.pt
|
||||||
|
ckpt_source: local
|
||||||
|
host: 0.0.0.0
|
||||||
|
port: 5000
|
||||||
|
use_bf16: true
|
||||||
|
unnorm_key: oxe_bridge
|
||||||
|
state_mode: ee_pose7
|
||||||
@@ -12,6 +12,14 @@ import cv2
|
|||||||
from flask import Flask, request, Response
|
from flask import Flask, request, Response
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
def pad_to_dim(x: np.ndarray, target_dim: int, axis: int = -1, value: float = 0.0) -> np.ndarray:
|
||||||
|
"""Pad an array to the target dimension with zeros along the specified axis."""
|
||||||
|
current_dim = x.shape[axis]
|
||||||
|
if current_dim < target_dim:
|
||||||
|
pad_width = [(0, 0)] * len(x.shape)
|
||||||
|
pad_width[axis] = (0, target_dim - current_dim)
|
||||||
|
return np.pad(x, pad_width, constant_values=value)
|
||||||
|
return x
|
||||||
|
|
||||||
class StarvlaInferenceServer:
|
class StarvlaInferenceServer:
|
||||||
|
|
||||||
@@ -19,7 +27,7 @@ class StarvlaInferenceServer:
|
|||||||
|
|
||||||
with open(config_path, "r") as f:
|
with open(config_path, "r") as f:
|
||||||
cfg = yaml.safe_load(f)
|
cfg = yaml.safe_load(f)
|
||||||
|
|
||||||
policy_server_cfg = cfg["policy_server"]
|
policy_server_cfg = cfg["policy_server"]
|
||||||
root_paths = cfg["general"]["root_paths"]
|
root_paths = cfg["general"]["root_paths"]
|
||||||
self.ckpt_source = policy_server_cfg["ckpt_source"]
|
self.ckpt_source = policy_server_cfg["ckpt_source"]
|
||||||
@@ -88,14 +96,15 @@ class StarvlaInferenceServer:
|
|||||||
img_wrist = Image.fromarray(cv2.resize(wrist_rgb, target_size))
|
img_wrist = Image.fromarray(cv2.resize(wrist_rgb, target_size))
|
||||||
|
|
||||||
state_vec = obs["state"]
|
state_vec = obs["state"]
|
||||||
|
# import ipdb;ipdb.set_trace()
|
||||||
|
# state_vec = pad_to_dim(np.array(state_vec), 100, axis=-1)
|
||||||
return img_left, img_right, img_wrist, state_vec, obs["prompt"]
|
return img_left, img_right, img_wrist, state_vec, obs["prompt"]
|
||||||
|
|
||||||
def inference(self, observation: dict) -> dict:
|
def inference(self, observation: dict) -> dict:
|
||||||
|
|
||||||
img_left, img_right, img_wrist, state_vec, prompt = \
|
img_left, img_right, img_wrist, state_vec, prompt = \
|
||||||
self.parse_observation(observation)
|
self.parse_observation(observation)
|
||||||
|
print(f"{state_vec.shape}")
|
||||||
vla_input = {
|
vla_input = {
|
||||||
"batch_images": [[img_left, img_right, img_wrist]],
|
"batch_images": [[img_left, img_right, img_wrist]],
|
||||||
"instructions": [prompt],
|
"instructions": [prompt],
|
||||||
|
|||||||
257
starvla_inference_server_unnorm.py
Normal file
257
starvla_inference_server_unnorm.py
Normal file
@@ -0,0 +1,257 @@
|
|||||||
|
|
||||||
|
import yaml
|
||||||
|
import pickle
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import traceback
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import cv2
|
||||||
|
import json
|
||||||
|
from flask import Flask, request, Response
|
||||||
|
from PIL import Image
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
def pad_to_dim(x: np.ndarray, target_dim: int, axis: int = -1, value: float = 0.0) -> np.ndarray:
|
||||||
|
"""Pad an array to the target dimension with zeros along the specified axis."""
|
||||||
|
current_dim = x.shape[axis]
|
||||||
|
if current_dim < target_dim:
|
||||||
|
pad_width = [(0, 0)] * len(x.shape)
|
||||||
|
pad_width[axis] = (0, target_dim - current_dim)
|
||||||
|
return np.pad(x, pad_width, constant_values=value)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def normalize_eepose_state(raw_state, stats, start_idx, norm_len=3):
|
||||||
|
"""
|
||||||
|
根据统计信息对原始动作的特定切片进行归一化 (Normalize),映射到 [-1.0, 1.0] 区间
|
||||||
|
|
||||||
|
:param raw_action: 包含原始动作的 numpy 数组,形状通常为 (batch_size, action_dim)
|
||||||
|
:param stats: 归一化所需的统计信息字典 (包含 'q01' 和 'q99')
|
||||||
|
:param start_idx: 需要归一化的特征在 action 向量中的起始索引
|
||||||
|
:param norm_len: 需要归一化的特征长度,默认为 3 (如 xyz 坐标)
|
||||||
|
:return: 归一化后的动作数组 (返回新数组,不修改原数组)
|
||||||
|
"""
|
||||||
|
# 1. 提取并截取对应的 q01 和 q99,保持 float16 精度对齐
|
||||||
|
q01 = np.array(stats["q01"][start_idx : start_idx + norm_len], dtype=np.float16)
|
||||||
|
q99 = np.array(stats["q99"][start_idx : start_idx + norm_len], dtype=np.float16)
|
||||||
|
|
||||||
|
# 2. 计算分母 denom,防除零保护
|
||||||
|
denom = np.clip(q99 - q01, a_min=1e-5, a_max=None)
|
||||||
|
|
||||||
|
# 3. 定位切片
|
||||||
|
target_slice = slice(start_idx, start_idx + norm_len)
|
||||||
|
|
||||||
|
# 4. 复制一份以防污染原始数据矩阵
|
||||||
|
norm_state = raw_state.copy()
|
||||||
|
x = norm_state[:, target_slice]
|
||||||
|
|
||||||
|
# 5. 执行正向映射计算:y = 2 * (x - q01) / denom - 1
|
||||||
|
y = 2.0 * (x - q01) / denom - 1.0
|
||||||
|
|
||||||
|
# 6. 核心操作:截断到 [-1.0, 1.0] 区间,防止输入给模型的动作越界
|
||||||
|
norm_state[:, target_slice] = np.clip(y, -1.0, 1.0)
|
||||||
|
|
||||||
|
return norm_state
|
||||||
|
|
||||||
|
def unnormalize_eepose_action(normalized_action, stats, start_idx, norm_len=3):
|
||||||
|
"""
|
||||||
|
读取统计信息 JSON 并对动作的特定切片进行反归一化 (Un-normalize)
|
||||||
|
|
||||||
|
:param normalized_action: 包含归一化动作的 numpy 数组,形状通常为 (batch_size, action_dim)
|
||||||
|
:param stats: 反归一化所需的统计信息,通常是从 JSON 文件中读取的字典
|
||||||
|
:param start_idx: 需要反归一化的特征在 action 向量中的起始索引 (即原代码中的 start)
|
||||||
|
:param norm_len: 需要反归一化的特征长度,默认为 3 (如 xyz 坐标)
|
||||||
|
:return: 反归一化后的动作数组 (返回新数组,不修改原数组)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 2. 提取并截取对应的 q01 和 q99
|
||||||
|
# 注意:为了和原归一化精度对齐,这里继续保持 float16
|
||||||
|
# import ipdb;ipdb.set_trace()
|
||||||
|
q01 = np.array(stats["q01"][start_idx : start_idx + norm_len], dtype=np.float16)
|
||||||
|
q99 = np.array(stats["q99"][start_idx : start_idx + norm_len], dtype=np.float16)
|
||||||
|
|
||||||
|
# 3. 计算分母 denom,保持和原来完全相同的裁剪逻辑防除零
|
||||||
|
denom = np.clip(q99 - q01, a_min=1e-5, a_max=None)
|
||||||
|
|
||||||
|
# 4. 定位切片
|
||||||
|
target_slice = slice(start_idx, start_idx + norm_len)
|
||||||
|
|
||||||
|
# 5. 执行反向计算:x = (y + 1) / 2 * denom + q01
|
||||||
|
# 复制一份以防污染原始的 normalized_action 矩阵
|
||||||
|
unnorm_action = normalized_action.copy()
|
||||||
|
y = unnorm_action[:, target_slice]
|
||||||
|
unnorm_action[:, target_slice] = (y + 1.0) / 2.0 * denom + q01
|
||||||
|
|
||||||
|
return unnorm_action
|
||||||
|
|
||||||
|
class StarvlaInferenceServer:
|
||||||
|
|
||||||
|
def __init__(self, config_path: str):
|
||||||
|
|
||||||
|
with open(config_path, "r") as f:
|
||||||
|
cfg = yaml.safe_load(f)
|
||||||
|
# import ipdb;ipdb.set_trace()
|
||||||
|
policy_server_cfg = cfg["policy_server"]
|
||||||
|
root_paths = cfg["general"]["root_paths"]
|
||||||
|
self.ckpt_source = policy_server_cfg["ckpt_source"]
|
||||||
|
self.ckpt_path = self._resolve_ckpt_path(
|
||||||
|
ckpt_url=policy_server_cfg["ckpt_path"],
|
||||||
|
root_paths=root_paths,
|
||||||
|
)
|
||||||
|
self.rel_eepose_stats = self.read_rel_eepose_stats()
|
||||||
|
self.host = policy_server_cfg.get("host", "0.0.0.0")
|
||||||
|
self.port = policy_server_cfg.get("port", 5000)
|
||||||
|
self.use_bf16 = policy_server_cfg.get("use_bf16", True)
|
||||||
|
self.unnorm_key = policy_server_cfg.get("unnorm_key", "oxe_bridge")
|
||||||
|
self.state_mode = policy_server_cfg.get("state_mode", "ee_pose7")
|
||||||
|
|
||||||
|
print("Loading StarVLA model...")
|
||||||
|
self.model = self.load_model()
|
||||||
|
print("Model loaded.")
|
||||||
|
|
||||||
|
self.app = Flask(__name__)
|
||||||
|
self.register_routes()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _resolve_ckpt_path(ckpt_url: str, root_paths: dict) -> str:
|
||||||
|
parsed = urlparse(ckpt_url)
|
||||||
|
if not parsed.scheme:
|
||||||
|
return ckpt_url
|
||||||
|
root = root_paths.get(parsed.scheme)
|
||||||
|
if not root:
|
||||||
|
raise KeyError(
|
||||||
|
f"cannot find the checkpoint root path in root_paths: {root_paths}"
|
||||||
|
)
|
||||||
|
rel = (parsed.netloc + parsed.path).lstrip("/")
|
||||||
|
return os.path.join(root, rel)
|
||||||
|
|
||||||
|
def read_rel_eepose_stats(self):
|
||||||
|
stats_path = Path(self.ckpt_path).parents[1] / "action_delta_eepose_stats.json"
|
||||||
|
if stats_path is None or not Path(stats_path).exists():
|
||||||
|
return {}
|
||||||
|
with open(stats_path, 'r', encoding='utf-8') as f:
|
||||||
|
stats = json.load(f)
|
||||||
|
return stats
|
||||||
|
|
||||||
|
def load_model(self):
|
||||||
|
|
||||||
|
from starVLA.model.framework.share_tools import read_mode_config, dict_to_namespace
|
||||||
|
from starVLA.model.framework.__init__ import build_framework
|
||||||
|
|
||||||
|
model_config, norm_stats = read_mode_config(self.ckpt_path)
|
||||||
|
|
||||||
|
cfg = dict_to_namespace(model_config)
|
||||||
|
cfg.trainer.pretrained_checkpoint = None
|
||||||
|
|
||||||
|
model = build_framework(cfg=cfg)
|
||||||
|
model.norm_stats = norm_stats
|
||||||
|
|
||||||
|
state_dict = torch.load(self.ckpt_path, map_location="cpu")
|
||||||
|
model.load_state_dict(state_dict, strict=True)
|
||||||
|
|
||||||
|
if self.use_bf16:
|
||||||
|
model = model.to(torch.bfloat16)
|
||||||
|
|
||||||
|
model = model.to("cuda").eval()
|
||||||
|
|
||||||
|
self.norm_stats = norm_stats
|
||||||
|
self.action_norm_stats = norm_stats.get(self.unnorm_key, {}).get("action", None)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def parse_observation(self, obs, target_size=(320, 180)):
|
||||||
|
|
||||||
|
left_rgb, right_rgb, wrist_rgb = obs["rgb"]["Left_Camera"], obs["rgb"]["Right_Camera"], obs["rgb"]["Hand_Camera"]
|
||||||
|
|
||||||
|
img_left = Image.fromarray(cv2.resize(left_rgb, target_size))
|
||||||
|
img_right = Image.fromarray(cv2.resize(right_rgb, target_size))
|
||||||
|
img_wrist = Image.fromarray(cv2.resize(wrist_rgb, target_size))
|
||||||
|
|
||||||
|
state_vec = obs["state"] #[:10]
|
||||||
|
# import ipdb;ipdb.set_trace()
|
||||||
|
state_vec = pad_to_dim(np.array(state_vec), 100, axis=-1)
|
||||||
|
return img_left, img_right, img_wrist, state_vec, obs["prompt"]
|
||||||
|
|
||||||
|
def inference(self, observation: dict) -> dict:
|
||||||
|
|
||||||
|
img_left, img_right, img_wrist, state_vec, prompt = \
|
||||||
|
self.parse_observation(observation)
|
||||||
|
print(f"{state_vec.shape}")
|
||||||
|
vla_input = {
|
||||||
|
"batch_images": [[img_left, img_right, img_wrist]],
|
||||||
|
"instructions": [prompt],
|
||||||
|
"state": [state_vec]
|
||||||
|
}
|
||||||
|
# import ipdb;ipdb.set_trace()
|
||||||
|
with torch.no_grad():
|
||||||
|
output = self.model.predict_action(**vla_input)
|
||||||
|
|
||||||
|
actions = output.get("normalized_actions")
|
||||||
|
|
||||||
|
if isinstance(actions, torch.Tensor):
|
||||||
|
actions = actions.cpu().numpy()
|
||||||
|
|
||||||
|
if actions.ndim == 3:
|
||||||
|
actions = actions[0] # (16, 10)
|
||||||
|
# import ipdb;ipdb.set_trace()
|
||||||
|
# 反归一化特定切片 (假设需要反归一化的部分是 action 向量中的第 0-8 维,即 ee_pose7 + gripper_width)
|
||||||
|
actions[:, :3] = unnormalize_eepose_action(actions[:, :3], self.rel_eepose_stats, start_idx=0, norm_len=3)
|
||||||
|
# from ipdb import set_trace; set_trace()
|
||||||
|
return {"ee_delta_position_chunks": actions[:, :3].tolist(),
|
||||||
|
"ee_delta_rot6d_chunks": actions[:, 3:9].tolist(),
|
||||||
|
"gripper_width_chunks": actions[:, 9:10].tolist()}
|
||||||
|
|
||||||
|
def register_routes(self):
|
||||||
|
|
||||||
|
@self.app.route("/policy/inference", methods=["POST"])
|
||||||
|
def policy():
|
||||||
|
try:
|
||||||
|
data = pickle.loads(request.data)
|
||||||
|
result = self.inference(data)
|
||||||
|
body = pickle.dumps(result, protocol=4)
|
||||||
|
return Response(body, mimetype="application/octet-stream")
|
||||||
|
except Exception as e:
|
||||||
|
err_obj = {
|
||||||
|
"error": str(e),
|
||||||
|
"traceback": traceback.format_exc(),
|
||||||
|
}
|
||||||
|
body = pickle.dumps(err_obj, protocol=4)
|
||||||
|
return Response(
|
||||||
|
body,
|
||||||
|
mimetype="application/octet-stream",
|
||||||
|
status=500,
|
||||||
|
)
|
||||||
|
@self.app.route("/policy/health", methods=["GET"])
|
||||||
|
def health():
|
||||||
|
if self.model is None:
|
||||||
|
return Response("Failed to load model", mimetype="application/json", status=500)
|
||||||
|
return Response("OK", mimetype="application/json")
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
|
||||||
|
print("StarVLA policy server running")
|
||||||
|
print(f"Host: {self.host}")
|
||||||
|
print(f"Port: {self.port}")
|
||||||
|
|
||||||
|
self.app.run(
|
||||||
|
host=self.host,
|
||||||
|
port=self.port,
|
||||||
|
threaded=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="StarVLA inference server")
|
||||||
|
parser.add_argument(
|
||||||
|
"--config",
|
||||||
|
type=str,
|
||||||
|
default="./benchmark.yaml",
|
||||||
|
help="Path to benchmark.yaml",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
config_path = args.config
|
||||||
|
server = StarvlaInferenceServer(config_path)
|
||||||
|
server.run()
|
||||||
@@ -8,8 +8,9 @@ from joysim.annotations.config_class import configclass, field
|
|||||||
from joysim.annotations.stereotype import stereotype
|
from joysim.annotations.stereotype import stereotype
|
||||||
from joysim.controllers.spawnable_controller import SpawnableController
|
from joysim.controllers.spawnable_controller import SpawnableController
|
||||||
from joysim.controllers.visualize_controller import VisualizeController
|
from joysim.controllers.visualize_controller import VisualizeController
|
||||||
|
from joysim.unisim.robots.models.modular_robot import ModularRobot
|
||||||
from joysim.utils.namespace import PoseVisualType, SimulatorType
|
from joysim.utils.namespace import PoseVisualType, SimulatorType
|
||||||
from joysim.unisim.robots.configs.actuator_configs.grippers import GripperDriveJointConfig
|
from joysim.unisim.robots.actuator_configs.grippers import GripperDriveJointConfig
|
||||||
from joysim.extensions.benchmark.action import RobotAction
|
from joysim.extensions.benchmark.action import RobotAction
|
||||||
from joysim.extensions.benchmark.benchmark import (
|
from joysim.extensions.benchmark.benchmark import (
|
||||||
BenchmarkAction,
|
BenchmarkAction,
|
||||||
@@ -25,6 +26,8 @@ from joysim.utils.pose import Pose
|
|||||||
class StarvlaPolicyConfig(PolicyConfig):
|
class StarvlaPolicyConfig(PolicyConfig):
|
||||||
|
|
||||||
robot_name: str = field(default="None", required=True, comment="The name of the robot")
|
robot_name: str = field(default="None", required=True, comment="The name of the robot")
|
||||||
|
arm_name: str = field(default="main_arm", required=True, comment="The name of the arm module to control")
|
||||||
|
drive_name: str = field(default="robotiq_85_left_knuckle_joint", required=True, comment="The name of the drive module to control")
|
||||||
gripper_width_mapper_file: str = field(default="", required=True, comment="The file path to the gripper width mapper")
|
gripper_width_mapper_file: str = field(default="", required=True, comment="The file path to the gripper width mapper")
|
||||||
visualize_action_ee_pose: bool = field(default=False, required=True, comment="Whether to visualize the action end effector pose")
|
visualize_action_ee_pose: bool = field(default=False, required=True, comment="Whether to visualize the action end effector pose")
|
||||||
visualize_state_ee_pose: bool = field(default=False, required=True, comment="Whether to visualize the state end effector pose")
|
visualize_state_ee_pose: bool = field(default=False, required=True, comment="Whether to visualize the state end effector pose")
|
||||||
@@ -64,6 +67,8 @@ class StarvlaPolicy(Policy):
|
|||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
self.robot_name = config.robot_name
|
self.robot_name = config.robot_name
|
||||||
|
self.arm_name = config.arm_name
|
||||||
|
self.drive_name = config.drive_name
|
||||||
self.sensor_names = config.sensor_names
|
self.sensor_names = config.sensor_names
|
||||||
self.server_url = config.server_url
|
self.server_url = config.server_url
|
||||||
self.prompt = config.prompt
|
self.prompt = config.prompt
|
||||||
@@ -79,11 +84,15 @@ class StarvlaPolicy(Policy):
|
|||||||
self.current_chunk_id = 0
|
self.current_chunk_id = 0
|
||||||
self.current_chunk_result = None
|
self.current_chunk_result = None
|
||||||
self.run_trunk_size = self.config.run_trunk_size
|
self.run_trunk_size = self.config.run_trunk_size
|
||||||
self.drive_joints: dict[str, GripperDriveJointConfig] = SpawnableController.control_robot(self.robot_name, "get_gripper_drive_joints").unwrap()
|
self.robot: ModularRobot = SpawnableController.get_spawnable_data(self.robot_name).unwrap()
|
||||||
|
self.drive_joints: dict[str, GripperDriveJointConfig] = self.robot.get_arm(self.arm_name).get_ee().get_drive_joints()
|
||||||
|
self.robot_drive_name = list(self.drive_joints.keys())[0]
|
||||||
|
|
||||||
for joint_name, joint_config in self.drive_joints.items():
|
for joint_name, joint_config in self.drive_joints.items():
|
||||||
SpawnableController.control_robot(self.robot_name, "set_joint_stiffness", parameters={"joint_names": [joint_name], "stiffness": joint_config.position_control_stiffness}).unwrap()
|
SpawnableController.control_robot(self.robot_name, "set_joint_stiffness", parameters={"joint_names": [joint_name], "stiffness": joint_config.position_control_stiffness}).unwrap()
|
||||||
SpawnableController.control_robot(self.robot_name, "set_joint_damping", parameters={"joint_names": [joint_name], "damping": joint_config.position_control_damping}).unwrap()
|
SpawnableController.control_robot(self.robot_name, "set_joint_damping", parameters={"joint_names": [joint_name], "damping": joint_config.position_control_damping}).unwrap()
|
||||||
SpawnableController.control_robot(self.robot_name, "set_joint_effort_limit", parameters={"joint_names": [joint_name], "effort_limit": 50}).unwrap()
|
SpawnableController.control_robot(self.robot_name, "set_joint_effort_limit", parameters={"joint_names": [joint_name], "effort_limit": 5000}).unwrap()
|
||||||
|
SpawnableController.control_robot(self.robot_name, "set_joint_effort_limit", parameters={"joint_names": [self.robot_drive_name], "effort_limit": 5000}).unwrap()
|
||||||
self.max_width = float("-inf")
|
self.max_width = float("-inf")
|
||||||
self.min_width = float("inf")
|
self.min_width = float("inf")
|
||||||
for entry in self.gripper_width_mapper:
|
for entry in self.gripper_width_mapper:
|
||||||
@@ -113,11 +122,13 @@ class StarvlaPolicy(Policy):
|
|||||||
|
|
||||||
def preprocess_observation(self, benchmark_observation: BenchmarkObservation) -> dict:
|
def preprocess_observation(self, benchmark_observation: BenchmarkObservation) -> dict:
|
||||||
robot_obs = benchmark_observation.get_robot_observations(self.robot_name)["robot_data"]
|
robot_obs = benchmark_observation.get_robot_observations(self.robot_name)["robot_data"]
|
||||||
ee_pose_base = robot_obs["ee_pose"]["base_frame"]
|
ee_pose_base = robot_obs["ee_pose"][self.arm_name]["base_frame"]
|
||||||
ee_position, ee_rot6d = ee_pose_base["position"],ee_pose_base["rot6d"]
|
ee_position, ee_rot6d = ee_pose_base["position"],ee_pose_base["rot6d"]
|
||||||
normalized_gripper_width = self.__map_joint_position_to_normalized_width(robot_obs["gripper_drive_state"]["position"][0])
|
arm_joint_positions = robot_obs["joint_positions"][:7] # 临时多加了一个drive的位置,现在读的最后一个joint值是drive
|
||||||
|
drive_joint_positions = robot_obs["joint_positions"][-1]
|
||||||
|
normalized_gripper_width = self.__map_joint_position_to_normalized_width(drive_joint_positions)
|
||||||
Log.debug(f"input normalized_gripper_width state: {round(normalized_gripper_width, 2)}")
|
Log.debug(f"input normalized_gripper_width state: {round(normalized_gripper_width, 2)}")
|
||||||
state = np.concatenate([ee_position,ee_rot6d,np.array([normalized_gripper_width])])
|
state = np.concatenate([ee_position,ee_rot6d,np.array([normalized_gripper_width]), [0]*10, np.array(arm_joint_positions)])
|
||||||
rgb_data = {}
|
rgb_data = {}
|
||||||
for sensor_name in self.sensor_names:
|
for sensor_name in self.sensor_names:
|
||||||
sensor_obs = benchmark_observation.get_sensor_observations(sensor_name)
|
sensor_obs = benchmark_observation.get_sensor_observations(sensor_name)
|
||||||
@@ -137,6 +148,7 @@ class StarvlaPolicy(Policy):
|
|||||||
data=payload,
|
data=payload,
|
||||||
headers={"Content-Type": "application/octet-stream"}
|
headers={"Content-Type": "application/octet-stream"}
|
||||||
)
|
)
|
||||||
|
self.test_obs = observation["state"] #TODO
|
||||||
self._handle_server_error(response)
|
self._handle_server_error(response)
|
||||||
result = pickle.loads(response.content)
|
result = pickle.loads(response.content)
|
||||||
max_trunk_size = len(result["ee_delta_position_chunks"])
|
max_trunk_size = len(result["ee_delta_position_chunks"])
|
||||||
@@ -176,7 +188,9 @@ class StarvlaPolicy(Policy):
|
|||||||
|
|
||||||
def postprocess_action(self, action: dict) -> BenchmarkAction:
|
def postprocess_action(self, action: dict) -> BenchmarkAction:
|
||||||
benchmark_action = BenchmarkAction()
|
benchmark_action = BenchmarkAction()
|
||||||
|
Log.debug(f"observation: {self.test_obs}")
|
||||||
|
# import ipdb;ipdb.set_trace()
|
||||||
|
|
||||||
# get base frame end-effector pose
|
# get base frame end-effector pose
|
||||||
delta_ee_pose = Pose(position=action["ee_delta_position_chunks"][self.current_chunk_id], rot6d=action["ee_delta_rot6d_chunks"][self.current_chunk_id])
|
delta_ee_pose = Pose(position=action["ee_delta_position_chunks"][self.current_chunk_id], rot6d=action["ee_delta_rot6d_chunks"][self.current_chunk_id])
|
||||||
curr_state_ee_pose = Pose(position=self.current_ee_position_state, rot6d=self.current_ee_rot6d_state)
|
curr_state_ee_pose = Pose(position=self.current_ee_position_state, rot6d=self.current_ee_rot6d_state)
|
||||||
@@ -221,7 +235,7 @@ class StarvlaPolicy(Policy):
|
|||||||
VisualizeController.create_pose_visualization(
|
VisualizeController.create_pose_visualization(
|
||||||
robot_base_world * pose_state_base,
|
robot_base_world * pose_state_base,
|
||||||
name=f"{self.robot_name}/starvla_state_ee",
|
name=f"{self.robot_name}/starvla_state_ee",
|
||||||
simulator=SimulatorType.ISAACSIM,
|
simulator=SimulatorType.ISAACLAB,
|
||||||
pose_type=PoseVisualType.COORDINATE,
|
pose_type=PoseVisualType.COORDINATE,
|
||||||
extra_params={"axis_length": 0.08, "thickness": 0.006},
|
extra_params={"axis_length": 0.08, "thickness": 0.006},
|
||||||
).unwrap()
|
).unwrap()
|
||||||
@@ -229,7 +243,7 @@ class StarvlaPolicy(Policy):
|
|||||||
VisualizeController.create_pose_visualization(
|
VisualizeController.create_pose_visualization(
|
||||||
robot_base_world * pose_action_base,
|
robot_base_world * pose_action_base,
|
||||||
name=f"{self.robot_name}/starvla_action_ee",
|
name=f"{self.robot_name}/starvla_action_ee",
|
||||||
simulator=SimulatorType.ISAACSIM,
|
simulator=SimulatorType.ISAACLAB,
|
||||||
pose_type=PoseVisualType.COORDINATE,
|
pose_type=PoseVisualType.COORDINATE,
|
||||||
extra_params={"axis_length": 0.1, "thickness": 0.006},
|
extra_params={"axis_length": 0.1, "thickness": 0.006},
|
||||||
).unwrap()
|
).unwrap()
|
||||||
@@ -239,5 +253,5 @@ class StarvlaPolicy(Policy):
|
|||||||
return
|
return
|
||||||
for target_name in self.visualize_bounding_box_targets:
|
for target_name in self.visualize_bounding_box_targets:
|
||||||
VisualizeController.visualize_target_bounding_box(
|
VisualizeController.visualize_target_bounding_box(
|
||||||
target_name, simulator=SimulatorType.ISAACSIM
|
target_name, simulator=SimulatorType.ISAACLAB
|
||||||
).unwrap()
|
).unwrap()
|
||||||
Reference in New Issue
Block a user