This commit is contained in:
Junhan
2026-05-07 19:31:45 +08:00
parent 7ce2823c56
commit 2514cc943d
6 changed files with 5976 additions and 94 deletions

View File

@@ -8,8 +8,9 @@ from joysim.annotations.config_class import configclass, field
from joysim.annotations.stereotype import stereotype
from joysim.controllers.spawnable_controller import SpawnableController
from joysim.controllers.visualize_controller import VisualizeController
from joysim.unisim.robots.models.modular_robot import ModularRobot
from joysim.utils.namespace import PoseVisualType, SimulatorType
from joysim.unisim.robots.configs.actuator_configs.grippers import GripperDriveJointConfig
from joysim.unisim.robots.actuator_configs.grippers import GripperDriveJointConfig
from joysim.extensions.benchmark.action import RobotAction
from joysim.extensions.benchmark.benchmark import (
BenchmarkAction,
@@ -25,6 +26,8 @@ from joysim.utils.pose import Pose
class StarvlaPolicyConfig(PolicyConfig):
robot_name: str = field(default="None", required=True, comment="The name of the robot")
arm_name: str = field(default="main_arm", required=True, comment="The name of the arm module to control")
drive_name: str = field(default="robotiq_85_left_knuckle_joint", required=True, comment="The name of the drive module to control")
gripper_width_mapper_file: str = field(default="", required=True, comment="The file path to the gripper width mapper")
visualize_action_ee_pose: bool = field(default=False, required=True, comment="Whether to visualize the action end effector pose")
visualize_state_ee_pose: bool = field(default=False, required=True, comment="Whether to visualize the state end effector pose")
@@ -64,6 +67,8 @@ class StarvlaPolicy(Policy):
super().__init__(config)
self.robot_name = config.robot_name
self.arm_name = config.arm_name
self.drive_name = config.drive_name
self.sensor_names = config.sensor_names
self.server_url = config.server_url
self.prompt = config.prompt
@@ -79,11 +84,15 @@ class StarvlaPolicy(Policy):
self.current_chunk_id = 0
self.current_chunk_result = None
self.run_trunk_size = self.config.run_trunk_size
self.drive_joints: dict[str, GripperDriveJointConfig] = SpawnableController.control_robot(self.robot_name, "get_gripper_drive_joints").unwrap()
self.robot: ModularRobot = SpawnableController.get_spawnable_data(self.robot_name).unwrap()
self.drive_joints: dict[str, GripperDriveJointConfig] = self.robot.get_arm(self.arm_name).get_ee().get_drive_joints()
self.robot_drive_name = list(self.drive_joints.keys())[0]
for joint_name, joint_config in self.drive_joints.items():
SpawnableController.control_robot(self.robot_name, "set_joint_stiffness", parameters={"joint_names": [joint_name], "stiffness": joint_config.position_control_stiffness}).unwrap()
SpawnableController.control_robot(self.robot_name, "set_joint_damping", parameters={"joint_names": [joint_name], "damping": joint_config.position_control_damping}).unwrap()
SpawnableController.control_robot(self.robot_name, "set_joint_effort_limit", parameters={"joint_names": [joint_name], "effort_limit": 50}).unwrap()
SpawnableController.control_robot(self.robot_name, "set_joint_effort_limit", parameters={"joint_names": [joint_name], "effort_limit": 5000}).unwrap()
SpawnableController.control_robot(self.robot_name, "set_joint_effort_limit", parameters={"joint_names": [self.robot_drive_name], "effort_limit": 5000}).unwrap()
self.max_width = float("-inf")
self.min_width = float("inf")
for entry in self.gripper_width_mapper:
@@ -113,11 +122,13 @@ class StarvlaPolicy(Policy):
def preprocess_observation(self, benchmark_observation: BenchmarkObservation) -> dict:
robot_obs = benchmark_observation.get_robot_observations(self.robot_name)["robot_data"]
ee_pose_base = robot_obs["ee_pose"]["base_frame"]
ee_pose_base = robot_obs["ee_pose"][self.arm_name]["base_frame"]
ee_position, ee_rot6d = ee_pose_base["position"],ee_pose_base["rot6d"]
normalized_gripper_width = self.__map_joint_position_to_normalized_width(robot_obs["gripper_drive_state"]["position"][0])
arm_joint_positions = robot_obs["joint_positions"][:7] # 临时多加了一个drive的位置现在读的最后一个joint值是drive
drive_joint_positions = robot_obs["joint_positions"][-1]
normalized_gripper_width = self.__map_joint_position_to_normalized_width(drive_joint_positions)
Log.debug(f"input normalized_gripper_width state: {round(normalized_gripper_width, 2)}")
state = np.concatenate([ee_position,ee_rot6d,np.array([normalized_gripper_width])])
state = np.concatenate([ee_position,ee_rot6d,np.array([normalized_gripper_width]), [0]*10, np.array(arm_joint_positions)])
rgb_data = {}
for sensor_name in self.sensor_names:
sensor_obs = benchmark_observation.get_sensor_observations(sensor_name)
@@ -137,6 +148,7 @@ class StarvlaPolicy(Policy):
data=payload,
headers={"Content-Type": "application/octet-stream"}
)
self.test_obs = observation["state"] #TODO
self._handle_server_error(response)
result = pickle.loads(response.content)
max_trunk_size = len(result["ee_delta_position_chunks"])
@@ -176,7 +188,9 @@ class StarvlaPolicy(Policy):
def postprocess_action(self, action: dict) -> BenchmarkAction:
benchmark_action = BenchmarkAction()
Log.debug(f"observation: {self.test_obs}")
# import ipdb;ipdb.set_trace()
# get base frame end-effector pose
delta_ee_pose = Pose(position=action["ee_delta_position_chunks"][self.current_chunk_id], rot6d=action["ee_delta_rot6d_chunks"][self.current_chunk_id])
curr_state_ee_pose = Pose(position=self.current_ee_position_state, rot6d=self.current_ee_rot6d_state)
@@ -221,7 +235,7 @@ class StarvlaPolicy(Policy):
VisualizeController.create_pose_visualization(
robot_base_world * pose_state_base,
name=f"{self.robot_name}/starvla_state_ee",
simulator=SimulatorType.ISAACSIM,
simulator=SimulatorType.ISAACLAB,
pose_type=PoseVisualType.COORDINATE,
extra_params={"axis_length": 0.08, "thickness": 0.006},
).unwrap()
@@ -229,7 +243,7 @@ class StarvlaPolicy(Policy):
VisualizeController.create_pose_visualization(
robot_base_world * pose_action_base,
name=f"{self.robot_name}/starvla_action_ee",
simulator=SimulatorType.ISAACSIM,
simulator=SimulatorType.ISAACLAB,
pose_type=PoseVisualType.COORDINATE,
extra_params={"axis_length": 0.1, "thickness": 0.006},
).unwrap()
@@ -239,5 +253,5 @@ class StarvlaPolicy(Policy):
return
for target_name in self.visualize_bounding_box_targets:
VisualizeController.visualize_target_bounding_box(
target_name, simulator=SimulatorType.ISAACSIM
target_name, simulator=SimulatorType.ISAACLAB
).unwrap()