finshi pull revolute action
This commit is contained in:
@@ -5,20 +5,24 @@ from data_gen_dependencies.action.base import StageTemplate
|
||||
|
||||
class TwistStage(StageTemplate):
|
||||
DELTA_DISTANCE = 0.01 # meter
|
||||
def __init__(self, active_obj_id, passive_obj_id, active_element=None, passive_element=None, target_pose=np.eye(4), extra_params=None, active_obj=None, passive_obj=None, **kwargs):
|
||||
def __init__(self, active_obj_id, passive_obj_id, active_element=None, passive_element=None, target_pose=np.eye(4), extra_params=None, objects = [], **kwargs):
|
||||
super().__init__(active_obj_id, passive_obj_id, active_element, passive_element)
|
||||
self.passive_obj = passive_obj
|
||||
self.active_obj = active_obj
|
||||
self.passive_obj = objects[passive_obj_id]
|
||||
self.active_obj = objects[active_obj_id]
|
||||
self.joint_position_threshold = passive_element.get('joint_position_threshold', 0.7)
|
||||
if self.joint_position_threshold is None:
|
||||
self.joint_position_threshold = 0.7
|
||||
self.correspond_joint_id = passive_element.get('correspond_joint_id', None)
|
||||
correspond_joint_info = passive_obj.joints_info[self.correspond_joint_id]
|
||||
correspond_joint_info = self.passive_obj.joints_info[self.correspond_joint_id]
|
||||
self.joint_lower_limit = correspond_joint_info["lower_bound"]
|
||||
self.joint_upper_limit = correspond_joint_info["upper_bound"]
|
||||
if self.joint_lower_limit is None or self.joint_upper_limit is None:
|
||||
self.twist_degree_range = np.pi
|
||||
else:
|
||||
self.twist_degree_range = abs(self.joint_upper_limit - self.joint_lower_limit)
|
||||
self.joint_axis = correspond_joint_info["joint_axis"]
|
||||
self.joint_type = correspond_joint_info["joint_type"]
|
||||
assert self.joint_type == 'revolute', "joint_type must be revolute for pull_revolute action"
|
||||
self.revolute_radius = passive_element.get('revolute_radius', None)
|
||||
assert self.revolute_radius is not None, "revolute_radius is required for pull_revolute action"
|
||||
assert self.joint_type == 'continuous', "joint_type must be continuous for twist action"
|
||||
if self.joint_position_threshold is None:
|
||||
self.joint_position_threshold = 0.7
|
||||
assert self.joint_position_threshold >= 0 and self.joint_position_threshold <= 1
|
||||
@@ -34,9 +38,20 @@ class TwistStage(StageTemplate):
|
||||
|
||||
def generate_substage(self, target_pose, vector_direction):
|
||||
vector_direction = vector_direction / np.linalg.norm(vector_direction)
|
||||
free_delta_pose = np.eye(4)
|
||||
free_delta_pose[2,3] = -0.03
|
||||
self.sub_stages.append([free_delta_pose, None, np.eye(4), 'local_gripper'])
|
||||
delta_twist_pose = self.axis_angle_to_matrix(np.asarray([0,0,1]), self.twist_degree_range * self.joint_position_threshold * self.joint_direction)
|
||||
import ipdb;ipdb.set_trace()
|
||||
self.sub_stages.append([delta_twist_pose, None, np.eye(4), 'local_gripper'])
|
||||
|
||||
def axis_angle_to_matrix(self, axis, angle):
|
||||
|
||||
axis = axis / np.linalg.norm(axis)
|
||||
K = np.array([[0, -axis[2], axis[1]],
|
||||
[axis[2], 0, -axis[0]],
|
||||
[-axis[1], axis[0], 0]])
|
||||
R = np.eye(3) + np.sin(angle) * K + (1 - np.cos(angle)) * (K @ K)
|
||||
T = np.eye(4)
|
||||
T[:3, :3] = R
|
||||
return T
|
||||
|
||||
def check_completion(self, objects):
|
||||
if self.__len__() == 0:
|
||||
|
||||
Reference in New Issue
Block a user