adapt for starvla-dex

This commit is contained in:
QiyangYan
2026-05-08 18:24:31 +08:00
parent 2514cc943d
commit 53796c1e63
3 changed files with 89 additions and 150 deletions

View File

@@ -28,7 +28,7 @@ class StarvlaPolicyConfig(PolicyConfig):
robot_name: str = field(default="None", required=True, comment="The name of the robot")
arm_name: str = field(default="main_arm", required=True, comment="The name of the arm module to control")
drive_name: str = field(default="robotiq_85_left_knuckle_joint", required=True, comment="The name of the drive module to control")
gripper_width_mapper_file: str = field(default="", required=True, comment="The file path to the gripper width mapper")
# gripper_width_mapper_file: str = field(default="", required=True, comment="The file path to the gripper width mapper")
visualize_action_ee_pose: bool = field(default=False, required=True, comment="Whether to visualize the action end effector pose")
visualize_state_ee_pose: bool = field(default=False, required=True, comment="Whether to visualize the state end effector pose")
visualize_bounding_box_targets: list[str] = field(
@@ -72,7 +72,7 @@ class StarvlaPolicy(Policy):
self.sensor_names = config.sensor_names
self.server_url = config.server_url
self.prompt = config.prompt
self.gripper_width_mapper = json.load(open(config.gripper_width_mapper_file, "r"))
# self.gripper_width_mapper = json.load(open(config.gripper_width_mapper_file, "r"))
self.visualize_action_ee_pose = config.visualize_action_ee_pose
self.visualize_state_ee_pose = config.visualize_state_ee_pose
self.visualize_bounding_box_targets = list(config.visualize_bounding_box_targets or [])
@@ -85,19 +85,21 @@ class StarvlaPolicy(Policy):
self.current_chunk_result = None
self.run_trunk_size = self.config.run_trunk_size
self.robot: ModularRobot = SpawnableController.get_spawnable_data(self.robot_name).unwrap()
self.drive_joints: dict[str, GripperDriveJointConfig] = self.robot.get_arm(self.arm_name).get_ee().get_drive_joints()
self.robot_drive_name = list(self.drive_joints.keys())[0]
# self.robot: ModularRobot = SpawnableController.get_spawnable(self.robot_name)
# self.drive_joints: dict[str, GripperDriveJointConfig] = self.robot.get_arm(self.arm_name).get_ee().get_drive_joints()
# self.robot_drive_name = list(self.drive_joints.keys())[0]
self.robot_drive_name = self.drive_name
for joint_name, joint_config in self.drive_joints.items():
SpawnableController.control_robot(self.robot_name, "set_joint_stiffness", parameters={"joint_names": [joint_name], "stiffness": joint_config.position_control_stiffness}).unwrap()
SpawnableController.control_robot(self.robot_name, "set_joint_damping", parameters={"joint_names": [joint_name], "damping": joint_config.position_control_damping}).unwrap()
SpawnableController.control_robot(self.robot_name, "set_joint_effort_limit", parameters={"joint_names": [joint_name], "effort_limit": 5000}).unwrap()
SpawnableController.control_robot(self.robot_name, "set_joint_effort_limit", parameters={"joint_names": [self.robot_drive_name], "effort_limit": 5000}).unwrap()
# for joint_name, joint_config in self.drive_joints.items():
# SpawnableController.control_robot(self.robot_name, "set_joint_stiffness", parameters={"joint_names": [joint_name], "stiffness": joint_config.position_control_stiffness}).unwrap()
# SpawnableController.control_robot(self.robot_name, "set_joint_damping", parameters={"joint_names": [joint_name], "damping": joint_config.position_control_damping}).unwrap()
# SpawnableController.control_robot(self.robot_name, "set_joint_effort_limit", parameters={"joint_names": [joint_name], "effort_limit": 5000}).unwrap()
# SpawnableController.control_robot(self.robot_name, "set_joint_effort_limit", parameters={"joint_names": [self.robot_drive_name], "effort_limit": 5000}).unwrap()
self.max_width = float("-inf")
self.min_width = float("inf")
for entry in self.gripper_width_mapper:
self.max_width = max(self.max_width, entry["width"])
self.min_width = min(self.min_width, entry["width"])
# for entry in self.gripper_width_mapper:
# self.max_width = max(self.max_width, entry["width"])
# self.min_width = min(self.min_width, entry["width"])
def warmup(self, benchmark_observation: BenchmarkObservation) -> None:
Log.info(f"Waiting for StarVLA inference server to be ready...")
@@ -166,9 +168,9 @@ class StarvlaPolicy(Policy):
joint_position = 0
if joint_position > 0.8:
joint_position = 0.8
for entry in self.gripper_width_mapper:
if round(entry["angel"], 2) == round(joint_position, 2):
return 1-(entry["width"] - self.min_width) / (self.max_width - self.min_width)
# for entry in self.gripper_width_mapper:
# if round(entry["angel"], 2) == round(joint_position, 2):
# return 1-(entry["width"] - self.min_width) / (self.max_width - self.min_width)
@@ -254,4 +256,4 @@ class StarvlaPolicy(Policy):
for target_name in self.visualize_bounding_box_targets:
VisualizeController.visualize_target_bounding_box(
target_name, simulator=SimulatorType.ISAACLAB
).unwrap()
).unwrap()