finish task_gen

This commit is contained in:
2025-09-05 11:10:42 +08:00
parent da022d4f83
commit 4e51158215
17 changed files with 82 additions and 50 deletions

View File

@@ -0,0 +1,289 @@
import numpy as np
import pickle
import json
import os
from task_gen_dependencies.transforms import transform_coordinates_3d
from grasp_nms import nms_grasp
class OmniObject:
def __init__(self,
name='obj',
cam_info=None,
type='Active',
mask=None,
pose=np.eye(4),
size=np.array([0.001, 0.001, 0.001])
):
self.name = name
self.cam_info = cam_info
self.type = type
self.obj_pose = pose
self.obj_length = size
self.xyz = np.array([0,0,0])
self.direction = np.array([0,0,0.05])
self.direction_proposals = None
self.elements = {}
if name=='gripper':
self.elements['active'] = {
'push': [
{
"part": "finger edge",
"task": "forward push",
"xyz": np.array([0,0,0]),
"direction": np.array([0,0,0.08]),
},
{
"part": "finger edge",
"task": "side push",
"xyz": np.array([0,0,0]),
"direction": np.array([1,0,0.3]),
},
{
"part": "finger edge",
"task": "side push",
"xyz": np.array([0,0,0]),
"direction": np.array([-1,0,0.3]),
},
],
'click': {
"part": "finger edge",
"task": "click",
"xyz": np.array([0,0,0]),
"direction": np.array([0,0,1]),
},
'touch': {
"part": "finger edge",
"task": "touch",
"xyz": np.array([0,0,0]),
"direction": np.array([0,0,1]),
},
'pull': {
"part": "finger edge",
"task": "pull",
"xyz": np.array([0,0,0]),
"direction": np.array([0,0,0.08]),
},
'rotate': {
"part": "finger edge",
"task": "rotate",
"xyz": np.array([0,0,0]),
"direction": np.array([0,0,0.08]),
}
}
@classmethod
def from_obj_dir(cls, obj_dir, obj_info=None):
if "interaction" in obj_info:
obj_info = obj_info
interaction_info=obj_info["interaction"]
part_joint_limits_info = obj_info.get("part_joint_limits", None)
else:
obj_info_file = '%s/object_parameters.json'%obj_dir
interaction_label_file = '%s/interaction.json'%obj_dir
assert os.path.exists(obj_info_file), 'object_parameters.json not found in %s'%obj_dir
assert os.path.exists(interaction_label_file), 'interaction.json not found in %s'%obj_dir
obj_info = json.load(open(obj_info_file))
interaction_data = json.load(open(interaction_label_file))
interaction_info = interaction_data['interaction']
part_joint_limits_info = interaction_data.get("part_joint_limits", None)
obj = cls(
name=obj_info['object_id'],
size=obj_info['size']
)
mesh_file = '%s/Aligned.obj'%obj_dir
if os.path.exists(mesh_file):
obj_info['mesh_file'] = mesh_file
obj.part_joint_limits = part_joint_limits_info
''' Load interaction labels '''
obj.part_ids = []
for type in ['active', 'passive']:
if type not in interaction_info:
continue
for action in interaction_info[type]:
action_info = interaction_info[type][action]
# import ipdb;ipdb.set_trace()
if action=='grasp' and type=='passive':
for grasp_part in action_info:
grasp_files = action_info[grasp_part]
grasp_data={
"grasp_pose": [],
"width": []
}
if isinstance(grasp_files, str):
grasp_files = [grasp_files]
for grasp_file in grasp_files:
grasp_file = '%s/%s'%(obj_dir, grasp_file)
if not os.path.exists(grasp_file):
print('Grasp file %s not found'%grasp_file)
continue
_data = pickle.load(open(grasp_file, 'rb'))
_data['grasp_pose'] = np.array(_data['grasp_pose'])
_data['width'] = np.array(_data['width'])
if _data['grasp_pose'].shape[0]==0:
print('Empty grasp pose in %s'%grasp_file)
continue
grasp_data['grasp_pose'].append(_data['grasp_pose'])
grasp_data['width'].append(_data['width'])
if len(grasp_data['grasp_pose'])==0:
print('Empty grasp pose in %s'%grasp_file)
continue
grasp_data['grasp_pose'] = np.concatenate(grasp_data['grasp_pose'])
grasp_data['width'] = np.concatenate(grasp_data['width'])
N_grasp = grasp_data['grasp_pose'].shape[0]
max_grasps = 100
use_nms = True if N_grasp>60 else False
if use_nms:
# import ipdb;ipdb.set_trace()
gripper_pose = grasp_data['grasp_pose']
N = gripper_pose.shape[0]
width = grasp_data['width'][:N, np.newaxis]
rotation = gripper_pose[:, :3, :3].reshape(N, -1)
translation = gripper_pose[:, :3, 3]
height = 0.02 * np.ones_like(width)
depth = np.zeros_like(width)
score = np.ones_like(width) * 0.1
obj_id = -1 * np.ones_like(width)
grasp_group_array = np.concatenate([score, width, height, depth, rotation, translation, obj_id], axis=-1)
if True:
nms_info = interaction_info.get('grasp_nms_threshold', {'translation': 0.015, 'rotation': 25.0})
translation_thresh = nms_info['translation']
rotation_thresh = nms_info['rotation'] / 180.0 * np.pi
grasp_group_array = nms_grasp(grasp_group_array, translation_thresh, rotation_thresh)
rotation = grasp_group_array[:, 4:4+9]
translation = grasp_group_array[:, 4+9:4+9+3]
width = grasp_group_array[:, 1]
grasp_data['grasp_pose'] = np.tile(np.eye(4), (grasp_group_array.shape[0], 1, 1))
grasp_data['grasp_pose'][:, :3, :3] = rotation.reshape(-1,3,3)
grasp_data['grasp_pose'][:, :3, 3] = translation
grasp_data['width'] = width
print(f"Grasp num of {obj.name} after NMS: {grasp_data['grasp_pose'].shape[0]}/{N_grasp}")
# downsample
if grasp_data['grasp_pose'].shape[0]>max_grasps:
grasp_num = grasp_data['grasp_pose'].shape[0]
index = np.random.choice(grasp_num, max_grasps, replace=False)
grasp_data['grasp_pose'] = grasp_data['grasp_pose'][index]
grasp_data['width'] = grasp_data['width'][index]
print(f"Grasp num of {obj.name} after downsample: {grasp_data['grasp_pose'].shape[0]}/{N_grasp}")
else:
print(f"Grasp num of {obj.name}: {grasp_data['grasp_pose'].shape[0]}/{N_grasp}")
action_info[grasp_part] = grasp_data
interaction_info[type][action] = action_info
else:
for primitive in action_info:
for primitive_info in action_info[primitive]:
if 'part_id' in primitive_info:
obj.part_ids.append(primitive_info['part_id'])
obj.elements = interaction_info
obj.info = obj_info
# import ipdb;ipdb.set_trace()
obj.is_articulated = True if part_joint_limits_info is not None else False
return obj
def set_element(self, element):
action = element['action']
self.elements[action] = element
def set_mask(self, mask, roi=None):
self.mask = mask
self.roi = roi
def set_pose(self, pose, length):
self.obj_pose = pose
self.obj_length = length
def set_part(self, xyz=None, direction=None, direction_proposals=None, relative=True):
if xyz is not None:
if not isinstance(xyz, np.ndarray):
xyz = np.array(xyz)
if relative:
xyz = xyz * self.obj_length / 2.0
self.xyz = xyz
if direction is not None:
if not isinstance(direction, np.ndarray):
direction = np.array(direction)
direction = direction / np.linalg.norm(direction) * 0.08
self.direction = direction
if direction_proposals is not None:
self.direction_proposals = direction_proposals
def format_object(self, relative=True):
xyz, direction = self.xyz, self.direction
obj_type = self.type.lower()
if obj_type=='active':
xyz_start = xyz
xyz_end = xyz_start + direction
elif obj_type=='passive' or obj_type=='plane':
xyz_end = xyz
xyz_start = xyz_end - direction
arrow_in_obj = np.array([xyz_start, xyz_end]).transpose(1,0)
arrow_in_world = transform_coordinates_3d(arrow_in_obj, self.obj_pose).transpose(1,0)
xyz_start_world, xyz_end_world = arrow_in_world
direction_world = xyz_end_world - xyz_start_world
direction_world = direction_world / np.linalg.norm(direction_world)
part2obj = np.eye(4)
part2obj[:3, 3] = xyz_start
self.obj2part = np.linalg.inv(part2obj)
object_world = {
'pose': self.obj_pose,
'length': self.obj_length,
'xyz_start': xyz_start,
'xyz_end': xyz_end,
'xyz_start_world': xyz_start_world,
'xyz_end_world': xyz_end_world,
'direction': direction_world,
'obj2part': self.obj2part
}
return object_world