111 lines
4.1 KiB
Python
111 lines
4.1 KiB
Python
import copy
|
|
from collections import deque
|
|
|
|
import numpy as np
|
|
from data_gen_dependencies.data_utils import pose_difference
|
|
|
|
|
|
def simple_check_completion(goal, objects, last_statement=None, pos_threshold=0.06, angle_threshold=70, is_grasped=False):
|
|
active_obj_id, passive_obj_id, target_pose_canonical, gripper_action, transform_world, motion_type = goal
|
|
if target_pose_canonical is None:
|
|
return True
|
|
if gripper_action=='open':
|
|
return True
|
|
|
|
current_pose_world = objects[active_obj_id].obj_pose
|
|
if len(target_pose_canonical.shape)==3:
|
|
target_pose_canonical = target_pose_canonical[-1]
|
|
transform_world = transform_world[-1]
|
|
target_pose_world = objects[passive_obj_id].obj_pose @ target_pose_canonical
|
|
if not is_grasped:
|
|
target_pose_world = np.dot(transform_world, target_pose_world)
|
|
|
|
pos_diff, angle_diff = pose_difference(current_pose_world, target_pose_world)
|
|
success = (pos_diff < pos_threshold) and (angle_diff < angle_threshold)
|
|
return success
|
|
|
|
|
|
def solve_target_gripper_pose(stage, objects):
|
|
active_obj_ID, passive_obj_ID, target_pose_canonical, gripper_action, transform_world, motion_type = stage
|
|
|
|
anchor_pose = objects[passive_obj_ID].obj_pose
|
|
|
|
|
|
if motion_type=='Trajectory':
|
|
assert len(target_pose_canonical.shape)==3, 'The target_pose should be a list of poses'
|
|
target_pose = anchor_pose[np.newaxis, ...] @ target_pose_canonical
|
|
target_pose = transform_world @ target_pose
|
|
else:
|
|
target_pose = anchor_pose @ target_pose_canonical
|
|
target_pose = transform_world @ target_pose
|
|
assert 'gripper' in objects, 'The gripper should be the first one in the object list'
|
|
current_gripper_pose = objects['gripper'].obj_pose
|
|
|
|
if active_obj_ID=='gripper':
|
|
target_gripper_pose = target_pose
|
|
else:
|
|
current_obj_pose = objects[active_obj_ID].obj_pose
|
|
gripper2obj = np.linalg.inv(current_obj_pose) @ current_gripper_pose
|
|
if len(target_pose.shape)==3:
|
|
gripper2obj = gripper2obj[np.newaxis, ...]
|
|
|
|
target_obj_pose = target_pose
|
|
target_gripper_pose = target_obj_pose @ gripper2obj
|
|
|
|
return target_gripper_pose
|
|
|
|
|
|
|
|
class StageTemplate:
|
|
def __init__(self, active_obj_id, passive_obj_id, active_element, passive_element):
|
|
self.active_obj_id = active_obj_id
|
|
self.passive_obj_id = passive_obj_id
|
|
self.active_element = active_element
|
|
self.passive_element = passive_element
|
|
|
|
self.last_statement = None
|
|
self.sub_stages = deque()
|
|
self.step_id = 0
|
|
self.extra_params = {}
|
|
|
|
def generate_substage(self):
|
|
raise NotImplementedError
|
|
|
|
def __len__(self) -> int:
|
|
return len(self.sub_stages) - self.step_id
|
|
|
|
def get_action(self, objects):
|
|
if self.__len__()==0:
|
|
return None
|
|
gripper_pose_canonical, gripper_action, transform_world, motion_type = self.sub_stages[self.step_id]
|
|
|
|
if motion_type == 'local_gripper':
|
|
delta_pose = gripper_pose_canonical
|
|
gripper_pose = objects['gripper'].obj_pose
|
|
target_gripper_pose = gripper_pose @ delta_pose
|
|
motion_type = 'Straight'
|
|
else:
|
|
|
|
if gripper_pose_canonical is None:
|
|
target_gripper_pose = None
|
|
else:
|
|
goal_datapack = [self.active_obj_id, self.passive_obj_id] + self.sub_stages[self.step_id]
|
|
target_gripper_pose = solve_target_gripper_pose(goal_datapack, objects)
|
|
|
|
last_statement = {'objects': copy.deepcopy(objects), 'target_gripper_pose': target_gripper_pose}
|
|
self.last_statement = last_statement
|
|
return target_gripper_pose, motion_type, gripper_action, self.extra_params.get('arm', 'right')
|
|
|
|
|
|
def check_completion(self, objects):
|
|
if self.__len__()==0:
|
|
return True
|
|
goal_datapack = [self.active_obj_id, self.passive_obj_id] + self.sub_stages[self.step_id]
|
|
succ = simple_check_completion(goal_datapack, objects)
|
|
if succ:
|
|
self.step_id += 1
|
|
return succ
|
|
|
|
|
|
|