update
This commit is contained in:
@@ -8,8 +8,9 @@ from joysim.annotations.config_class import configclass, field
|
||||
from joysim.annotations.stereotype import stereotype
|
||||
from joysim.controllers.spawnable_controller import SpawnableController
|
||||
from joysim.controllers.visualize_controller import VisualizeController
|
||||
from joysim.unisim.robots.models.modular_robot import ModularRobot
|
||||
from joysim.utils.namespace import PoseVisualType, SimulatorType
|
||||
from joysim.unisim.robots.configs.actuator_configs.grippers import GripperDriveJointConfig
|
||||
from joysim.unisim.robots.actuator_configs.grippers import GripperDriveJointConfig
|
||||
from joysim.extensions.benchmark.action import RobotAction
|
||||
from joysim.extensions.benchmark.benchmark import (
|
||||
BenchmarkAction,
|
||||
@@ -25,6 +26,8 @@ from joysim.utils.pose import Pose
|
||||
class StarvlaPolicyConfig(PolicyConfig):
|
||||
|
||||
robot_name: str = field(default="None", required=True, comment="The name of the robot")
|
||||
arm_name: str = field(default="main_arm", required=True, comment="The name of the arm module to control")
|
||||
drive_name: str = field(default="robotiq_85_left_knuckle_joint", required=True, comment="The name of the drive module to control")
|
||||
gripper_width_mapper_file: str = field(default="", required=True, comment="The file path to the gripper width mapper")
|
||||
visualize_action_ee_pose: bool = field(default=False, required=True, comment="Whether to visualize the action end effector pose")
|
||||
visualize_state_ee_pose: bool = field(default=False, required=True, comment="Whether to visualize the state end effector pose")
|
||||
@@ -64,6 +67,8 @@ class StarvlaPolicy(Policy):
|
||||
super().__init__(config)
|
||||
|
||||
self.robot_name = config.robot_name
|
||||
self.arm_name = config.arm_name
|
||||
self.drive_name = config.drive_name
|
||||
self.sensor_names = config.sensor_names
|
||||
self.server_url = config.server_url
|
||||
self.prompt = config.prompt
|
||||
@@ -79,11 +84,15 @@ class StarvlaPolicy(Policy):
|
||||
self.current_chunk_id = 0
|
||||
self.current_chunk_result = None
|
||||
self.run_trunk_size = self.config.run_trunk_size
|
||||
self.drive_joints: dict[str, GripperDriveJointConfig] = SpawnableController.control_robot(self.robot_name, "get_gripper_drive_joints").unwrap()
|
||||
self.robot: ModularRobot = SpawnableController.get_spawnable_data(self.robot_name).unwrap()
|
||||
self.drive_joints: dict[str, GripperDriveJointConfig] = self.robot.get_arm(self.arm_name).get_ee().get_drive_joints()
|
||||
self.robot_drive_name = list(self.drive_joints.keys())[0]
|
||||
|
||||
for joint_name, joint_config in self.drive_joints.items():
|
||||
SpawnableController.control_robot(self.robot_name, "set_joint_stiffness", parameters={"joint_names": [joint_name], "stiffness": joint_config.position_control_stiffness}).unwrap()
|
||||
SpawnableController.control_robot(self.robot_name, "set_joint_damping", parameters={"joint_names": [joint_name], "damping": joint_config.position_control_damping}).unwrap()
|
||||
SpawnableController.control_robot(self.robot_name, "set_joint_effort_limit", parameters={"joint_names": [joint_name], "effort_limit": 50}).unwrap()
|
||||
SpawnableController.control_robot(self.robot_name, "set_joint_effort_limit", parameters={"joint_names": [joint_name], "effort_limit": 5000}).unwrap()
|
||||
SpawnableController.control_robot(self.robot_name, "set_joint_effort_limit", parameters={"joint_names": [self.robot_drive_name], "effort_limit": 5000}).unwrap()
|
||||
self.max_width = float("-inf")
|
||||
self.min_width = float("inf")
|
||||
for entry in self.gripper_width_mapper:
|
||||
@@ -113,11 +122,13 @@ class StarvlaPolicy(Policy):
|
||||
|
||||
def preprocess_observation(self, benchmark_observation: BenchmarkObservation) -> dict:
|
||||
robot_obs = benchmark_observation.get_robot_observations(self.robot_name)["robot_data"]
|
||||
ee_pose_base = robot_obs["ee_pose"]["base_frame"]
|
||||
ee_pose_base = robot_obs["ee_pose"][self.arm_name]["base_frame"]
|
||||
ee_position, ee_rot6d = ee_pose_base["position"],ee_pose_base["rot6d"]
|
||||
normalized_gripper_width = self.__map_joint_position_to_normalized_width(robot_obs["gripper_drive_state"]["position"][0])
|
||||
arm_joint_positions = robot_obs["joint_positions"][:7] # 临时多加了一个drive的位置,现在读的最后一个joint值是drive
|
||||
drive_joint_positions = robot_obs["joint_positions"][-1]
|
||||
normalized_gripper_width = self.__map_joint_position_to_normalized_width(drive_joint_positions)
|
||||
Log.debug(f"input normalized_gripper_width state: {round(normalized_gripper_width, 2)}")
|
||||
state = np.concatenate([ee_position,ee_rot6d,np.array([normalized_gripper_width])])
|
||||
state = np.concatenate([ee_position,ee_rot6d,np.array([normalized_gripper_width]), [0]*10, np.array(arm_joint_positions)])
|
||||
rgb_data = {}
|
||||
for sensor_name in self.sensor_names:
|
||||
sensor_obs = benchmark_observation.get_sensor_observations(sensor_name)
|
||||
@@ -137,6 +148,7 @@ class StarvlaPolicy(Policy):
|
||||
data=payload,
|
||||
headers={"Content-Type": "application/octet-stream"}
|
||||
)
|
||||
self.test_obs = observation["state"] #TODO
|
||||
self._handle_server_error(response)
|
||||
result = pickle.loads(response.content)
|
||||
max_trunk_size = len(result["ee_delta_position_chunks"])
|
||||
@@ -176,7 +188,9 @@ class StarvlaPolicy(Policy):
|
||||
|
||||
def postprocess_action(self, action: dict) -> BenchmarkAction:
|
||||
benchmark_action = BenchmarkAction()
|
||||
|
||||
Log.debug(f"observation: {self.test_obs}")
|
||||
# import ipdb;ipdb.set_trace()
|
||||
|
||||
# get base frame end-effector pose
|
||||
delta_ee_pose = Pose(position=action["ee_delta_position_chunks"][self.current_chunk_id], rot6d=action["ee_delta_rot6d_chunks"][self.current_chunk_id])
|
||||
curr_state_ee_pose = Pose(position=self.current_ee_position_state, rot6d=self.current_ee_rot6d_state)
|
||||
@@ -221,7 +235,7 @@ class StarvlaPolicy(Policy):
|
||||
VisualizeController.create_pose_visualization(
|
||||
robot_base_world * pose_state_base,
|
||||
name=f"{self.robot_name}/starvla_state_ee",
|
||||
simulator=SimulatorType.ISAACSIM,
|
||||
simulator=SimulatorType.ISAACLAB,
|
||||
pose_type=PoseVisualType.COORDINATE,
|
||||
extra_params={"axis_length": 0.08, "thickness": 0.006},
|
||||
).unwrap()
|
||||
@@ -229,7 +243,7 @@ class StarvlaPolicy(Policy):
|
||||
VisualizeController.create_pose_visualization(
|
||||
robot_base_world * pose_action_base,
|
||||
name=f"{self.robot_name}/starvla_action_ee",
|
||||
simulator=SimulatorType.ISAACSIM,
|
||||
simulator=SimulatorType.ISAACLAB,
|
||||
pose_type=PoseVisualType.COORDINATE,
|
||||
extra_params={"axis_length": 0.1, "thickness": 0.006},
|
||||
).unwrap()
|
||||
@@ -239,5 +253,5 @@ class StarvlaPolicy(Policy):
|
||||
return
|
||||
for target_name in self.visualize_bounding_box_targets:
|
||||
VisualizeController.visualize_target_bounding_box(
|
||||
target_name, simulator=SimulatorType.ISAACSIM
|
||||
target_name, simulator=SimulatorType.ISAACLAB
|
||||
).unwrap()
|
||||
Reference in New Issue
Block a user