adapt for starvla-dex

This commit is contained in:
QiyangYan
2026-05-08 18:24:31 +08:00
parent 2514cc943d
commit 53796c1e63
3 changed files with 89 additions and 150 deletions

View File

@@ -1,9 +1,9 @@
general: general:
scan_project: true scan_project: true
root_paths: root_paths:
asset: /home/ubuntu/xionghao/sim_hofee/sim_hofee/assets asset: /home/zhiyuan/zhujuan/joysim/gen_data/data # Root directory for assets (robots, objects, scene USDs, etc.)
checkpoints: /home/ubuntu/xionghao/starVLA-starVLA/playground/Checkpoints checkpoints: /home/zhiyuan/zhujuan/checkpoints
output: /home/ubuntu/xionghao/sim_hofee output: /home/zhiyuan/zhujuan/joysim/output # Root directory for outputs (recorded data, logs, etc.)
simulation: simulation:
stereotype: isaaclab stereotype: isaaclab
@@ -20,14 +20,14 @@ scene:
base_config: base_config:
stereotype: usd stereotype: usd
name: _827313_home_workspace_01 name: _827313_home_workspace_01
source: local source: platform
asset_path: asset://scenes/kujiale_multispace/827313_home/workspace_01.usd asset_path: platform://scenes/kujiale_multispace/827313_home/collect_asset_without_phy.optimized.glb
object_cfg_dict: object_cfg_dict:
omni6DPose_timer_017: omni6DPose_timer_017:
name: omni6DPose_timer_017 name: omni6DPose_timer_017
stereotype: rigid stereotype: rigid
source: local source: platform
asset_path: asset://objects/omni6DPose/timer/omni6DPose_timer_017/Aligned.usd asset_path: platform://objects/omni6DPose/timer/omni6DPose_timer_017/Aligned.usd
scale: scale:
- 0.001 - 0.001
- 0.001 - 0.001
@@ -45,8 +45,8 @@ scene:
omni6DPose_book_031: omni6DPose_book_031:
name: omni6DPose_book_031 name: omni6DPose_book_031
stereotype: rigid stereotype: rigid
source: local source: platform
asset_path: asset://objects/omni6DPose/book/omni6DPose_book_031/Aligned.usd asset_path: platform://objects/omni6DPose/book/omni6DPose_book_031/Aligned.usd
scale: scale:
- 0.001 - 0.001
- 0.001 - 0.001
@@ -62,9 +62,9 @@ scene:
- -0.033652094759270265 - -0.033652094759270265
axis_y_up: true axis_y_up: true
robot_cfg_dict: robot_cfg_dict:
Franka_Robotiq_2f85: r1pro_dex:
name: Franka_Robotiq_2f85 name: r1pro_dex
asset_path: asset://Franka/franka_robotiq_2f85_zedmini.usd asset_path: asset://robots/r1pro/r1pro_dex.usd
position: position:
- 1.082364 - 1.082364
- -3.92826 - -3.92826
@@ -76,56 +76,41 @@ scene:
- 0.9999999999999973 - 0.9999999999999973
stereotype: modular_robot stereotype: modular_robot
source: local source: local
ee_link_name: panda_link8
ik_joint_names:
- panda_joint1
- panda_joint2
- panda_joint3
- panda_joint4
- panda_joint5
- panda_joint6
- panda_joint7
init_joint_position:
# panda_joint1: 0.18641542
# panda_joint2: 0.47660449
# panda_joint3: -0.03320411
# panda_joint4: -2.27693725
# panda_joint5: 0.98161776
# panda_joint6: 2.20247197
# panda_joint7: 0.71794897
panda_joint2: -0.1633
panda_joint4: -1.07
panda_joint6: 0.8933
panda_joint7: 0.785
arm_modules: arm_modules:
main_arm: main_arm:
arm_actuator_name: franka_arm arm_actuator_name: dex_arm
ee_link_name: panda_link8 ee_actuator_name: robot_hand
ee_type: gripper ee_type: dexterous_hand
ee_actuator_name: robotiq_gripper
actuator_cfg_dict: actuator_cfg_dict:
franka_arm: dex_arm:
stereotype: arm stereotype: arm
joint_names_expr: [panda_joint1, panda_joint2, panda_joint3, panda_joint4, panda_joint5, panda_joint6, panda_joint7] joint_names_expr: [
left_arm_joint1, right_arm_joint1, left_arm_joint2, right_arm_joint2,
left_arm_joint3, right_arm_joint3, left_arm_joint4, right_arm_joint4,
left_arm_joint5, right_arm_joint5, left_arm_joint6, right_arm_joint6,
left_arm_joint7, right_arm_joint7]
stiffness: 3000.0 stiffness: 3000.0
damping: 800.0 damping: 800.0
robotiq_gripper: robot_hand:
stereotype: gripper stereotype: dexterous_hand
joint_names_expr: [robotiq_85_left_knuckle_joint] joint_names_expr: [
left_index_MCP_FE, left_middle_MCP_FE, left_pinky_CMC, left_ring_MCP_FE, left_thumb_CMC_FE,
right_index_MCP_FE, right_middle_MCP_FE, right_pinky_CMC, right_ring_MCP_FE, right_thumb_CMC_FE,
left_index_MCP_AA, left_middle_MCP_AA, left_pinky_MCP_FE, left_ring_MCP_AA, left_thumb_CMC_AA,
right_index_MCP_AA, right_middle_MCP_AA, right_pinky_MCP_FE, right_ring_MCP_AA, right_thumb_CMC_AA,
left_index_PIP, left_middle_PIP, left_pinky_MCP_AA, left_ring_PIP, left_thumb_MCP_FE,
right_index_PIP, right_middle_PIP, right_pinky_MCP_AA, right_ring_PIP, right_thumb_MCP_FE,
left_index_DIP, left_middle_DIP, left_pinky_PIP, left_ring_DIP, left_thumb_MCP_AA,
right_index_DIP, right_middle_DIP, right_pinky_PIP, right_ring_DIP, right_thumb_MCP_AA,
left_pinky_DIP, left_thumb_IP, right_pinky_DIP, right_thumb_IP]
stiffness: 10000 stiffness: 10000
damping: 500.0 damping: 500.0
close_control_type: velocity close_control_type: velocity
open_control_type: position open_control_type: position
drive_joints:
robotiq_85_left_knuckle_joint:
close_velocity: 5.0
open_velocity: -5.0
close_position: 0.8
open_position: 0.0
use_planner: false use_planner: false
sensor_cfg_dict: sensor_cfg_dict:
Hand_Camera: Zed_Camera:
name: Hand_Camera name: Zed_Camera
stereotype: camera stereotype: camera
data_types: data_types:
- rgb - rgb
@@ -140,71 +125,18 @@ scene:
vertical_aperture: 2.7608816125932627 vertical_aperture: 2.7608816125932627
convention: opengl convention: opengl
attach_to: attach_to:
target_name: Franka_Robotiq_2f85 target_name: r1pro_dex
is_articulation_part: true is_articulation_part: false
articulation_part_name: panda_link8
create_fixed_joint: true create_fixed_joint: true
local_position: local_position:
- -0.07128738160694643 - 0.06
- 0.03551506300731732 - 0.0
- 0.018927748370281355 - 0.01
local_rotation: local_rotation:
- -0.12117023430710862 - -1.0
- -0.6862313269668 - 0.0
- 0.7070213671685396 - 0.0
- 0.12052023305019997 - 0.0
Left_Camera:
name: Left_Camera
stereotype: camera
data_types:
- rgb
- depth
- normals
width: 1280
height: 720
camera_model: pinhole
fix_camera: false
focal_length: 2.1
horizontal_aperture: 5.030789363390793
vertical_aperture: 2.833796298140747
convention: opengl
attach_to:
target_name: Franka_Robotiq_2f85
local_position:
- 0.31702696813014064
- -0.3844238699868664
- 0.6551552990137672
local_rotation:
- 0.8742457685173938
- 0.38378563025938384
- -0.11951449178007277
- -0.27224843891267797
Right_Camera:
name: Right_Camera
stereotype: camera
data_types:
- rgb
- depth
- normals
width: 1280
height: 720
camera_model: pinhole
fix_camera: false
focal_length: 2.1
horizontal_aperture: 5.050364265142387
vertical_aperture: 2.833796298140747
convention: opengl
attach_to:
target_name: Franka_Robotiq_2f85
local_position:
- 0.21844487914880717
- 0.20172329179193413
- 0.30108042236545296
local_rotation:
- -0.5316249212230874
- -0.38697158527836417
- 0.44338617110944967
- 0.6091277686910994
extension: extension:
extension_cfg_dict: extension_cfg_dict:
@@ -213,8 +145,22 @@ extension:
stereotype: data_collect stereotype: data_collect
observer_cfgs: observer_cfgs:
- stereotype: robot_observer - stereotype: robot_observer
name: Franka_Robotiq_2f85 name: r1pro_dex
target_joint_names: [panda_joint1, panda_joint2, panda_joint3, panda_joint4, panda_joint5, panda_joint6, panda_joint7, robotiq_85_left_knuckle_joint] target_joint_names: [steer_motor_joint1, steer_motor_joint2, steer_motor_joint3, torso_joint1, wheel_motor_joint1,
wheel_motor_joint2, wheel_motor_joint3, torso_joint2, torso_joint3, torso_joint4,
left_arm_joint1, right_arm_joint1, left_arm_joint2, right_arm_joint2,
left_arm_joint3, right_arm_joint3, left_arm_joint4, right_arm_joint4,
left_arm_joint5, right_arm_joint5, left_arm_joint6, right_arm_joint6,
left_arm_joint7, right_arm_joint7,
left_index_MCP_FE, left_middle_MCP_FE, left_pinky_CMC, left_ring_MCP_FE, left_thumb_CMC_FE,
right_index_MCP_FE, right_middle_MCP_FE, right_pinky_CMC, right_ring_MCP_FE, right_thumb_CMC_FE,
left_index_MCP_AA, left_middle_MCP_AA, left_pinky_MCP_FE, left_ring_MCP_AA, left_thumb_CMC_AA,
right_index_MCP_AA, right_middle_MCP_AA, right_pinky_MCP_FE, right_ring_MCP_AA, right_thumb_CMC_AA,
left_index_PIP, left_middle_PIP, left_pinky_MCP_AA, left_ring_PIP, left_thumb_MCP_FE,
right_index_PIP, right_middle_PIP, right_pinky_MCP_AA, right_ring_PIP, right_thumb_MCP_FE,
left_index_DIP, left_middle_DIP, left_pinky_PIP, left_ring_DIP, left_thumb_MCP_AA,
right_index_DIP, right_middle_DIP, right_pinky_PIP, right_ring_DIP, right_thumb_MCP_AA,
left_pinky_DIP, left_thumb_IP, right_pinky_DIP, right_thumb_IP]
observe_ee_pose: true observe_ee_pose: true
observe_ee_state: true observe_ee_state: true
observe_joint_position: true observe_joint_position: true
@@ -225,13 +171,7 @@ extension:
observe_joint_position_targets: true observe_joint_position_targets: true
observe_joint_velocity_targets: true observe_joint_velocity_targets: true
- stereotype: sensor_observer - stereotype: sensor_observer
name: Hand_Camera name: Zed_Camera
observe_rgb: true
- stereotype: sensor_observer
name: Left_Camera
observe_rgb: true
- stereotype: sensor_observer
name: Right_Camera
observe_rgb: true observe_rgb: true
starvla_benchmark: starvla_benchmark:
@@ -248,12 +188,11 @@ extension:
object_B_name: omni6DPose_timer_017 object_B_name: omni6DPose_timer_017
policy: policy:
stereotype: starvla stereotype: starvla
robot_name: Franka_Robotiq_2f85 robot_name: r1pro_dex
arm_name: main_arm arm_name: main_arm
sensor_names: [Hand_Camera, Left_Camera, Right_Camera] sensor_names: [Zed_Camera]
prompt: pick up the timer and put on the book prompt: pick up the timer and put on the book
run_trunk_size: 16 run_trunk_size: 16
gripper_width_mapper_file: ./gripper_width_robotiq_2f85_fixed.json
visualize_action_ee_pose: true visualize_action_ee_pose: true
visualize_state_ee_pose: true visualize_state_ee_pose: true
visualize_bounding_box_targets: [] # [omni6DPose_plug_001, omni6DPose_can_016] # 打开会被policy看到会影响policy的推理结果 visualize_bounding_box_targets: [] # [omni6DPose_plug_001, omni6DPose_can_016] # 打开会被policy看到会影响policy的推理结果
@@ -267,13 +206,10 @@ extension:
postprocess_list: ["hdf5", "video"] postprocess_list: ["hdf5", "video"]
policy_server: policy_server:
# ckpt_path: checkpoints://0324_qwenpi_droid_pretrain_8node/checkpoints/steps_30000_pytorch_model.pt ckpt_path: checkpoints://egodex_part1_restats_gbs1024/checkpoints/steps_70000_pytorch_model.pt
# ckpt_path: checkpoints://0405_qwenpi_droid_norm_pretrain_8node/checkpoints/steps_60000_pytorch_model.pt
# ckpt_path: checkpoints://0407_qwenpi_droid_postrain/final_model/pytorch_model.pt
ckpt_path: checkpoints://0407_qwenpi_droid_from_scratch/final_model/pytorch_model.pt
ckpt_source: local ckpt_source: local
host: 0.0.0.0 host: 0.0.0.0
port: 5000 port: 5000
use_bf16: true use_bf16: true
unnorm_key: oxe_bridge unnorm_key: oxe_bridge
state_mode: ee_pose7 state_mode: ee_pose7

View File

@@ -106,13 +106,14 @@ class StarvlaInferenceServer:
self.parse_observation(observation) self.parse_observation(observation)
print(f"{state_vec.shape}") print(f"{state_vec.shape}")
vla_input = { vla_input = {
"batch_images": [[img_left, img_right, img_wrist]], # "batch_images": [[img_left, img_right, img_wrist]],
"instructions": [prompt], "image": [img_left],
"state": [state_vec] "lang": prompt,
"state": state_vec
} }
with torch.no_grad(): with torch.no_grad():
output = self.model.predict_action(**vla_input) output = self.model.predict_action(examples=vla_input)
actions = output.get("normalized_actions") actions = output.get("normalized_actions")
@@ -176,4 +177,4 @@ if __name__ == "__main__":
config_path = args.config config_path = args.config
server = StarvlaInferenceServer(config_path) server = StarvlaInferenceServer(config_path)
server.run() server.run()

View File

@@ -28,7 +28,7 @@ class StarvlaPolicyConfig(PolicyConfig):
robot_name: str = field(default="None", required=True, comment="The name of the robot") robot_name: str = field(default="None", required=True, comment="The name of the robot")
arm_name: str = field(default="main_arm", required=True, comment="The name of the arm module to control") arm_name: str = field(default="main_arm", required=True, comment="The name of the arm module to control")
drive_name: str = field(default="robotiq_85_left_knuckle_joint", required=True, comment="The name of the drive module to control") drive_name: str = field(default="robotiq_85_left_knuckle_joint", required=True, comment="The name of the drive module to control")
gripper_width_mapper_file: str = field(default="", required=True, comment="The file path to the gripper width mapper") # gripper_width_mapper_file: str = field(default="", required=True, comment="The file path to the gripper width mapper")
visualize_action_ee_pose: bool = field(default=False, required=True, comment="Whether to visualize the action end effector pose") visualize_action_ee_pose: bool = field(default=False, required=True, comment="Whether to visualize the action end effector pose")
visualize_state_ee_pose: bool = field(default=False, required=True, comment="Whether to visualize the state end effector pose") visualize_state_ee_pose: bool = field(default=False, required=True, comment="Whether to visualize the state end effector pose")
visualize_bounding_box_targets: list[str] = field( visualize_bounding_box_targets: list[str] = field(
@@ -72,7 +72,7 @@ class StarvlaPolicy(Policy):
self.sensor_names = config.sensor_names self.sensor_names = config.sensor_names
self.server_url = config.server_url self.server_url = config.server_url
self.prompt = config.prompt self.prompt = config.prompt
self.gripper_width_mapper = json.load(open(config.gripper_width_mapper_file, "r")) # self.gripper_width_mapper = json.load(open(config.gripper_width_mapper_file, "r"))
self.visualize_action_ee_pose = config.visualize_action_ee_pose self.visualize_action_ee_pose = config.visualize_action_ee_pose
self.visualize_state_ee_pose = config.visualize_state_ee_pose self.visualize_state_ee_pose = config.visualize_state_ee_pose
self.visualize_bounding_box_targets = list(config.visualize_bounding_box_targets or []) self.visualize_bounding_box_targets = list(config.visualize_bounding_box_targets or [])
@@ -85,19 +85,21 @@ class StarvlaPolicy(Policy):
self.current_chunk_result = None self.current_chunk_result = None
self.run_trunk_size = self.config.run_trunk_size self.run_trunk_size = self.config.run_trunk_size
self.robot: ModularRobot = SpawnableController.get_spawnable_data(self.robot_name).unwrap() self.robot: ModularRobot = SpawnableController.get_spawnable_data(self.robot_name).unwrap()
self.drive_joints: dict[str, GripperDriveJointConfig] = self.robot.get_arm(self.arm_name).get_ee().get_drive_joints() # self.robot: ModularRobot = SpawnableController.get_spawnable(self.robot_name)
self.robot_drive_name = list(self.drive_joints.keys())[0] # self.drive_joints: dict[str, GripperDriveJointConfig] = self.robot.get_arm(self.arm_name).get_ee().get_drive_joints()
# self.robot_drive_name = list(self.drive_joints.keys())[0]
self.robot_drive_name = self.drive_name
for joint_name, joint_config in self.drive_joints.items(): # for joint_name, joint_config in self.drive_joints.items():
SpawnableController.control_robot(self.robot_name, "set_joint_stiffness", parameters={"joint_names": [joint_name], "stiffness": joint_config.position_control_stiffness}).unwrap() # SpawnableController.control_robot(self.robot_name, "set_joint_stiffness", parameters={"joint_names": [joint_name], "stiffness": joint_config.position_control_stiffness}).unwrap()
SpawnableController.control_robot(self.robot_name, "set_joint_damping", parameters={"joint_names": [joint_name], "damping": joint_config.position_control_damping}).unwrap() # SpawnableController.control_robot(self.robot_name, "set_joint_damping", parameters={"joint_names": [joint_name], "damping": joint_config.position_control_damping}).unwrap()
SpawnableController.control_robot(self.robot_name, "set_joint_effort_limit", parameters={"joint_names": [joint_name], "effort_limit": 5000}).unwrap() # SpawnableController.control_robot(self.robot_name, "set_joint_effort_limit", parameters={"joint_names": [joint_name], "effort_limit": 5000}).unwrap()
SpawnableController.control_robot(self.robot_name, "set_joint_effort_limit", parameters={"joint_names": [self.robot_drive_name], "effort_limit": 5000}).unwrap() # SpawnableController.control_robot(self.robot_name, "set_joint_effort_limit", parameters={"joint_names": [self.robot_drive_name], "effort_limit": 5000}).unwrap()
self.max_width = float("-inf") self.max_width = float("-inf")
self.min_width = float("inf") self.min_width = float("inf")
for entry in self.gripper_width_mapper: # for entry in self.gripper_width_mapper:
self.max_width = max(self.max_width, entry["width"]) # self.max_width = max(self.max_width, entry["width"])
self.min_width = min(self.min_width, entry["width"]) # self.min_width = min(self.min_width, entry["width"])
def warmup(self, benchmark_observation: BenchmarkObservation) -> None: def warmup(self, benchmark_observation: BenchmarkObservation) -> None:
Log.info(f"Waiting for StarVLA inference server to be ready...") Log.info(f"Waiting for StarVLA inference server to be ready...")
@@ -166,9 +168,9 @@ class StarvlaPolicy(Policy):
joint_position = 0 joint_position = 0
if joint_position > 0.8: if joint_position > 0.8:
joint_position = 0.8 joint_position = 0.8
for entry in self.gripper_width_mapper: # for entry in self.gripper_width_mapper:
if round(entry["angel"], 2) == round(joint_position, 2): # if round(entry["angel"], 2) == round(joint_position, 2):
return 1-(entry["width"] - self.min_width) / (self.max_width - self.min_width) # return 1-(entry["width"] - self.min_width) / (self.max_width - self.min_width)
@@ -254,4 +256,4 @@ class StarvlaPolicy(Policy):
for target_name in self.visualize_bounding_box_targets: for target_name in self.visualize_bounding_box_targets:
VisualizeController.visualize_target_bounding_box( VisualizeController.visualize_target_bounding_box(
target_name, simulator=SimulatorType.ISAACLAB target_name, simulator=SimulatorType.ISAACLAB
).unwrap() ).unwrap()