adapt for starvla-dex
This commit is contained in:
@@ -28,7 +28,7 @@ class StarvlaPolicyConfig(PolicyConfig):
|
||||
robot_name: str = field(default="None", required=True, comment="The name of the robot")
|
||||
arm_name: str = field(default="main_arm", required=True, comment="The name of the arm module to control")
|
||||
drive_name: str = field(default="robotiq_85_left_knuckle_joint", required=True, comment="The name of the drive module to control")
|
||||
gripper_width_mapper_file: str = field(default="", required=True, comment="The file path to the gripper width mapper")
|
||||
# gripper_width_mapper_file: str = field(default="", required=True, comment="The file path to the gripper width mapper")
|
||||
visualize_action_ee_pose: bool = field(default=False, required=True, comment="Whether to visualize the action end effector pose")
|
||||
visualize_state_ee_pose: bool = field(default=False, required=True, comment="Whether to visualize the state end effector pose")
|
||||
visualize_bounding_box_targets: list[str] = field(
|
||||
@@ -72,7 +72,7 @@ class StarvlaPolicy(Policy):
|
||||
self.sensor_names = config.sensor_names
|
||||
self.server_url = config.server_url
|
||||
self.prompt = config.prompt
|
||||
self.gripper_width_mapper = json.load(open(config.gripper_width_mapper_file, "r"))
|
||||
# self.gripper_width_mapper = json.load(open(config.gripper_width_mapper_file, "r"))
|
||||
self.visualize_action_ee_pose = config.visualize_action_ee_pose
|
||||
self.visualize_state_ee_pose = config.visualize_state_ee_pose
|
||||
self.visualize_bounding_box_targets = list(config.visualize_bounding_box_targets or [])
|
||||
@@ -85,19 +85,21 @@ class StarvlaPolicy(Policy):
|
||||
self.current_chunk_result = None
|
||||
self.run_trunk_size = self.config.run_trunk_size
|
||||
self.robot: ModularRobot = SpawnableController.get_spawnable_data(self.robot_name).unwrap()
|
||||
self.drive_joints: dict[str, GripperDriveJointConfig] = self.robot.get_arm(self.arm_name).get_ee().get_drive_joints()
|
||||
self.robot_drive_name = list(self.drive_joints.keys())[0]
|
||||
# self.robot: ModularRobot = SpawnableController.get_spawnable(self.robot_name)
|
||||
# self.drive_joints: dict[str, GripperDriveJointConfig] = self.robot.get_arm(self.arm_name).get_ee().get_drive_joints()
|
||||
# self.robot_drive_name = list(self.drive_joints.keys())[0]
|
||||
self.robot_drive_name = self.drive_name
|
||||
|
||||
for joint_name, joint_config in self.drive_joints.items():
|
||||
SpawnableController.control_robot(self.robot_name, "set_joint_stiffness", parameters={"joint_names": [joint_name], "stiffness": joint_config.position_control_stiffness}).unwrap()
|
||||
SpawnableController.control_robot(self.robot_name, "set_joint_damping", parameters={"joint_names": [joint_name], "damping": joint_config.position_control_damping}).unwrap()
|
||||
SpawnableController.control_robot(self.robot_name, "set_joint_effort_limit", parameters={"joint_names": [joint_name], "effort_limit": 5000}).unwrap()
|
||||
SpawnableController.control_robot(self.robot_name, "set_joint_effort_limit", parameters={"joint_names": [self.robot_drive_name], "effort_limit": 5000}).unwrap()
|
||||
# for joint_name, joint_config in self.drive_joints.items():
|
||||
# SpawnableController.control_robot(self.robot_name, "set_joint_stiffness", parameters={"joint_names": [joint_name], "stiffness": joint_config.position_control_stiffness}).unwrap()
|
||||
# SpawnableController.control_robot(self.robot_name, "set_joint_damping", parameters={"joint_names": [joint_name], "damping": joint_config.position_control_damping}).unwrap()
|
||||
# SpawnableController.control_robot(self.robot_name, "set_joint_effort_limit", parameters={"joint_names": [joint_name], "effort_limit": 5000}).unwrap()
|
||||
# SpawnableController.control_robot(self.robot_name, "set_joint_effort_limit", parameters={"joint_names": [self.robot_drive_name], "effort_limit": 5000}).unwrap()
|
||||
self.max_width = float("-inf")
|
||||
self.min_width = float("inf")
|
||||
for entry in self.gripper_width_mapper:
|
||||
self.max_width = max(self.max_width, entry["width"])
|
||||
self.min_width = min(self.min_width, entry["width"])
|
||||
# for entry in self.gripper_width_mapper:
|
||||
# self.max_width = max(self.max_width, entry["width"])
|
||||
# self.min_width = min(self.min_width, entry["width"])
|
||||
|
||||
def warmup(self, benchmark_observation: BenchmarkObservation) -> None:
|
||||
Log.info(f"Waiting for StarVLA inference server to be ready...")
|
||||
@@ -166,9 +168,9 @@ class StarvlaPolicy(Policy):
|
||||
joint_position = 0
|
||||
if joint_position > 0.8:
|
||||
joint_position = 0.8
|
||||
for entry in self.gripper_width_mapper:
|
||||
if round(entry["angel"], 2) == round(joint_position, 2):
|
||||
return 1-(entry["width"] - self.min_width) / (self.max_width - self.min_width)
|
||||
# for entry in self.gripper_width_mapper:
|
||||
# if round(entry["angel"], 2) == round(joint_position, 2):
|
||||
# return 1-(entry["width"] - self.min_width) / (self.max_width - self.min_width)
|
||||
|
||||
|
||||
|
||||
@@ -254,4 +256,4 @@ class StarvlaPolicy(Policy):
|
||||
for target_name in self.visualize_bounding_box_targets:
|
||||
VisualizeController.visualize_target_bounding_box(
|
||||
target_name, simulator=SimulatorType.ISAACLAB
|
||||
).unwrap()
|
||||
).unwrap()
|
||||
|
||||
Reference in New Issue
Block a user