Files
gen_data_agent/data_gen_dependencies/omniagent.py
2025-09-05 19:18:37 +08:00

426 lines
22 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
import json
import os
import pickle
import time
import random
import numpy as np
from pyboot.utils.log import Log
from data_gen_dependencies.base_agent import BaseAgent
from data_gen_dependencies.manip_solver import load_task_solution, generate_action_stages, split_grasp_stages
from data_gen_dependencies.omni_robot import IsaacSimRpcRobot
from data_gen_dependencies.transforms import calculate_rotation_matrix
class Agent(BaseAgent):
def __init__(self, robot: IsaacSimRpcRobot):
super().__init__(robot)
self.attached_obj_id = None
def start_recording(self, task_name, camera_prim_list, fps, render_semantic=False):
self.robot.client.start_recording(
task_name=task_name,
fps=fps,
data_keys={
"camera": {
"camera_prim_list": camera_prim_list,
"render_depth": False,
"render_semantic": render_semantic,
},
"pose": ["/World/Raise_A2/gripper_center"],
"joint_position": True,
"gripper": True,
},
)
def generate_layout(self, task_file):
self.task_file = task_file
with open(task_file, "r") as f:
task_info = json.load(f)
# add mass for stable manipulation
for stage in task_info['stages']:
if stage['action'] in ['place', 'insert', 'pour']:
obj_id = stage['passive']['object_id']
for i in range(len(task_info['objects'])):
if task_info['objects'][i]['object_id'] == obj_id:
task_info['objects'][i]['mass'] = 10
break
self.articulated_objs = []
for object_info in task_info["objects"]:
if object_info['object_id'] == 'fix_pose':
continue
is_articulated = object_info.get('is_articulated', False)
if is_articulated:
self.articulated_objs.append(object_info['object_id'])
object_info['material'] = 'general'
self.add_object(object_info)
time.sleep(2)
self.arm = task_info["robot"]["arm"]
''' For A2D: Fix camera rotaton to look at target object '''
task_related_objs = []
for stage in task_info['stages']:
for type in ['active', 'passive']:
obj_id = stage[type]['object_id']
if obj_id == 'gripper' or obj_id in task_related_objs:
continue
task_related_objs.append(obj_id)
target_lookat_point = []
for obj in task_info['objects']:
if obj['object_id'] not in task_related_objs:
continue
target_lookat_point.append(obj['position'])
target_lookat_point = np.mean(np.stack(target_lookat_point), axis=0)
self.robot.client.SetTargetPoint(target_lookat_point.tolist())
''' Set material '''
material_infos = []
if "object_with_material" in task_info:
for key in task_info['object_with_material']:
material_infos += task_info['object_with_material'][key]
if len(material_infos):
self.robot.client.SetMaterial(material_infos)
time.sleep(0.3)
''' Set light '''
light_infos = []
if "lights" in task_info:
for key in task_info['lights']:
light_infos += task_info['lights'][key]
if len(light_infos):
self.robot.client.SetLight(light_infos)
time.sleep(0.3)
''' Set camera'''
if "cameras" in task_info:
for cam_id in task_info['cameras']:
cam_info = task_info['cameras'][cam_id]
self.robot.client.AddCamera(
cam_id, cam_info['position'], cam_info['quaternion'],
cam_info['width'], cam_info['height'],
cam_info['focal_length'], cam_info['horizontal_aperture'], cam_info['vertical_aperture'],
cam_info['is_local']
)
def update_objects(self, objects, arm='right'):
# update gripper pose
objects['gripper'].obj_pose = self.robot.get_ee_pose(ee_type='gripper', id=arm)
# update object pose
for obj_id in objects:
if obj_id == 'gripper':
continue
if obj_id == 'fix_pose':
if len(objects['fix_pose'].obj_pose) == 3:
position = objects['fix_pose'].obj_pose
rotation_matrix = calculate_rotation_matrix(objects['fix_pose'].direction, [0, 0, 1])
objects['fix_pose'].obj_pose = np.eye(4)
objects['fix_pose'].obj_pose[:3, 3] = position.flatten()
objects['fix_pose'].obj_pose[:3, :3] = rotation_matrix
continue
# TODO(unify part_name and obj_name)
if '/' in obj_id:
obj_name = obj_id.split('/')[0]
part_name = obj_id.split('/')[1]
object_joint_state = self.robot.client.get_object_joint('/World/Objects/%s' % obj_name)
for joint_name, joint_position, joint_velocity in zip(object_joint_state.joint_names,
object_joint_state.joint_positions,
object_joint_state.joint_velocities):
if joint_name[-1] == part_name[-1]:
objects[obj_id].joint_position = joint_position
objects[obj_id].joint_velocity = joint_velocity
objects[obj_id].obj_pose = self.robot.get_prim_world_pose('/World/Objects/%s' % obj_id)
if hasattr(objects[obj_id], 'info') and 'simple_place' in objects[obj_id].info and objects[obj_id].info[
'simple_place']:
down_direction_world = (np.linalg.inv(objects[obj_id].obj_pose) @ np.array([0, 0, -1, 1]))[:3]
down_direction_world = down_direction_world / np.linalg.norm(down_direction_world) * 0.08
objects[obj_id].elements['active']['place']['direction'] = down_direction_world
return objects
def check_task_file(self, task_file):
with open(task_file, "r") as f:
task_info = json.load(f)
objs_dir = {}
objs_interaction = {}
for obj_info in task_info["objects"]:
obj_id = obj_info["object_id"]
if obj_id == 'fix_pose':
continue
objs_dir[obj_id] = obj_info["data_info_dir"]
if "interaction" in obj_info:
objs_interaction[obj_id] = obj_info["interaction"]
else:
objs_interaction[obj_id] = json.load(open(obj_info["data_info_dir"] + '/interaction.json'))[
'interaction']
for stage in task_info['stages']:
active_obj_id = stage['active']['object_id']
passive_obj_id = stage['passive']['object_id']
if active_obj_id != 'gripper':
if active_obj_id not in objs_dir:
Log.error('Active obj not in objs_dir: %s' % active_obj_id)
return False
if passive_obj_id != 'gripper' and passive_obj_id != 'fix_pose':
if passive_obj_id not in objs_dir:
Log.error('Passive obj not in objs_dir: %s' % passive_obj_id)
return False
data_root = os.path.dirname(os.path.dirname(__file__)) + "/assets"
if stage['action'] in ['grasp', 'pick']:
passive_obj_id = stage['passive']['object_id']
obj_dir = objs_dir[passive_obj_id]
primitive = stage['passive']['primitive']
if primitive is None:
file = 'grasp_pose/grasp_pose.pkl'
else:
file = objs_interaction[passive_obj_id]['passive']['grasp'][primitive]
if isinstance(file, list):
file = file[0]
grasp_file = os.path.join(data_root, obj_dir, file)
if not os.path.exists(grasp_file):
Log.error('-- Grasp file not exist: %s' % grasp_file)
return False
_data = pickle.load(open(grasp_file, 'rb'))
if len(_data['grasp_pose']) == 0:
Log.error('-- Grasp file empty: %s' % grasp_file)
return False
return True
def run(self, task_list, camera_list, use_recording, workspaces, fps=10, render_semantic=False):
for index, task_file in enumerate(task_list):
if not self.check_task_file(task_file):
Log.error("Task file bad: %s" % task_file)
continue
Log.info("start task: "+ task_file)
self.reset()
self.attached_obj_id = None
# import ipdb;ipdb.set_trace()
self.generate_layout(task_file)
# import ipdb;ipdb.set_trace()
self.robot.open_gripper(id='right')
self.robot.open_gripper(id='left')
self.robot.reset_pose = {
'right': self.robot.get_ee_pose(ee_type='gripper', id='right'),
'left': self.robot.get_ee_pose(ee_type='gripper', id='left'),
}
#print('Reset pose:', self.robot.reset_pose)
task_info = json.load(open(task_file, 'rb'))
stages, objects = load_task_solution(task_info)
objects = self.update_objects(objects)
split_stages = split_grasp_stages(stages)
# import ipdb; ipdb.set_trace()
if use_recording:
self.start_recording(task_name="[%s]" % (os.path.basename(task_file).split(".")[0]),
camera_prim_list=camera_list, fps=fps,
render_semantic=render_semantic) # TODO 录制判断
stage_id = -1
success = False
substages = None
for _stages in split_stages:
extra_params = _stages[0].get('extra_params', {})
active_id, passive_id = _stages[0]['active']['object_id'], _stages[0]['passive']['object_id']
arm = extra_params.get('arm', 'right')
action_stages = generate_action_stages(objects, _stages, self.robot)
if not len(action_stages):
success = False
print('No action stage generated.')
break
# Execution
success = True
for action, substages in action_stages:
stage_id += 1
Log.info(f'start action stage: {action} ({stage_id}/{len(action_stages)})')
active_id, passive_id = _stages[stage_id]['active']['object_id'], _stages[stage_id]['passive'][
'object_id']
if action in ['reset']:
init_pose = self.robot.reset_pose[arm]
curr_pose = self.robot.get_ee_pose(ee_type='gripper', id=arm)
interp_pose = init_pose.copy()
interp_pose[:3, 3] = curr_pose[:3, 3] + (init_pose[:3, 3] - curr_pose[:3, 3]) * 0.25
success = self.robot.move_pose(self.robot.reset_pose[arm], type='AvoidObs', arm=arm, block=True)
continue
while len(substages):
objects = self.update_objects(objects, arm=arm)
target_gripper_pose, motion_type, gripper_action, arm = substages.get_action(objects)
arm = extra_params.get('arm', 'right')
self.robot.client.set_frame_state(action, substages.step_id, active_id, passive_id,
self.attached_obj_id is not None, arm=arm,
target_pose=target_gripper_pose)
if target_gripper_pose is not None:
self.robot.move_pose(target_gripper_pose, motion_type, arm=arm, block=True)
set_gripper_open = gripper_action == 'open'
set_gripper_close = gripper_action == 'close'
self.robot.client.set_frame_state(action, substages.step_id, active_id, passive_id,
self.attached_obj_id is not None, set_gripper_open,
set_gripper_close, arm=arm, target_pose=target_gripper_pose)
self.robot.set_gripper_action(gripper_action, arm=arm)
if gripper_action == 'open':
time.sleep(1)
self.robot.client.set_frame_state(action, substages.step_id, active_id, passive_id,
self.attached_obj_id is not None, arm=arm,
target_pose=target_gripper_pose)
# check sub-stage completion
objects['gripper'].obj_pose = self.robot.get_ee_pose(ee_type='gripper', id=arm)
objects = self.update_objects(objects, arm=arm)
success = substages.check_completion(objects)
self.robot.client.set_frame_state(action, substages.step_id, active_id, passive_id,
self.attached_obj_id is not None, arm=arm,
target_pose=target_gripper_pose)
if success == False:
# import ipdb;ipdb.set_trace()
self.attached_obj_id = None
Log.error('Failed at sub-stage %d' % substages.step_id)
break
# attach grasped object to gripper # TODO avoid articulated objects
if gripper_action == 'close': # TODO 确定是grasp才行
self.attached_obj_id = substages.passive_obj_id
elif gripper_action == 'open':
self.attached_obj_id = None
self.robot.client.set_frame_state(action, substages.step_id, active_id, passive_id,
self.attached_obj_id is not None, arm=arm,
target_pose=target_gripper_pose)
# change object position
num_against = substages.extra_params.get('against', 0)
against_area = substages.extra_params.get('against_range', [])
against_type = substages.extra_params.get('against_type', None)
if num_against > 0 and gripper_action == 'open' and action == 'pick' and (
against_type is not None):
parts = against_type.split('_')
need_move_objects = [passive_id]
if parts[0] == 'move' and parts[1] == 'with':
for workspace in objects:
if parts[2] in workspace:
need_move_objects.append(workspace)
## 目前就集体向一个方向移动
offset_y = random.uniform(0, 0.2)
poses = []
for move_object in need_move_objects:
response = self.robot.client.get_object_pose(f'/World/Objects/{move_object}')
pos = [response.object_pose.position.x,
response.object_pose.position.y + offset_y,
response.object_pose.position.z]
quat_wxyz = np.array(
[
response.object_pose.rpy.rw,
response.object_pose.rpy.rx,
response.object_pose.rpy.ry,
response.object_pose.rpy.rz,
]
)
object_pose = {}
object_pose["prim_path"] = f'/World/Objects/{move_object}'
object_pose["position"] = pos
object_pose["rotation"] = quat_wxyz
poses.append(object_pose)
print(poses)
self.robot.client.SetObjectPose(poses, [])
self.robot.client.DetachObj()
elif (num_against > 0 and gripper_action == 'open' and action == 'pick') or \
(num_against > 0 and action == 'place' and gripper_action == 'close'):
# import ipdb;ipdb.set_trace()
x_min, x_max = 999.0, -999.0
y_min, y_max = 999.0, -999.0
if against_area:
selected_against_area = random.choice(against_area)
if selected_against_area in workspaces:
position = workspaces[selected_against_area]['position']
size = workspaces[selected_against_area]['size']
x_min, x_max = position[0] - size[0] / 2, position[0] + size[0] / 2
y_min, y_max = position[1] - size[1] / 2, position[1] + size[1] / 2
x_size, y_size, z_size = objects[passive_id].info['size']
up_axis = objects[passive_id].info['upAxis']
axis_mapping = {
'x': (y_size / 2000.0, z_size / 2000.0),
'y': (x_size / 2000.0, z_size / 2000.0),
'z': (x_size / 2000.0, y_size / 2000.0)
}
dimensions = axis_mapping[up_axis[0]]
distance = np.linalg.norm(dimensions)
x_min += distance * 1.5
x_max -= distance * 1.5
y_min += distance * 1.5
y_max -= distance * 1.5
else:
print("against_range not set")
continue
response = self.robot.client.get_object_pose(f'/World/Objects/{passive_id}')
pos = [response.object_pose.position.x,
response.object_pose.position.y,
response.object_pose.position.z + 0.02]
quat_wxyz = np.array(
[
response.object_pose.rpy.rw,
response.object_pose.rpy.rx,
response.object_pose.rpy.ry,
response.object_pose.rpy.rz,
]
)
# import ipdb;ipdb.set_trace()
# if position is close to the gripper, then random position again
while True:
pos[0] = random.uniform(x_min, x_max)
pos[1] = random.uniform(y_min, y_max)
distance = np.linalg.norm(
[pos[0] - target_gripper_pose[0][3], pos[1] - target_gripper_pose[1][3]])
if distance >= 0.2:
break
poses = []
object_pose = {}
object_pose["prim_path"] = f'/World/Objects/{passive_id}'
object_pose["position"] = pos
object_pose["rotation"] = quat_wxyz
poses.append(object_pose)
print(poses)
self.robot.client.SetObjectPose(poses, [])
self.robot.client.DetachObj()
if success == False:
self.attached_obj_id = None
break
if self.attached_obj_id is not None:
if self.attached_obj_id.split('/')[0] not in self.articulated_objs:
self.robot.client.DetachObj()
self.robot.client.AttachObj(prim_paths=['/World/Objects/' + self.attached_obj_id],
is_right=arm == 'right')
if success == False:
break
time.sleep(0.5)
if self.attached_obj_id is None:
self.robot.client.DetachObj()
self.robot.client.stop_recording()
step_id = -1
fail_stage_step = [stage_id, step_id] if not success else [-1, -1]
task_info_saved = task_info.copy()
self.robot.client.SendTaskStatus(success, fail_stage_step)
if success:
print(">>>>>>>>>>>>>>>> Success!!!!!!!!!!!!!!!!")
return True