adapt for dexterous hands
This commit is contained in:
@@ -4,31 +4,28 @@ import json
|
||||
import numpy as np
|
||||
import requests
|
||||
|
||||
from joysim.annotations.config_class import configclass, field
|
||||
from joysim.annotations.stereotype import stereotype
|
||||
from joysim.controllers.spawnable_controller import SpawnableController
|
||||
from joysim.controllers.visualize_controller import VisualizeController
|
||||
from joysim.unisim.robots.models.modular_robot import ModularRobot
|
||||
from joysim.utils.namespace import PoseVisualType, SimulatorType
|
||||
from joysim.unisim.robots.actuator_configs.grippers import GripperDriveJointConfig
|
||||
from joysim.extensions.benchmark.action import RobotAction
|
||||
from joysim.extensions.benchmark.benchmark import (
|
||||
from fastsim.annotations.config_class import configclass, field
|
||||
from fastsim.annotations.stereotype import stereotype
|
||||
from fastsim.controllers.spawnable_controller import SpawnableController
|
||||
from fastsim.controllers.visualize_controller import VisualizeController
|
||||
from fastsim.unisim.robots.models.modular_robot import ModularRobot
|
||||
from fastsim.utils.namespace import PoseVisualType, SimulatorType
|
||||
from fastsim.unisim.robots.actuator_configs.grippers import GripperDriveJointConfig
|
||||
from fastsim.extensions.benchmark.action import RobotAction
|
||||
from fastsim.extensions.benchmark.benchmark import (
|
||||
BenchmarkAction,
|
||||
BenchmarkObservation,
|
||||
ControlMode,
|
||||
)
|
||||
from joysim.extensions.benchmark.policy import Policy, PolicyConfig
|
||||
from joysim.utils.log import Log
|
||||
from joysim.utils.pose import Pose
|
||||
from fastsim.extensions.benchmark.policy import Policy, PolicyConfig
|
||||
from fastsim.utils.log import Log
|
||||
from fastsim.utils.pose import Pose
|
||||
|
||||
@configclass
|
||||
@stereotype.register_config("starvla")
|
||||
class StarvlaPolicyConfig(PolicyConfig):
|
||||
|
||||
robot_name: str = field(default="None", required=True, comment="The name of the robot")
|
||||
arm_name: str = field(default="main_arm", required=True, comment="The name of the arm module to control")
|
||||
drive_name: str = field(default="robotiq_85_left_knuckle_joint", required=True, comment="The name of the drive module to control")
|
||||
# gripper_width_mapper_file: str = field(default="", required=True, comment="The file path to the gripper width mapper")
|
||||
visualize_action_ee_pose: bool = field(default=False, required=True, comment="Whether to visualize the action end effector pose")
|
||||
visualize_state_ee_pose: bool = field(default=False, required=True, comment="Whether to visualize the state end effector pose")
|
||||
visualize_bounding_box_targets: list[str] = field(
|
||||
@@ -67,39 +64,29 @@ class StarvlaPolicy(Policy):
|
||||
super().__init__(config)
|
||||
|
||||
self.robot_name = config.robot_name
|
||||
self.arm_name = config.arm_name
|
||||
self.drive_name = config.drive_name
|
||||
self.sensor_names = config.sensor_names
|
||||
self.server_url = config.server_url
|
||||
self.prompt = config.prompt
|
||||
# self.gripper_width_mapper = json.load(open(config.gripper_width_mapper_file, "r"))
|
||||
self.visualize_action_ee_pose = config.visualize_action_ee_pose
|
||||
self.visualize_state_ee_pose = config.visualize_state_ee_pose
|
||||
self.visualize_bounding_box_targets = list(config.visualize_bounding_box_targets or [])
|
||||
|
||||
def reset(self) -> None:
|
||||
self.current_ee_position_state = None
|
||||
self.current_ee_rot6d_state = None
|
||||
self.current_gripper_width = None
|
||||
self.current_state = {}
|
||||
self.current_chunk_id = 0
|
||||
self.current_chunk_result = None
|
||||
self.run_trunk_size = self.config.run_trunk_size
|
||||
self.robot: ModularRobot = SpawnableController.get_spawnable_data(self.robot_name).unwrap()
|
||||
# self.robot: ModularRobot = SpawnableController.get_spawnable(self.robot_name)
|
||||
# self.drive_joints: dict[str, GripperDriveJointConfig] = self.robot.get_arm(self.arm_name).get_ee().get_drive_joints()
|
||||
# self.robot_drive_name = list(self.drive_joints.keys())[0]
|
||||
self.robot_drive_name = self.drive_name
|
||||
|
||||
# for joint_name, joint_config in self.drive_joints.items():
|
||||
# SpawnableController.control_robot(self.robot_name, "set_joint_stiffness", parameters={"joint_names": [joint_name], "stiffness": joint_config.position_control_stiffness}).unwrap()
|
||||
# SpawnableController.control_robot(self.robot_name, "set_joint_damping", parameters={"joint_names": [joint_name], "damping": joint_config.position_control_damping}).unwrap()
|
||||
# SpawnableController.control_robot(self.robot_name, "set_joint_effort_limit", parameters={"joint_names": [joint_name], "effort_limit": 5000}).unwrap()
|
||||
# SpawnableController.control_robot(self.robot_name, "set_joint_effort_limit", parameters={"joint_names": [self.robot_drive_name], "effort_limit": 5000}).unwrap()
|
||||
self.max_width = float("-inf")
|
||||
self.min_width = float("inf")
|
||||
# for entry in self.gripper_width_mapper:
|
||||
# self.max_width = max(self.max_width, entry["width"])
|
||||
# self.min_width = min(self.min_width, entry["width"])
|
||||
self.left_hand_joints = SpawnableController.control_robot(
|
||||
self.robot_name,
|
||||
"get_actuator_joint_names",
|
||||
parameters={"actuator_name": "left_hand"},
|
||||
).unwrap()
|
||||
self.right_hand_joints = SpawnableController.control_robot(
|
||||
self.robot_name,
|
||||
"get_actuator_joint_names",
|
||||
parameters={"actuator_name": "right_hand"},
|
||||
).unwrap()
|
||||
|
||||
def warmup(self, benchmark_observation: BenchmarkObservation) -> None:
|
||||
Log.info(f"Waiting for StarVLA inference server to be ready...")
|
||||
@@ -122,28 +109,41 @@ class StarvlaPolicy(Policy):
|
||||
elif response.status_code != 200:
|
||||
Log.error(f"StarVLA server error with status code <{response.status_code}> : {response.text}", exit=True)
|
||||
|
||||
def split_joints(self, state_or_action, keys=None) -> list[dict]:
|
||||
if keys is None:
|
||||
keys = ["left_arm", "right_arm"]
|
||||
total_dim = 31 * len(keys)
|
||||
assert state_or_action.shape[-1] == total_dim, f"Expected last dimension to be {total_dim}, got {state_or_action.shape[-1]}"
|
||||
joints_all = np.split(state_or_action, [31], axis=-1)
|
||||
return_dict = {}
|
||||
for key, joints in zip(keys, joints_all):
|
||||
ee_pos, ee_rot6d, finger_qpos = np.split(joints, [3, 9], axis=-1)
|
||||
return_dict[key] = {
|
||||
"ee_pos": ee_pos,
|
||||
"ee_rot6d": ee_rot6d,
|
||||
"finger_qpos": finger_qpos
|
||||
}
|
||||
return return_dict
|
||||
|
||||
def preprocess_observation(self, benchmark_observation: BenchmarkObservation) -> dict:
|
||||
robot_obs = benchmark_observation.get_robot_observations(self.robot_name)["robot_data"]
|
||||
ee_pose_base = robot_obs["ee_pose"][self.arm_name]["base_frame"]
|
||||
ee_position, ee_rot6d = ee_pose_base["position"],ee_pose_base["rot6d"]
|
||||
arm_joint_positions = robot_obs["joint_positions"][:7] # 临时多加了一个drive的位置,现在读的最后一个joint值是drive
|
||||
drive_joint_positions = robot_obs["joint_positions"][-1]
|
||||
normalized_gripper_width = self.__map_joint_position_to_normalized_width(drive_joint_positions)
|
||||
Log.debug(f"input normalized_gripper_width state: {round(normalized_gripper_width, 2)}")
|
||||
state = np.concatenate([ee_position,ee_rot6d,np.array([normalized_gripper_width]), [0]*10, np.array(arm_joint_positions)])
|
||||
left_ee_pose_base = robot_obs["ee_pose"]["left_arm"]["base_frame"]
|
||||
left_ee_position, left_ee_rot6d = left_ee_pose_base["position"], left_ee_pose_base["rot6d"]
|
||||
right_ee_pose_base = robot_obs["ee_pose"]["right_arm"]["base_frame"]
|
||||
right_ee_position, right_ee_rot6d = right_ee_pose_base["position"], right_ee_pose_base["rot6d"]
|
||||
finger_positions = robot_obs["joint_positions"] # use finger joints(44) only
|
||||
state = np.concatenate([left_ee_position, left_ee_rot6d, finger_positions[:22],
|
||||
right_ee_position, right_ee_rot6d, finger_positions[22:]], axis=-1) # (62,)
|
||||
rgb_data = {}
|
||||
for sensor_name in self.sensor_names:
|
||||
sensor_obs = benchmark_observation.get_sensor_observations(sensor_name)
|
||||
rgb_data[sensor_name] = sensor_obs["rgb"].data.cpu().numpy().astype(np.uint8)
|
||||
obs = {"state": state,"rgb": rgb_data,"prompt": self.prompt}
|
||||
|
||||
return obs
|
||||
|
||||
def compute_action(self, observation: dict) -> dict:
|
||||
if self.current_chunk_result is None:
|
||||
self.current_ee_position_state = np.array(observation["state"][:3]).astype(np.float64)
|
||||
self.current_ee_rot6d_state = np.array(observation["state"][3:9]).astype(np.float64)
|
||||
self.current_gripper_width = np.array([observation["state"][9]])
|
||||
self.current_state.update(self.split_joints(observation["state"]))
|
||||
payload = pickle.dumps(observation)
|
||||
response = requests.post(
|
||||
f"{self.server_url}/inference",
|
||||
@@ -153,7 +153,7 @@ class StarvlaPolicy(Policy):
|
||||
self.test_obs = observation["state"] #TODO
|
||||
self._handle_server_error(response)
|
||||
result = pickle.loads(response.content)
|
||||
max_trunk_size = len(result["ee_delta_position_chunks"])
|
||||
max_trunk_size = len(result["right_arm"]["ee_delta_position_chunks"])
|
||||
if self.run_trunk_size > max_trunk_size:
|
||||
Log.warning(f"Run trunk size {self.run_trunk_size} is greater than the number of chunks {max_trunk_size}. Set run trunk size to {max_trunk_size}.")
|
||||
self.run_trunk_size = max_trunk_size
|
||||
@@ -163,59 +163,33 @@ class StarvlaPolicy(Policy):
|
||||
result = self.current_chunk_result
|
||||
|
||||
return result
|
||||
def __map_joint_position_to_normalized_width(self, joint_position: float) -> float:
|
||||
if joint_position < 0:
|
||||
joint_position = 0
|
||||
if joint_position > 0.8:
|
||||
joint_position = 0.8
|
||||
# for entry in self.gripper_width_mapper:
|
||||
# if round(entry["angel"], 2) == round(joint_position, 2):
|
||||
# return 1-(entry["width"] - self.min_width) / (self.max_width - self.min_width)
|
||||
|
||||
|
||||
|
||||
def __map_gripper_joint_position(self, normalized_gripper_width: float) -> float:
|
||||
|
||||
joint_positions = []
|
||||
joint_names = []
|
||||
if normalized_gripper_width > 0.5:
|
||||
for joint_name, joint_config in self.drive_joints.items():
|
||||
joint_positions.append(joint_config.close_position)
|
||||
joint_names.append(joint_name)
|
||||
else:
|
||||
for joint_name, joint_config in self.drive_joints.items():
|
||||
joint_positions.append(joint_config.open_position)
|
||||
joint_names.append(joint_name)
|
||||
return joint_positions, joint_names
|
||||
|
||||
def postprocess_action(self, action: dict) -> BenchmarkAction:
|
||||
benchmark_action = BenchmarkAction()
|
||||
Log.debug(f"observation: {self.test_obs}")
|
||||
# import ipdb;ipdb.set_trace()
|
||||
|
||||
# get base frame end-effector pose
|
||||
delta_ee_pose = Pose(position=action["ee_delta_position_chunks"][self.current_chunk_id], rot6d=action["ee_delta_rot6d_chunks"][self.current_chunk_id])
|
||||
curr_state_ee_pose = Pose(position=self.current_ee_position_state, rot6d=self.current_ee_rot6d_state)
|
||||
curr_action_ee_pose = curr_state_ee_pose * delta_ee_pose # action2base = state2base * action2state
|
||||
curr_action_gripper_width = action["gripper_width_chunks"][self.current_chunk_id]
|
||||
|
||||
gripper_joint_positions, gripper_joint_names = self.__map_gripper_joint_position(curr_action_gripper_width[0])
|
||||
Log.debug(f"action_gripper_joint_positions: {gripper_joint_positions}, action_normalized_gripper_width: {round(curr_action_gripper_width[0], 2)}")
|
||||
benchmark_action.add_robot_action(
|
||||
RobotAction(
|
||||
control_mode=ControlMode.POSITION,
|
||||
robot_name=self.robot_name,
|
||||
joint_names=gripper_joint_names,
|
||||
joint_positions=gripper_joint_positions
|
||||
for arm_key in self.robot['arms'].keys():
|
||||
action_arm = action[arm_key]
|
||||
delta_ee_pose = Pose(position=action_arm["ee_delta_position_chunks"][self.current_chunk_id], rot6d=action_arm["ee_delta_rot6d_chunks"][self.current_chunk_id])
|
||||
curr_state_ee_pose = Pose(position=self.current_state[arm_key]["ee_pos"], rot6d=self.current_state[arm_key]["ee_rot6d"])
|
||||
curr_action_ee_pose = curr_state_ee_pose * delta_ee_pose # action2base = state2base * action2state
|
||||
finger_joint_qpos = action_arm["finger_chunks"][self.current_chunk_id] + self.current_state[arm_key]["finger_qpos"]
|
||||
joint_names = self.left_hand_joints if arm_key == "left_arm" else self.right_hand_joints
|
||||
benchmark_action.add_robot_action(
|
||||
RobotAction(
|
||||
control_mode=ControlMode.POSITION,
|
||||
robot_name=self.robot_name,
|
||||
joint_names=joint_names,
|
||||
joint_positions=finger_joint_qpos
|
||||
)
|
||||
)
|
||||
)
|
||||
benchmark_action.add_robot_action(
|
||||
RobotAction(
|
||||
control_mode=ControlMode.EE_POSE,
|
||||
robot_name=self.robot_name,
|
||||
ee_pose=curr_action_ee_pose
|
||||
benchmark_action.add_robot_action(
|
||||
RobotAction(
|
||||
control_mode=ControlMode.EE_POSE,
|
||||
robot_name=self.robot_name,
|
||||
ee_pose=curr_action_ee_pose,
|
||||
arm_name=arm_key
|
||||
)
|
||||
)
|
||||
)
|
||||
self._visualize_base_frame_ee_poses(curr_state_ee_pose, curr_action_ee_pose)
|
||||
self._visualize_bounding_boxes()
|
||||
self.current_chunk_id += 1
|
||||
|
||||
Reference in New Issue
Block a user