From 3d3da4e17f42c55a7e6bd7276b4f2fd870089016 Mon Sep 17 00:00:00 2001 From: QiyangYan Date: Fri, 22 May 2026 18:38:16 +0800 Subject: [PATCH] adapt for dexterous hands --- benchmark.yaml | 334 +++++++++++++++++++++++------------- starvla_inference_server.py | 78 ++++++--- starvla_policy.py | 164 ++++++++---------- 3 files changed, 340 insertions(+), 236 deletions(-) diff --git a/benchmark.yaml b/benchmark.yaml index b28ff2c..eb06dc0 100644 --- a/benchmark.yaml +++ b/benchmark.yaml @@ -1,143 +1,204 @@ general: - scan_project: true root_paths: - asset: /home/zhiyuan/zhujuan/joysim/gen_data/data # Root directory for assets (robots, objects, scene USDs, etc.) + asset: /home/zhiyuan/zhujuan/joysim_exp/gen_data/data # Root directory for assets (robots, objects, scene USDs, etc.) checkpoints: /home/zhiyuan/zhujuan/checkpoints - output: /home/zhiyuan/zhujuan/joysim/output # Root directory for outputs (recorded data, logs, etc.) + output: /home/zhiyuan/zhujuan/joysim_exp/output # Root directory for outputs (recorded data, logs, etc.) simulation: stereotype: isaaclab - intiailize_steps: 300 launch_config: device: cuda enable_cameras: true headless: false livestream: 0 - - scene: - name: 827313_home + name: kujiale_multispace base_config: stereotype: usd - name: _827313_home_workspace_01 + name: _827313_home_workspace_00 source: platform - asset_path: platform://scenes/kujiale_multispace/827313_home/collect_asset_without_phy.optimized.glb + asset_path: platform://scenes/kujiale_multispace/827313_home/workspace_00.usd object_cfg_dict: - omni6DPose_timer_017: - name: omni6DPose_timer_017 + omni6DPose_can_016: + name: omni6DPose_can_016 stereotype: rigid - source: platform - asset_path: platform://objects/omni6DPose/timer/omni6DPose_timer_017/Aligned.usd + source: local + asset_path: asset://objects/omni6DPose/can/omni6DPose_can_016/Aligned.usd scale: - 0.001 - 0.001 - 0.001 position: - - 0.552364 - - -4.0582599999999995 - - 0.524713118 - quaternion: - - 0.166210542394157 - - 0.166210542394157 - - 0.6872947370648492 - - 0.6872947370648491 - axis_y_up: true + - 0.419859 + - -4.02430000000001 + - 0.510259093 + rotation: + - -0.304408012043137 + - -0.304408012043137 + - 0.638228612805745 + - 0.6382286128057448 omni6DPose_book_031: name: omni6DPose_book_031 stereotype: rigid - source: platform - asset_path: platform://objects/omni6DPose/book/omni6DPose_book_031/Aligned.usd + source: local + asset_path: asset://objects/omni6DPose/book/omni6DPose_book_031/Aligned.usd scale: - 0.001 - 0.001 - 0.001 position: - - 0.6623640000000001 - - -3.7882599999999997 - - 0.5101601435 - quaternion: - - 0.7063055546421202 - - 0.7063055546421203 - - -0.03365209475927027 - - -0.033652094759270265 + - 0.419859 + - -4.152430000000001 + - 0.510259093 + quaternion: [1, 0, 0, 0] axis_y_up: true + robot_cfg_dict: r1pro_dex: name: r1pro_dex asset_path: asset://robots/r1pro/r1pro_dex.usd - position: - - 1.082364 - - -3.92826 - - 0.47629299999999997 - rotation: - - 7.549799991308018e-08 - - 0.0 - - 0.0 - - 0.9999999999999973 + position: [-0.5, -4.0, 0.0] + rotation: [1, 0, 0, 0] stereotype: modular_robot source: local - arm_modules: - main_arm: - arm_actuator_name: dex_arm - ee_actuator_name: robot_hand - ee_type: dexterous_hand + + init_joint_position: + torso_joint1: 0.0 + torso_joint2: 0.0 + torso_joint3: 0.0 + torso_joint4: 0.0 + left_arm_joint1: 0.0 + left_arm_joint2: 0.5 + left_arm_joint3: 0.0 + left_arm_joint4: -1.0 + left_arm_joint5: 0.0 + left_arm_joint6: 0.0 + left_arm_joint7: 0.0 + right_arm_joint1: 0.0 + right_arm_joint2: -0.5 + right_arm_joint3: 0.0 + right_arm_joint4: -1.0 + right_arm_joint5: 0.0 + right_arm_joint6: 0.0 + right_arm_joint7: 0.0 + actuator_cfg_dict: - dex_arm: + left_arm: stereotype: arm - joint_names_expr: [ - left_arm_joint1, right_arm_joint1, left_arm_joint2, right_arm_joint2, - left_arm_joint3, right_arm_joint3, left_arm_joint4, right_arm_joint4, - left_arm_joint5, right_arm_joint5, left_arm_joint6, right_arm_joint6, - left_arm_joint7, right_arm_joint7] - stiffness: 3000.0 - damping: 800.0 - robot_hand: - stereotype: dexterous_hand - joint_names_expr: [ - left_index_MCP_FE, left_middle_MCP_FE, left_pinky_CMC, left_ring_MCP_FE, left_thumb_CMC_FE, - right_index_MCP_FE, right_middle_MCP_FE, right_pinky_CMC, right_ring_MCP_FE, right_thumb_CMC_FE, - left_index_MCP_AA, left_middle_MCP_AA, left_pinky_MCP_FE, left_ring_MCP_AA, left_thumb_CMC_AA, - right_index_MCP_AA, right_middle_MCP_AA, right_pinky_MCP_FE, right_ring_MCP_AA, right_thumb_CMC_AA, - left_index_PIP, left_middle_PIP, left_pinky_MCP_AA, left_ring_PIP, left_thumb_MCP_FE, - right_index_PIP, right_middle_PIP, right_pinky_MCP_AA, right_ring_PIP, right_thumb_MCP_FE, - left_index_DIP, left_middle_DIP, left_pinky_PIP, left_ring_DIP, left_thumb_MCP_AA, - right_index_DIP, right_middle_DIP, right_pinky_PIP, right_ring_DIP, right_thumb_MCP_AA, - left_pinky_DIP, left_thumb_IP, right_pinky_DIP, right_thumb_IP] - stiffness: 10000 - damping: 500.0 - close_control_type: velocity - open_control_type: position + joint_names_expr: [left_arm_joint1, left_arm_joint2, left_arm_joint3, left_arm_joint4, left_arm_joint5, left_arm_joint6, left_arm_joint7] + stiffness: 60000.0 + damping: 4000.0 + right_arm: + stereotype: arm + joint_names_expr: [right_arm_joint1, right_arm_joint2, right_arm_joint3, right_arm_joint4, right_arm_joint5, right_arm_joint6, right_arm_joint7] + stiffness: 60000.0 + damping: 4000.0 + left_hand: + stereotype: arm + joint_names_expr: [left_thumb_CMC_FE, left_thumb_CMC_AA, left_thumb_MCP_FE, left_thumb_MCP_AA, left_thumb_IP, left_index_MCP_FE, left_index_MCP_AA, left_index_PIP, left_index_DIP, left_middle_MCP_FE, left_middle_MCP_AA, left_middle_PIP, left_middle_DIP, left_ring_MCP_FE, left_ring_MCP_AA, left_ring_PIP, left_ring_DIP, left_pinky_CMC, left_pinky_MCP_FE, left_pinky_MCP_AA, left_pinky_PIP, left_pinky_DIP] + stiffness: 50.0 + damping: 5.0 + right_hand: + stereotype: arm + joint_names_expr: [right_thumb_CMC_FE, right_thumb_CMC_AA, right_thumb_MCP_FE, right_thumb_MCP_AA, right_thumb_IP, right_index_MCP_FE, right_index_MCP_AA, right_index_PIP, right_index_DIP, right_middle_MCP_FE, right_middle_MCP_AA, right_middle_PIP, right_middle_DIP, right_ring_MCP_FE, right_ring_MCP_AA, right_ring_PIP, right_ring_DIP, right_pinky_CMC, right_pinky_MCP_FE, right_pinky_MCP_AA, right_pinky_PIP, right_pinky_DIP] + stiffness: 50.0 + damping: 5.0 + torso: + stereotype: arm + joint_names_expr: [torso_joint1, torso_joint2, torso_joint3, torso_joint4] + stiffness: 100000.0 + damping: 8000.0 + base_lock: + stereotype: arm + joint_names_expr: [steer_motor_joint1, steer_motor_joint2, steer_motor_joint3, wheel_motor_joint1, wheel_motor_joint2, wheel_motor_joint3] + stiffness: 100000.0 + damping: 5000.0 + + arm_modules: + left_arm: + arm_actuator_name: left_arm + ee_link_name: left_hand_C_MC + ee_type: dexterous_hand + ee_actuator_name: left_hand + right_arm: + arm_actuator_name: right_arm + ee_link_name: right_hand_C_MC + ee_type: dexterous_hand + ee_actuator_name: right_hand + + extra_modules: + torso: + actuator_name: torso + use_planner: false + sensor_cfg_dict: - Zed_Camera: - name: Zed_Camera + + head_camera: + name: head_camera stereotype: camera - data_types: - - rgb - - depth - - normals + position: [-0.4, -4.0, 1.2] + look_at: + is_point: true + look_at_point: [0.4200, -4.1530, 0.4885] + data_types: [rgb] + width: 1280 + height: 720 + camera_model: pinhole + fix_camera: true + front_camera: + name: front_camera + stereotype: camera + position: [1.0, -4.0, 1.5] + look_at: + is_point: true + look_at_point: [-1.0, -4.0, 1.2] + data_types: [rgb] + width: 1280 + height: 720 + camera_model: pinhole + fix_camera: true + left_camera: + name: left_camera + stereotype: camera + position: [-1.0, -1.0, 1.2] + look_at: + is_point: true + look_at_point: [0.0, -4.1, 1.2] + data_types: [rgb] + width: 1280 + height: 720 + camera_model: pinhole + fix_camera: true + right_camera: + name: right_camera + stereotype: camera + position: [-1.0, -6.5, 1.2] + look_at: + is_point: true + look_at_point: [0.0, -4.1, 1.2] + data_types: [rgb] width: 1280 height: 720 camera_model: pinhole fix_camera: true - focal_length: 2.8 - horizontal_aperture: 4.893416860031241 - vertical_aperture: 2.7608816125932627 - convention: opengl - attach_to: - target_name: r1pro_dex - is_articulation_part: false - create_fixed_joint: true - local_position: - - 0.06 - - 0.0 - - 0.01 - local_rotation: - - -1.0 - - 0.0 - - 0.0 - - 0.0 + light_cfg_dict: + sun: + name: sun + stereotype: general_light + light_type: distant + position: [0, 0, 5] + rotation: [1, 0, 0, 0] + intensity: 1000 + angle: 0.53 + color: [1.0, 1.0, 1.0] + sky: + name: sky + stereotype: general_light + light_type: dome + intensity: 10.0 + color: [1.0, 1.0, 1.0] extension: extension_cfg_dict: benchmark_data_collect: @@ -146,21 +207,51 @@ extension: observer_cfgs: - stereotype: robot_observer name: r1pro_dex - target_joint_names: [steer_motor_joint1, steer_motor_joint2, steer_motor_joint3, torso_joint1, wheel_motor_joint1, - wheel_motor_joint2, wheel_motor_joint3, torso_joint2, torso_joint3, torso_joint4, - left_arm_joint1, right_arm_joint1, left_arm_joint2, right_arm_joint2, - left_arm_joint3, right_arm_joint3, left_arm_joint4, right_arm_joint4, - left_arm_joint5, right_arm_joint5, left_arm_joint6, right_arm_joint6, - left_arm_joint7, right_arm_joint7, - left_index_MCP_FE, left_middle_MCP_FE, left_pinky_CMC, left_ring_MCP_FE, left_thumb_CMC_FE, - right_index_MCP_FE, right_middle_MCP_FE, right_pinky_CMC, right_ring_MCP_FE, right_thumb_CMC_FE, - left_index_MCP_AA, left_middle_MCP_AA, left_pinky_MCP_FE, left_ring_MCP_AA, left_thumb_CMC_AA, - right_index_MCP_AA, right_middle_MCP_AA, right_pinky_MCP_FE, right_ring_MCP_AA, right_thumb_CMC_AA, - left_index_PIP, left_middle_PIP, left_pinky_MCP_AA, left_ring_PIP, left_thumb_MCP_FE, - right_index_PIP, right_middle_PIP, right_pinky_MCP_AA, right_ring_PIP, right_thumb_MCP_FE, - left_index_DIP, left_middle_DIP, left_pinky_PIP, left_ring_DIP, left_thumb_MCP_AA, - right_index_DIP, right_middle_DIP, right_pinky_PIP, right_ring_DIP, right_thumb_MCP_AA, - left_pinky_DIP, left_thumb_IP, right_pinky_DIP, right_thumb_IP] + target_joint_names: + - left_thumb_CMC_FE + - left_thumb_CMC_AA + - left_thumb_MCP_FE + - left_thumb_MCP_AA + - left_thumb_IP + - left_index_MCP_FE + - left_index_MCP_AA + - left_index_PIP + - left_index_DIP + - left_middle_MCP_FE + - left_middle_MCP_AA + - left_middle_PIP + - left_middle_DIP + - left_ring_MCP_FE + - left_ring_MCP_AA + - left_ring_PIP + - left_ring_DIP + - left_pinky_CMC + - left_pinky_MCP_FE + - left_pinky_MCP_AA + - left_pinky_PIP + - left_pinky_DIP + - right_thumb_CMC_FE + - right_thumb_CMC_AA + - right_thumb_MCP_FE + - right_thumb_MCP_AA + - right_thumb_IP + - right_index_MCP_FE + - right_index_MCP_AA + - right_index_PIP + - right_index_DIP + - right_middle_MCP_FE + - right_middle_MCP_AA + - right_middle_PIP + - right_middle_DIP + - right_ring_MCP_FE + - right_ring_MCP_AA + - right_ring_PIP + - right_ring_DIP + - right_pinky_CMC + - right_pinky_MCP_FE + - right_pinky_MCP_AA + - right_pinky_PIP + - right_pinky_DIP observe_ee_pose: true observe_ee_state: true observe_joint_position: true @@ -171,7 +262,16 @@ extension: observe_joint_position_targets: true observe_joint_velocity_targets: true - stereotype: sensor_observer - name: Zed_Camera + name: head_camera + observe_rgb: true + - stereotype: sensor_observer + name: front_camera + observe_rgb: true + - stereotype: sensor_observer + name: left_camera + observe_rgb: true + - stereotype: sensor_observer + name: right_camera observe_rgb: true starvla_benchmark: @@ -181,17 +281,17 @@ extension: action_frequency: 15.0 timeout_per_episode: 300 goals: - - name: cola on top of book - description: check if the cola bottle is on the book + - name: can on top of book + description: check if the can is on the book stereotype: on_top object_A_name: omni6DPose_book_031 - object_B_name: omni6DPose_timer_017 + object_B_name: omni6DPose_can_016 policy: stereotype: starvla robot_name: r1pro_dex - arm_name: main_arm - sensor_names: [Zed_Camera] - prompt: pick up the timer and put on the book + arm_name: right_arm + sensor_names: [head_camera] + prompt: pick up the can and put on the book run_trunk_size: 16 visualize_action_ee_pose: true visualize_state_ee_pose: true @@ -203,7 +303,7 @@ extension: data_collector_name: benchmark_data_collect record_fps: 30 backend_root_path: output://benchmark_record - postprocess_list: ["hdf5", "video"] + postprocess_list: ["hdf5", "video", "preview_video"] policy_server: ckpt_path: checkpoints://egodex_part1_restats_gbs1024/checkpoints/steps_70000_pytorch_model.pt @@ -211,5 +311,3 @@ policy_server: host: 0.0.0.0 port: 5000 use_bf16: true - unnorm_key: oxe_bridge - state_mode: ee_pose7 diff --git a/starvla_inference_server.py b/starvla_inference_server.py index 1562298..bb07bd3 100644 --- a/starvla_inference_server.py +++ b/starvla_inference_server.py @@ -21,6 +21,39 @@ def pad_to_dim(x: np.ndarray, target_dim: int, axis: int = -1, value: float = 0. return np.pad(x, pad_width, constant_values=value) return x + +def normalize_states(states, statistics): + stats = statistics["new_embodiment"]["state"] + q01 = np.array(stats["q01"]).astype(states.dtype) + q99 = np.array(stats["q99"]).astype(states.dtype) + + # In the case of q01 == q99, the normalization will be undefined + # So we set the normalized values to the original values + mask = q01 != q99 + normalized = np.zeros_like(states) + + # Normalize the values where q01 != q99 + # Formula: 2 * (x - q01) / (q99 - q01) - 1 + normalized[..., mask] = (states[..., mask] - q01[..., mask]) / ( + q99[..., mask] - q01[..., mask] + ) + normalized[..., mask] = 2 * normalized[..., mask] - 1 + + # Set the normalized values to the original values where q01 == q99 + normalized[..., ~mask] = states[..., ~mask] + + # Clip the normalized values to be between -1 and 1 + normalized = np.clip(normalized, -1, 1) + return normalized + + +def unnormalize_actions(normalized_actions, statistics): + stats = statistics["new_embodiment"]["action"] + q01 = np.array(stats["q01"]).astype(normalized_actions.dtype) + q99 = np.array(stats["q99"]).astype(normalized_actions.dtype) + + return (normalized_actions + 1) / 2 * (q99 - q01) + q01 + class StarvlaInferenceServer: def __init__(self, config_path: str): @@ -38,8 +71,6 @@ class StarvlaInferenceServer: self.host = policy_server_cfg.get("host", "0.0.0.0") self.port = policy_server_cfg.get("port", 5000) self.use_bf16 = policy_server_cfg.get("use_bf16", True) - self.unnorm_key = policy_server_cfg.get("unnorm_key", "oxe_bridge") - self.state_mode = policy_server_cfg.get("state_mode", "ee_pose7") print("Loading StarVLA model...") self.model = self.load_model() @@ -74,42 +105,36 @@ class StarvlaInferenceServer: model = build_framework(cfg=cfg) model.norm_stats = norm_stats - state_dict = torch.load(self.ckpt_path, map_location="cpu") - model.load_state_dict(state_dict, strict=True) - if self.use_bf16: model = model.to(torch.bfloat16) + model = model.eval() + + state_dict = torch.load(self.ckpt_path, map_location="cpu") + model.load_state_dict(state_dict, strict=True) + model = model.to("cuda") - model = model.to("cuda").eval() self.norm_stats = norm_stats - self.action_norm_stats = norm_stats.get(self.unnorm_key, {}).get("action", None) - return model def parse_observation(self, obs, target_size=(320, 180)): - left_rgb, right_rgb, wrist_rgb = obs["rgb"]["Left_Camera"], obs["rgb"]["Right_Camera"], obs["rgb"]["Hand_Camera"] + head_rgb = obs["rgb"]["head_camera"] - img_left = Image.fromarray(cv2.resize(left_rgb, target_size)) - img_right = Image.fromarray(cv2.resize(right_rgb, target_size)) - img_wrist = Image.fromarray(cv2.resize(wrist_rgb, target_size)) - - state_vec = obs["state"] - # import ipdb;ipdb.set_trace() + img_head = Image.fromarray(cv2.resize(head_rgb, target_size)) + state_vec = normalize_states(obs["state"], self.norm_stats) # state_vec = pad_to_dim(np.array(state_vec), 100, axis=-1) - return img_left, img_right, img_wrist, state_vec, obs["prompt"] + return img_head, state_vec, obs["prompt"] def inference(self, observation: dict) -> dict: - img_left, img_right, img_wrist, state_vec, prompt = \ + img_head, state_vec, prompt = \ self.parse_observation(observation) - print(f"{state_vec.shape}") vla_input = { # "batch_images": [[img_left, img_right, img_wrist]], - "image": [img_left], + "image": [img_head], "lang": prompt, - "state": state_vec + "state": state_vec[None, :], # (1, 62) } with torch.no_grad(): @@ -122,9 +147,16 @@ class StarvlaInferenceServer: if actions.ndim == 3: actions = actions[0] # (16, 10) - return {"ee_delta_position_chunks": actions[:, :3].tolist(), - "ee_delta_rot6d_chunks": actions[:, 3:9].tolist(), - "gripper_width_chunks": actions[:, 9:10].tolist()} + actions = unnormalize_actions(actions, self.norm_stats) + return {"left_arm": { + "ee_delta_position_chunks": actions[:, :3].tolist(), + "ee_delta_rot6d_chunks": actions[:, 3:9].tolist(), + "finger_chunks": actions[:, 9:31].tolist()}, + "right_arm": { + "ee_delta_position_chunks": actions[:, 31:34].tolist(), + "ee_delta_rot6d_chunks": actions[:, 34:40].tolist(), + "finger_chunks": actions[:, 40:62].tolist()} + } def register_routes(self): diff --git a/starvla_policy.py b/starvla_policy.py index fc5fce6..10b7a03 100644 --- a/starvla_policy.py +++ b/starvla_policy.py @@ -4,31 +4,28 @@ import json import numpy as np import requests -from joysim.annotations.config_class import configclass, field -from joysim.annotations.stereotype import stereotype -from joysim.controllers.spawnable_controller import SpawnableController -from joysim.controllers.visualize_controller import VisualizeController -from joysim.unisim.robots.models.modular_robot import ModularRobot -from joysim.utils.namespace import PoseVisualType, SimulatorType -from joysim.unisim.robots.actuator_configs.grippers import GripperDriveJointConfig -from joysim.extensions.benchmark.action import RobotAction -from joysim.extensions.benchmark.benchmark import ( +from fastsim.annotations.config_class import configclass, field +from fastsim.annotations.stereotype import stereotype +from fastsim.controllers.spawnable_controller import SpawnableController +from fastsim.controllers.visualize_controller import VisualizeController +from fastsim.unisim.robots.models.modular_robot import ModularRobot +from fastsim.utils.namespace import PoseVisualType, SimulatorType +from fastsim.unisim.robots.actuator_configs.grippers import GripperDriveJointConfig +from fastsim.extensions.benchmark.action import RobotAction +from fastsim.extensions.benchmark.benchmark import ( BenchmarkAction, BenchmarkObservation, ControlMode, ) -from joysim.extensions.benchmark.policy import Policy, PolicyConfig -from joysim.utils.log import Log -from joysim.utils.pose import Pose +from fastsim.extensions.benchmark.policy import Policy, PolicyConfig +from fastsim.utils.log import Log +from fastsim.utils.pose import Pose @configclass @stereotype.register_config("starvla") class StarvlaPolicyConfig(PolicyConfig): robot_name: str = field(default="None", required=True, comment="The name of the robot") - arm_name: str = field(default="main_arm", required=True, comment="The name of the arm module to control") - drive_name: str = field(default="robotiq_85_left_knuckle_joint", required=True, comment="The name of the drive module to control") - # gripper_width_mapper_file: str = field(default="", required=True, comment="The file path to the gripper width mapper") visualize_action_ee_pose: bool = field(default=False, required=True, comment="Whether to visualize the action end effector pose") visualize_state_ee_pose: bool = field(default=False, required=True, comment="Whether to visualize the state end effector pose") visualize_bounding_box_targets: list[str] = field( @@ -67,39 +64,29 @@ class StarvlaPolicy(Policy): super().__init__(config) self.robot_name = config.robot_name - self.arm_name = config.arm_name - self.drive_name = config.drive_name self.sensor_names = config.sensor_names self.server_url = config.server_url self.prompt = config.prompt - # self.gripper_width_mapper = json.load(open(config.gripper_width_mapper_file, "r")) self.visualize_action_ee_pose = config.visualize_action_ee_pose self.visualize_state_ee_pose = config.visualize_state_ee_pose self.visualize_bounding_box_targets = list(config.visualize_bounding_box_targets or []) def reset(self) -> None: - self.current_ee_position_state = None - self.current_ee_rot6d_state = None - self.current_gripper_width = None + self.current_state = {} self.current_chunk_id = 0 self.current_chunk_result = None self.run_trunk_size = self.config.run_trunk_size self.robot: ModularRobot = SpawnableController.get_spawnable_data(self.robot_name).unwrap() - # self.robot: ModularRobot = SpawnableController.get_spawnable(self.robot_name) - # self.drive_joints: dict[str, GripperDriveJointConfig] = self.robot.get_arm(self.arm_name).get_ee().get_drive_joints() - # self.robot_drive_name = list(self.drive_joints.keys())[0] - self.robot_drive_name = self.drive_name - - # for joint_name, joint_config in self.drive_joints.items(): - # SpawnableController.control_robot(self.robot_name, "set_joint_stiffness", parameters={"joint_names": [joint_name], "stiffness": joint_config.position_control_stiffness}).unwrap() - # SpawnableController.control_robot(self.robot_name, "set_joint_damping", parameters={"joint_names": [joint_name], "damping": joint_config.position_control_damping}).unwrap() - # SpawnableController.control_robot(self.robot_name, "set_joint_effort_limit", parameters={"joint_names": [joint_name], "effort_limit": 5000}).unwrap() - # SpawnableController.control_robot(self.robot_name, "set_joint_effort_limit", parameters={"joint_names": [self.robot_drive_name], "effort_limit": 5000}).unwrap() - self.max_width = float("-inf") - self.min_width = float("inf") - # for entry in self.gripper_width_mapper: - # self.max_width = max(self.max_width, entry["width"]) - # self.min_width = min(self.min_width, entry["width"]) + self.left_hand_joints = SpawnableController.control_robot( + self.robot_name, + "get_actuator_joint_names", + parameters={"actuator_name": "left_hand"}, + ).unwrap() + self.right_hand_joints = SpawnableController.control_robot( + self.robot_name, + "get_actuator_joint_names", + parameters={"actuator_name": "right_hand"}, + ).unwrap() def warmup(self, benchmark_observation: BenchmarkObservation) -> None: Log.info(f"Waiting for StarVLA inference server to be ready...") @@ -122,28 +109,41 @@ class StarvlaPolicy(Policy): elif response.status_code != 200: Log.error(f"StarVLA server error with status code <{response.status_code}> : {response.text}", exit=True) + def split_joints(self, state_or_action, keys=None) -> list[dict]: + if keys is None: + keys = ["left_arm", "right_arm"] + total_dim = 31 * len(keys) + assert state_or_action.shape[-1] == total_dim, f"Expected last dimension to be {total_dim}, got {state_or_action.shape[-1]}" + joints_all = np.split(state_or_action, [31], axis=-1) + return_dict = {} + for key, joints in zip(keys, joints_all): + ee_pos, ee_rot6d, finger_qpos = np.split(joints, [3, 9], axis=-1) + return_dict[key] = { + "ee_pos": ee_pos, + "ee_rot6d": ee_rot6d, + "finger_qpos": finger_qpos + } + return return_dict + def preprocess_observation(self, benchmark_observation: BenchmarkObservation) -> dict: robot_obs = benchmark_observation.get_robot_observations(self.robot_name)["robot_data"] - ee_pose_base = robot_obs["ee_pose"][self.arm_name]["base_frame"] - ee_position, ee_rot6d = ee_pose_base["position"],ee_pose_base["rot6d"] - arm_joint_positions = robot_obs["joint_positions"][:7] # 临时多加了一个drive的位置,现在读的最后一个joint值是drive - drive_joint_positions = robot_obs["joint_positions"][-1] - normalized_gripper_width = self.__map_joint_position_to_normalized_width(drive_joint_positions) - Log.debug(f"input normalized_gripper_width state: {round(normalized_gripper_width, 2)}") - state = np.concatenate([ee_position,ee_rot6d,np.array([normalized_gripper_width]), [0]*10, np.array(arm_joint_positions)]) + left_ee_pose_base = robot_obs["ee_pose"]["left_arm"]["base_frame"] + left_ee_position, left_ee_rot6d = left_ee_pose_base["position"], left_ee_pose_base["rot6d"] + right_ee_pose_base = robot_obs["ee_pose"]["right_arm"]["base_frame"] + right_ee_position, right_ee_rot6d = right_ee_pose_base["position"], right_ee_pose_base["rot6d"] + finger_positions = robot_obs["joint_positions"] # use finger joints(44) only + state = np.concatenate([left_ee_position, left_ee_rot6d, finger_positions[:22], + right_ee_position, right_ee_rot6d, finger_positions[22:]], axis=-1) # (62,) rgb_data = {} for sensor_name in self.sensor_names: sensor_obs = benchmark_observation.get_sensor_observations(sensor_name) rgb_data[sensor_name] = sensor_obs["rgb"].data.cpu().numpy().astype(np.uint8) obs = {"state": state,"rgb": rgb_data,"prompt": self.prompt} - return obs def compute_action(self, observation: dict) -> dict: if self.current_chunk_result is None: - self.current_ee_position_state = np.array(observation["state"][:3]).astype(np.float64) - self.current_ee_rot6d_state = np.array(observation["state"][3:9]).astype(np.float64) - self.current_gripper_width = np.array([observation["state"][9]]) + self.current_state.update(self.split_joints(observation["state"])) payload = pickle.dumps(observation) response = requests.post( f"{self.server_url}/inference", @@ -153,7 +153,7 @@ class StarvlaPolicy(Policy): self.test_obs = observation["state"] #TODO self._handle_server_error(response) result = pickle.loads(response.content) - max_trunk_size = len(result["ee_delta_position_chunks"]) + max_trunk_size = len(result["right_arm"]["ee_delta_position_chunks"]) if self.run_trunk_size > max_trunk_size: Log.warning(f"Run trunk size {self.run_trunk_size} is greater than the number of chunks {max_trunk_size}. Set run trunk size to {max_trunk_size}.") self.run_trunk_size = max_trunk_size @@ -163,59 +163,33 @@ class StarvlaPolicy(Policy): result = self.current_chunk_result return result - def __map_joint_position_to_normalized_width(self, joint_position: float) -> float: - if joint_position < 0: - joint_position = 0 - if joint_position > 0.8: - joint_position = 0.8 - # for entry in self.gripper_width_mapper: - # if round(entry["angel"], 2) == round(joint_position, 2): - # return 1-(entry["width"] - self.min_width) / (self.max_width - self.min_width) - - - def __map_gripper_joint_position(self, normalized_gripper_width: float) -> float: - - joint_positions = [] - joint_names = [] - if normalized_gripper_width > 0.5: - for joint_name, joint_config in self.drive_joints.items(): - joint_positions.append(joint_config.close_position) - joint_names.append(joint_name) - else: - for joint_name, joint_config in self.drive_joints.items(): - joint_positions.append(joint_config.open_position) - joint_names.append(joint_name) - return joint_positions, joint_names def postprocess_action(self, action: dict) -> BenchmarkAction: benchmark_action = BenchmarkAction() - Log.debug(f"observation: {self.test_obs}") - # import ipdb;ipdb.set_trace() - - # get base frame end-effector pose - delta_ee_pose = Pose(position=action["ee_delta_position_chunks"][self.current_chunk_id], rot6d=action["ee_delta_rot6d_chunks"][self.current_chunk_id]) - curr_state_ee_pose = Pose(position=self.current_ee_position_state, rot6d=self.current_ee_rot6d_state) - curr_action_ee_pose = curr_state_ee_pose * delta_ee_pose # action2base = state2base * action2state - curr_action_gripper_width = action["gripper_width_chunks"][self.current_chunk_id] - - gripper_joint_positions, gripper_joint_names = self.__map_gripper_joint_position(curr_action_gripper_width[0]) - Log.debug(f"action_gripper_joint_positions: {gripper_joint_positions}, action_normalized_gripper_width: {round(curr_action_gripper_width[0], 2)}") - benchmark_action.add_robot_action( - RobotAction( - control_mode=ControlMode.POSITION, - robot_name=self.robot_name, - joint_names=gripper_joint_names, - joint_positions=gripper_joint_positions + for arm_key in self.robot['arms'].keys(): + action_arm = action[arm_key] + delta_ee_pose = Pose(position=action_arm["ee_delta_position_chunks"][self.current_chunk_id], rot6d=action_arm["ee_delta_rot6d_chunks"][self.current_chunk_id]) + curr_state_ee_pose = Pose(position=self.current_state[arm_key]["ee_pos"], rot6d=self.current_state[arm_key]["ee_rot6d"]) + curr_action_ee_pose = curr_state_ee_pose * delta_ee_pose # action2base = state2base * action2state + finger_joint_qpos = action_arm["finger_chunks"][self.current_chunk_id] + self.current_state[arm_key]["finger_qpos"] + joint_names = self.left_hand_joints if arm_key == "left_arm" else self.right_hand_joints + benchmark_action.add_robot_action( + RobotAction( + control_mode=ControlMode.POSITION, + robot_name=self.robot_name, + joint_names=joint_names, + joint_positions=finger_joint_qpos + ) ) - ) - benchmark_action.add_robot_action( - RobotAction( - control_mode=ControlMode.EE_POSE, - robot_name=self.robot_name, - ee_pose=curr_action_ee_pose + benchmark_action.add_robot_action( + RobotAction( + control_mode=ControlMode.EE_POSE, + robot_name=self.robot_name, + ee_pose=curr_action_ee_pose, + arm_name=arm_key + ) ) - ) self._visualize_base_frame_ee_poses(curr_state_ee_pose, curr_action_ee_pose) self._visualize_bounding_boxes() self.current_chunk_id += 1