solve dependencies problem
This commit is contained in:
		
							
								
								
									
										110
									
								
								data_gen_dependencies/action/base.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										110
									
								
								data_gen_dependencies/action/base.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,110 @@ | ||||
| import copy | ||||
| from collections import deque | ||||
|  | ||||
| import numpy as np | ||||
| from data_gen_dependencies.data_utils import pose_difference | ||||
|  | ||||
|  | ||||
| def simple_check_completion(goal, objects, last_statement=None, pos_threshold=0.06, angle_threshold=70, is_grasped=False): | ||||
|     active_obj_id, passive_obj_id, target_pose_canonical, gripper_action, transform_world, motion_type = goal | ||||
|     if target_pose_canonical is None: | ||||
|         return True | ||||
|     if gripper_action=='open': | ||||
|         return True | ||||
|  | ||||
|     current_pose_world = objects[active_obj_id].obj_pose | ||||
|     if len(target_pose_canonical.shape)==3: | ||||
|         target_pose_canonical = target_pose_canonical[-1] | ||||
|         transform_world = transform_world[-1] | ||||
|     target_pose_world = objects[passive_obj_id].obj_pose @ target_pose_canonical | ||||
|     if not is_grasped: | ||||
|         target_pose_world = np.dot(transform_world, target_pose_world) | ||||
|      | ||||
|     pos_diff, angle_diff = pose_difference(current_pose_world, target_pose_world) | ||||
|     success = (pos_diff < pos_threshold) and (angle_diff < angle_threshold) | ||||
|     return success | ||||
|  | ||||
|  | ||||
| def solve_target_gripper_pose(stage, objects): | ||||
|     active_obj_ID, passive_obj_ID, target_pose_canonical, gripper_action, transform_world, motion_type = stage | ||||
|      | ||||
|     anchor_pose = objects[passive_obj_ID].obj_pose | ||||
|      | ||||
|      | ||||
|     if motion_type=='Trajectory': | ||||
|         assert len(target_pose_canonical.shape)==3, 'The target_pose should be a list of poses' | ||||
|         target_pose = anchor_pose[np.newaxis, ...] @ target_pose_canonical | ||||
|         target_pose = transform_world @ target_pose | ||||
|     else: | ||||
|         target_pose = anchor_pose @ target_pose_canonical | ||||
|         target_pose = transform_world @ target_pose | ||||
|     assert 'gripper' in objects, 'The gripper should be the first one in the object list' | ||||
|     current_gripper_pose = objects['gripper'].obj_pose | ||||
|      | ||||
|     if active_obj_ID=='gripper': | ||||
|         target_gripper_pose = target_pose | ||||
|     else: | ||||
|         current_obj_pose = objects[active_obj_ID].obj_pose | ||||
|         gripper2obj = np.linalg.inv(current_obj_pose) @ current_gripper_pose | ||||
|         if len(target_pose.shape)==3: | ||||
|             gripper2obj = gripper2obj[np.newaxis, ...] | ||||
|          | ||||
|         target_obj_pose = target_pose | ||||
|         target_gripper_pose = target_obj_pose @ gripper2obj | ||||
|  | ||||
|     return target_gripper_pose | ||||
|  | ||||
|  | ||||
|  | ||||
| class StageTemplate: | ||||
|     def __init__(self, active_obj_id, passive_obj_id, active_element, passive_element): | ||||
|         self.active_obj_id = active_obj_id | ||||
|         self.passive_obj_id = passive_obj_id | ||||
|         self.active_element = active_element | ||||
|         self.passive_element = passive_element | ||||
|  | ||||
|         self.last_statement = None | ||||
|         self.sub_stages = deque() | ||||
|         self.step_id = 0 | ||||
|         self.extra_params = {} | ||||
|  | ||||
|     def generate_substage(self): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def __len__(self) -> int: | ||||
|         return len(self.sub_stages) - self.step_id | ||||
|  | ||||
|     def get_action(self, objects): | ||||
|         if self.__len__()==0: | ||||
|             return None | ||||
|         gripper_pose_canonical, gripper_action, transform_world, motion_type = self.sub_stages[self.step_id] | ||||
|  | ||||
|         if motion_type == 'local_gripper': | ||||
|             delta_pose = gripper_pose_canonical | ||||
|             gripper_pose = objects['gripper'].obj_pose | ||||
|             target_gripper_pose = gripper_pose @ delta_pose | ||||
|             motion_type = 'Straight' | ||||
|         else: | ||||
|  | ||||
|             if gripper_pose_canonical is None: | ||||
|                 target_gripper_pose = None | ||||
|             else:  | ||||
|                 goal_datapack = [self.active_obj_id, self.passive_obj_id] + self.sub_stages[self.step_id] | ||||
|                 target_gripper_pose = solve_target_gripper_pose(goal_datapack, objects) | ||||
|              | ||||
|             last_statement = {'objects': copy.deepcopy(objects), 'target_gripper_pose': target_gripper_pose} | ||||
|             self.last_statement = last_statement | ||||
|         return target_gripper_pose, motion_type, gripper_action, self.extra_params.get('arm', 'right') | ||||
|      | ||||
|  | ||||
|     def check_completion(self, objects): | ||||
|         if self.__len__()==0: | ||||
|             return True | ||||
|         goal_datapack = [self.active_obj_id, self.passive_obj_id] + self.sub_stages[self.step_id] | ||||
|         succ = simple_check_completion(goal_datapack, objects) | ||||
|         if succ: | ||||
|             self.step_id += 1 | ||||
|         return succ | ||||
|          | ||||
|  | ||||
|  | ||||
		Reference in New Issue
	
	Block a user