110 lines
		
	
	
		
			4.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			110 lines
		
	
	
		
			4.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import copy
 | |
| from collections import deque
 | |
| 
 | |
| import numpy as np
 | |
| from data_gen_dependencies.data_utils import pose_difference
 | |
| 
 | |
| 
 | |
| def simple_check_completion(goal, objects, last_statement=None, pos_threshold=0.06, angle_threshold=70, is_grasped=False):
 | |
|     active_obj_id, passive_obj_id, target_pose_canonical, gripper_action, transform_world, motion_type = goal
 | |
|     if target_pose_canonical is None:
 | |
|         return True
 | |
|     if gripper_action=='open':
 | |
|         return True
 | |
| 
 | |
|     current_pose_world = objects[active_obj_id].obj_pose
 | |
|     if len(target_pose_canonical.shape)==3:
 | |
|         target_pose_canonical = target_pose_canonical[-1]
 | |
|         transform_world = transform_world[-1]
 | |
|     target_pose_world = objects[passive_obj_id].obj_pose @ target_pose_canonical
 | |
|     if not is_grasped:
 | |
|         target_pose_world = np.dot(transform_world, target_pose_world)
 | |
|     
 | |
|     pos_diff, angle_diff = pose_difference(current_pose_world, target_pose_world)
 | |
|     success = (pos_diff < pos_threshold) and (angle_diff < angle_threshold)
 | |
|     return success
 | |
| 
 | |
| 
 | |
| def solve_target_gripper_pose(stage, objects):
 | |
|     active_obj_ID, passive_obj_ID, target_pose_canonical, gripper_action, transform_world, motion_type = stage
 | |
|     anchor_pose = objects[passive_obj_ID].obj_pose
 | |
|     
 | |
|     
 | |
|     if motion_type=='Trajectory':
 | |
|         assert len(target_pose_canonical.shape)==3, 'The target_pose should be a list of poses'
 | |
|         target_pose = anchor_pose[np.newaxis, ...] @ target_pose_canonical
 | |
|         target_pose = transform_world @ target_pose
 | |
|     else:
 | |
|         target_pose = anchor_pose @ target_pose_canonical
 | |
|         target_pose = transform_world @ target_pose
 | |
|     assert 'gripper' in objects, 'The gripper should be the first one in the object list'
 | |
|     current_gripper_pose = objects['gripper'].obj_pose
 | |
|     
 | |
|     if active_obj_ID=='gripper':
 | |
|         target_gripper_pose = target_pose
 | |
|     else:
 | |
|         current_obj_pose = objects[active_obj_ID].obj_pose
 | |
|         gripper2obj = np.linalg.inv(current_obj_pose) @ current_gripper_pose
 | |
|         if len(target_pose.shape)==3:
 | |
|             gripper2obj = gripper2obj[np.newaxis, ...]
 | |
|         
 | |
|         target_obj_pose = target_pose
 | |
|         target_gripper_pose = target_obj_pose @ gripper2obj
 | |
| 
 | |
|     return target_gripper_pose
 | |
| 
 | |
| 
 | |
| 
 | |
| class StageTemplate:
 | |
|     def __init__(self, active_obj_id, passive_obj_id, active_element, passive_element):
 | |
|         self.active_obj_id = active_obj_id
 | |
|         self.passive_obj_id = passive_obj_id
 | |
|         self.active_element = active_element
 | |
|         self.passive_element = passive_element
 | |
| 
 | |
|         self.last_statement = None
 | |
|         self.sub_stages = deque()
 | |
|         self.step_id = 0
 | |
|         self.extra_params = {}
 | |
| 
 | |
|     def generate_substage(self):
 | |
|         raise NotImplementedError
 | |
| 
 | |
|     def __len__(self) -> int:
 | |
|         return len(self.sub_stages) - self.step_id
 | |
| 
 | |
|     def get_action(self, objects):
 | |
|         if self.__len__()==0:
 | |
|             return None
 | |
|         gripper_pose_canonical, gripper_action, transform_world, motion_type = self.sub_stages[self.step_id]
 | |
| 
 | |
|         if motion_type == 'local_gripper':
 | |
|             delta_pose = gripper_pose_canonical
 | |
|             gripper_pose = objects['gripper'].obj_pose
 | |
|             target_gripper_pose = gripper_pose @ delta_pose
 | |
|             motion_type = 'Straight'
 | |
|         else:
 | |
| 
 | |
|             if gripper_pose_canonical is None:
 | |
|                 target_gripper_pose = None
 | |
|             else: 
 | |
|                 goal_datapack = [self.active_obj_id, self.passive_obj_id] + self.sub_stages[self.step_id]
 | |
|                 target_gripper_pose = solve_target_gripper_pose(goal_datapack, objects)
 | |
|             
 | |
|             last_statement = {'objects': copy.deepcopy(objects), 'target_gripper_pose': target_gripper_pose}
 | |
|             self.last_statement = last_statement
 | |
|         return target_gripper_pose, motion_type, gripper_action, self.extra_params.get('arm', 'right')
 | |
|     
 | |
| 
 | |
|     def check_completion(self, objects):
 | |
|         if self.__len__()==0:
 | |
|             return True
 | |
|         goal_datapack = [self.active_obj_id, self.passive_obj_id] + self.sub_stages[self.step_id]
 | |
|         succ = simple_check_completion(goal_datapack, objects)
 | |
|         if succ:
 | |
|             self.step_id += 1
 | |
|         return succ
 | |
|         
 | |
| 
 | |
| 
 |