Files
gen_data_agent/data_gen_dependencies/action/base.py
2025-09-05 15:49:00 +08:00

111 lines
4.1 KiB
Python

import copy
from collections import deque
import numpy as np
from data_gen_dependencies.data_utils import pose_difference
def simple_check_completion(goal, objects, last_statement=None, pos_threshold=0.06, angle_threshold=70, is_grasped=False):
active_obj_id, passive_obj_id, target_pose_canonical, gripper_action, transform_world, motion_type = goal
if target_pose_canonical is None:
return True
if gripper_action=='open':
return True
current_pose_world = objects[active_obj_id].obj_pose
if len(target_pose_canonical.shape)==3:
target_pose_canonical = target_pose_canonical[-1]
transform_world = transform_world[-1]
target_pose_world = objects[passive_obj_id].obj_pose @ target_pose_canonical
if not is_grasped:
target_pose_world = np.dot(transform_world, target_pose_world)
pos_diff, angle_diff = pose_difference(current_pose_world, target_pose_world)
success = (pos_diff < pos_threshold) and (angle_diff < angle_threshold)
return success
def solve_target_gripper_pose(stage, objects):
active_obj_ID, passive_obj_ID, target_pose_canonical, gripper_action, transform_world, motion_type = stage
anchor_pose = objects[passive_obj_ID].obj_pose
if motion_type=='Trajectory':
assert len(target_pose_canonical.shape)==3, 'The target_pose should be a list of poses'
target_pose = anchor_pose[np.newaxis, ...] @ target_pose_canonical
target_pose = transform_world @ target_pose
else:
target_pose = anchor_pose @ target_pose_canonical
target_pose = transform_world @ target_pose
assert 'gripper' in objects, 'The gripper should be the first one in the object list'
current_gripper_pose = objects['gripper'].obj_pose
if active_obj_ID=='gripper':
target_gripper_pose = target_pose
else:
current_obj_pose = objects[active_obj_ID].obj_pose
gripper2obj = np.linalg.inv(current_obj_pose) @ current_gripper_pose
if len(target_pose.shape)==3:
gripper2obj = gripper2obj[np.newaxis, ...]
target_obj_pose = target_pose
target_gripper_pose = target_obj_pose @ gripper2obj
return target_gripper_pose
class StageTemplate:
def __init__(self, active_obj_id, passive_obj_id, active_element, passive_element):
self.active_obj_id = active_obj_id
self.passive_obj_id = passive_obj_id
self.active_element = active_element
self.passive_element = passive_element
self.last_statement = None
self.sub_stages = deque()
self.step_id = 0
self.extra_params = {}
def generate_substage(self):
raise NotImplementedError
def __len__(self) -> int:
return len(self.sub_stages) - self.step_id
def get_action(self, objects):
if self.__len__()==0:
return None
gripper_pose_canonical, gripper_action, transform_world, motion_type = self.sub_stages[self.step_id]
if motion_type == 'local_gripper':
delta_pose = gripper_pose_canonical
gripper_pose = objects['gripper'].obj_pose
target_gripper_pose = gripper_pose @ delta_pose
motion_type = 'Straight'
else:
if gripper_pose_canonical is None:
target_gripper_pose = None
else:
goal_datapack = [self.active_obj_id, self.passive_obj_id] + self.sub_stages[self.step_id]
target_gripper_pose = solve_target_gripper_pose(goal_datapack, objects)
last_statement = {'objects': copy.deepcopy(objects), 'target_gripper_pose': target_gripper_pose}
self.last_statement = last_statement
return target_gripper_pose, motion_type, gripper_action, self.extra_params.get('arm', 'right')
def check_completion(self, objects):
if self.__len__()==0:
return True
goal_datapack = [self.active_obj_id, self.passive_obj_id] + self.sub_stages[self.step_id]
succ = simple_check_completion(goal_datapack, objects)
if succ:
self.step_id += 1
return succ