This commit is contained in:
hufei.hofee
2026-03-24 19:38:15 +08:00
parent 833ade73fd
commit 7ce2823c56
2 changed files with 81 additions and 7 deletions

View File

@@ -7,7 +7,9 @@ import requests
from joysim.annotations.config_class import configclass, field
from joysim.annotations.stereotype import stereotype
from joysim.controllers.spawnable_controller import SpawnableController
from joysim.core.robots.configs.actuator_configs.grippers import GripperDriveJointConfig
from joysim.controllers.visualize_controller import VisualizeController
from joysim.utils.namespace import PoseVisualType, SimulatorType
from joysim.unisim.robots.configs.actuator_configs.grippers import GripperDriveJointConfig
from joysim.extensions.benchmark.action import RobotAction
from joysim.extensions.benchmark.benchmark import (
BenchmarkAction,
@@ -24,6 +26,13 @@ class StarvlaPolicyConfig(PolicyConfig):
robot_name: str = field(default="None", required=True, comment="The name of the robot")
gripper_width_mapper_file: str = field(default="", required=True, comment="The file path to the gripper width mapper")
visualize_action_ee_pose: bool = field(default=False, required=True, comment="Whether to visualize the action end effector pose")
visualize_state_ee_pose: bool = field(default=False, required=True, comment="Whether to visualize the state end effector pose")
visualize_bounding_box_targets: list[str] = field(
default_factory=list,
required=False,
comment="Spawnable object names to draw bounding boxes for (empty = off)",
)
sensor_names: list[str] = field(
default=["Hand_Camera", "Left_Camera", "Right_Camera"],
required=True,
@@ -59,6 +68,10 @@ class StarvlaPolicy(Policy):
self.server_url = config.server_url
self.prompt = config.prompt
self.gripper_width_mapper = json.load(open(config.gripper_width_mapper_file, "r"))
self.visualize_action_ee_pose = config.visualize_action_ee_pose
self.visualize_state_ee_pose = config.visualize_state_ee_pose
self.visualize_bounding_box_targets = list(config.visualize_bounding_box_targets or [])
def reset(self) -> None:
self.current_ee_position_state = None
self.current_ee_rot6d_state = None
@@ -70,7 +83,7 @@ class StarvlaPolicy(Policy):
for joint_name, joint_config in self.drive_joints.items():
SpawnableController.control_robot(self.robot_name, "set_joint_stiffness", parameters={"joint_names": [joint_name], "stiffness": joint_config.position_control_stiffness}).unwrap()
SpawnableController.control_robot(self.robot_name, "set_joint_damping", parameters={"joint_names": [joint_name], "damping": joint_config.position_control_damping}).unwrap()
SpawnableController.control_robot(self.robot_name, "set_joint_effort_limit", parameters={"joint_names": [joint_name], "effort_limit": joint_config.position_control_effort_limit}).unwrap()
SpawnableController.control_robot(self.robot_name, "set_joint_effort_limit", parameters={"joint_names": [joint_name], "effort_limit": 50}).unwrap()
self.max_width = float("-inf")
self.min_width = float("inf")
for entry in self.gripper_width_mapper:
@@ -103,6 +116,7 @@ class StarvlaPolicy(Policy):
ee_pose_base = robot_obs["ee_pose"]["base_frame"]
ee_position, ee_rot6d = ee_pose_base["position"],ee_pose_base["rot6d"]
normalized_gripper_width = self.__map_joint_position_to_normalized_width(robot_obs["gripper_drive_state"]["position"][0])
Log.debug(f"input normalized_gripper_width state: {round(normalized_gripper_width, 2)}")
state = np.concatenate([ee_position,ee_rot6d,np.array([normalized_gripper_width])])
rgb_data = {}
for sensor_name in self.sensor_names:
@@ -136,11 +150,15 @@ class StarvlaPolicy(Policy):
return result
def __map_joint_position_to_normalized_width(self, joint_position: float) -> float:
if joint_position < 0:
joint_position = 0
if joint_position > 0.8:
joint_position = 0.8
for entry in self.gripper_width_mapper:
if round(entry["angel"], 2) == round(joint_position, 2):
return 1-(entry["width"] - self.min_width) / (self.max_width - self.min_width)
raise ValueError(f"Joint position {joint_position} not found in gripper width mapper")
def __map_gripper_joint_position(self, normalized_gripper_width: float) -> float:
@@ -182,8 +200,44 @@ class StarvlaPolicy(Policy):
ee_pose=curr_action_ee_pose
)
)
self._visualize_base_frame_ee_poses(curr_state_ee_pose, curr_action_ee_pose)
self._visualize_bounding_boxes()
self.current_chunk_id += 1
if self.current_chunk_id == self.run_trunk_size:
self.current_chunk_id = 0
self.current_chunk_result = None
return benchmark_action
return benchmark_action
# ------------------- Visualization -------------------
def _visualize_base_frame_ee_poses(
self, pose_state_base: Pose, pose_action_base: Pose
) -> None:
if not self.visualize_action_ee_pose and not self.visualize_state_ee_pose:
return
robot_base_world = SpawnableController.control_robot(
self.robot_name, "get_pose"
).unwrap()
if self.visualize_state_ee_pose:
VisualizeController.create_pose_visualization(
robot_base_world * pose_state_base,
name=f"{self.robot_name}/starvla_state_ee",
simulator=SimulatorType.ISAACSIM,
pose_type=PoseVisualType.COORDINATE,
extra_params={"axis_length": 0.08, "thickness": 0.006},
).unwrap()
if self.visualize_action_ee_pose:
VisualizeController.create_pose_visualization(
robot_base_world * pose_action_base,
name=f"{self.robot_name}/starvla_action_ee",
simulator=SimulatorType.ISAACSIM,
pose_type=PoseVisualType.COORDINATE,
extra_params={"axis_length": 0.1, "thickness": 0.006},
).unwrap()
def _visualize_bounding_boxes(self) -> None:
if not self.visualize_bounding_box_targets:
return
for target_name in self.visualize_bounding_box_targets:
VisualizeController.visualize_target_bounding_box(
target_name, simulator=SimulatorType.ISAACSIM
).unwrap()