update
This commit is contained in:
@@ -7,7 +7,9 @@ import requests
|
||||
from joysim.annotations.config_class import configclass, field
|
||||
from joysim.annotations.stereotype import stereotype
|
||||
from joysim.controllers.spawnable_controller import SpawnableController
|
||||
from joysim.core.robots.configs.actuator_configs.grippers import GripperDriveJointConfig
|
||||
from joysim.controllers.visualize_controller import VisualizeController
|
||||
from joysim.utils.namespace import PoseVisualType, SimulatorType
|
||||
from joysim.unisim.robots.configs.actuator_configs.grippers import GripperDriveJointConfig
|
||||
from joysim.extensions.benchmark.action import RobotAction
|
||||
from joysim.extensions.benchmark.benchmark import (
|
||||
BenchmarkAction,
|
||||
@@ -24,6 +26,13 @@ class StarvlaPolicyConfig(PolicyConfig):
|
||||
|
||||
robot_name: str = field(default="None", required=True, comment="The name of the robot")
|
||||
gripper_width_mapper_file: str = field(default="", required=True, comment="The file path to the gripper width mapper")
|
||||
visualize_action_ee_pose: bool = field(default=False, required=True, comment="Whether to visualize the action end effector pose")
|
||||
visualize_state_ee_pose: bool = field(default=False, required=True, comment="Whether to visualize the state end effector pose")
|
||||
visualize_bounding_box_targets: list[str] = field(
|
||||
default_factory=list,
|
||||
required=False,
|
||||
comment="Spawnable object names to draw bounding boxes for (empty = off)",
|
||||
)
|
||||
sensor_names: list[str] = field(
|
||||
default=["Hand_Camera", "Left_Camera", "Right_Camera"],
|
||||
required=True,
|
||||
@@ -59,6 +68,10 @@ class StarvlaPolicy(Policy):
|
||||
self.server_url = config.server_url
|
||||
self.prompt = config.prompt
|
||||
self.gripper_width_mapper = json.load(open(config.gripper_width_mapper_file, "r"))
|
||||
self.visualize_action_ee_pose = config.visualize_action_ee_pose
|
||||
self.visualize_state_ee_pose = config.visualize_state_ee_pose
|
||||
self.visualize_bounding_box_targets = list(config.visualize_bounding_box_targets or [])
|
||||
|
||||
def reset(self) -> None:
|
||||
self.current_ee_position_state = None
|
||||
self.current_ee_rot6d_state = None
|
||||
@@ -70,7 +83,7 @@ class StarvlaPolicy(Policy):
|
||||
for joint_name, joint_config in self.drive_joints.items():
|
||||
SpawnableController.control_robot(self.robot_name, "set_joint_stiffness", parameters={"joint_names": [joint_name], "stiffness": joint_config.position_control_stiffness}).unwrap()
|
||||
SpawnableController.control_robot(self.robot_name, "set_joint_damping", parameters={"joint_names": [joint_name], "damping": joint_config.position_control_damping}).unwrap()
|
||||
SpawnableController.control_robot(self.robot_name, "set_joint_effort_limit", parameters={"joint_names": [joint_name], "effort_limit": joint_config.position_control_effort_limit}).unwrap()
|
||||
SpawnableController.control_robot(self.robot_name, "set_joint_effort_limit", parameters={"joint_names": [joint_name], "effort_limit": 50}).unwrap()
|
||||
self.max_width = float("-inf")
|
||||
self.min_width = float("inf")
|
||||
for entry in self.gripper_width_mapper:
|
||||
@@ -103,6 +116,7 @@ class StarvlaPolicy(Policy):
|
||||
ee_pose_base = robot_obs["ee_pose"]["base_frame"]
|
||||
ee_position, ee_rot6d = ee_pose_base["position"],ee_pose_base["rot6d"]
|
||||
normalized_gripper_width = self.__map_joint_position_to_normalized_width(robot_obs["gripper_drive_state"]["position"][0])
|
||||
Log.debug(f"input normalized_gripper_width state: {round(normalized_gripper_width, 2)}")
|
||||
state = np.concatenate([ee_position,ee_rot6d,np.array([normalized_gripper_width])])
|
||||
rgb_data = {}
|
||||
for sensor_name in self.sensor_names:
|
||||
@@ -136,11 +150,15 @@ class StarvlaPolicy(Policy):
|
||||
|
||||
return result
|
||||
def __map_joint_position_to_normalized_width(self, joint_position: float) -> float:
|
||||
if joint_position < 0:
|
||||
joint_position = 0
|
||||
if joint_position > 0.8:
|
||||
joint_position = 0.8
|
||||
for entry in self.gripper_width_mapper:
|
||||
if round(entry["angel"], 2) == round(joint_position, 2):
|
||||
return 1-(entry["width"] - self.min_width) / (self.max_width - self.min_width)
|
||||
|
||||
raise ValueError(f"Joint position {joint_position} not found in gripper width mapper")
|
||||
|
||||
|
||||
def __map_gripper_joint_position(self, normalized_gripper_width: float) -> float:
|
||||
|
||||
@@ -182,8 +200,44 @@ class StarvlaPolicy(Policy):
|
||||
ee_pose=curr_action_ee_pose
|
||||
)
|
||||
)
|
||||
self._visualize_base_frame_ee_poses(curr_state_ee_pose, curr_action_ee_pose)
|
||||
self._visualize_bounding_boxes()
|
||||
self.current_chunk_id += 1
|
||||
if self.current_chunk_id == self.run_trunk_size:
|
||||
self.current_chunk_id = 0
|
||||
self.current_chunk_result = None
|
||||
return benchmark_action
|
||||
return benchmark_action
|
||||
|
||||
# ------------------- Visualization -------------------
|
||||
def _visualize_base_frame_ee_poses(
|
||||
self, pose_state_base: Pose, pose_action_base: Pose
|
||||
) -> None:
|
||||
if not self.visualize_action_ee_pose and not self.visualize_state_ee_pose:
|
||||
return
|
||||
robot_base_world = SpawnableController.control_robot(
|
||||
self.robot_name, "get_pose"
|
||||
).unwrap()
|
||||
if self.visualize_state_ee_pose:
|
||||
VisualizeController.create_pose_visualization(
|
||||
robot_base_world * pose_state_base,
|
||||
name=f"{self.robot_name}/starvla_state_ee",
|
||||
simulator=SimulatorType.ISAACSIM,
|
||||
pose_type=PoseVisualType.COORDINATE,
|
||||
extra_params={"axis_length": 0.08, "thickness": 0.006},
|
||||
).unwrap()
|
||||
if self.visualize_action_ee_pose:
|
||||
VisualizeController.create_pose_visualization(
|
||||
robot_base_world * pose_action_base,
|
||||
name=f"{self.robot_name}/starvla_action_ee",
|
||||
simulator=SimulatorType.ISAACSIM,
|
||||
pose_type=PoseVisualType.COORDINATE,
|
||||
extra_params={"axis_length": 0.1, "thickness": 0.006},
|
||||
).unwrap()
|
||||
|
||||
def _visualize_bounding_boxes(self) -> None:
|
||||
if not self.visualize_bounding_box_targets:
|
||||
return
|
||||
for target_name in self.visualize_bounding_box_targets:
|
||||
VisualizeController.visualize_target_bounding_box(
|
||||
target_name, simulator=SimulatorType.ISAACSIM
|
||||
).unwrap()
|
||||
Reference in New Issue
Block a user