Files
gen_data_agent/task_gen_dependencies/solver_2d.py
2025-09-05 11:10:42 +08:00

363 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from scipy.spatial.transform import Rotation as R
from task_gen_dependencies.layout_2d import DFS_Solver_Floor
from task_gen_dependencies.multi_add_util import *
from task_gen_dependencies.utils import axis_to_quaternion, quaternion_rotate, get_rotation_matrix_from_quaternion, \
get_quaternion_from_rotation_matrix, get_xyz_euler_from_quaternion
from pyboot.utils.log import Log
def rotate_along_axis(target_affine, angle_degrees, rot_axis='z', use_local=True):
"""
根据指定的角度和旋转轴来旋转target_affine。
参数:
- target_affine: 4x4 仿射变换矩阵
- angle_degrees: 旋转角度(以度为单位)
- rot_axis: 旋转轴,'x''y''z'
"""
# 将角度转换为弧度
angle_radians = np.deg2rad(angle_degrees)
# 创建旋转对象
if rot_axis == 'z':
rotation_vector = np.array([0, 0, angle_radians])
elif rot_axis == 'y':
rotation_vector = np.array([0, angle_radians, 0])
elif rot_axis == 'x':
rotation_vector = np.array([angle_radians, 0, 0])
else:
raise ValueError("Invalid rotation axis. Please choose from 'x', 'y', 'z'.")
# 生成旋转矩阵
R_angle = R.from_rotvec(rotation_vector).as_matrix()
# 提取旋转部分3x3矩阵
target_rotation = target_affine[:3, :3]
# 将 target_rotation 绕指定轴旋转指定角度,得到 target_rotation_2
if use_local:
target_rotation_2 = np.dot(target_rotation, R_angle)
else:
target_rotation_2 = np.dot(R_angle, target_rotation)
# 重新组合旋转矩阵 target_rotation_2 和原始的平移部分
target_affine_2 = np.eye(4)
target_affine_2[:3, :3] = target_rotation_2
target_affine_2[:3, 3] = target_affine[:3, 3] # 保留原始的平移部分
return target_affine_2
def quaternion_multiply(q1, q2):
"""计算两个四元数的乘积"""
w1, x1, y1, z1 = q1
w2, x2, y2, z2 = q2
w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2
x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2
y = w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2
z = w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2
return np.array([w, x, y, z])
def quaternion_rotate_z(quaternion, angle):
"""
Rotate a quaternion around the global z-axis by a given angle.
Parameters:
quaternion (numpy array): The input quaternion [w, x, y, z].
angle (float): The rotation angle in degrees.
Returns:
numpy array: The rotated quaternion.
"""
# Convert angle from degrees to radians
angle_rad = np.radians(angle)
# Calculate the rotation quaternion for z-axis rotation
cos_half_angle = np.cos(angle_rad / 2)
sin_half_angle = np.sin(angle_rad / 2)
q_z = np.array([cos_half_angle, 0, 0, sin_half_angle])
# Rotate the input quaternion around the global z-axis
rotated_quaternion = quaternion_multiply(q_z, quaternion)
return rotated_quaternion
def rotate_point_ext(px, py, angle, ox, oy):
s, c = math.sin(angle), math.cos(angle)
px, py = px - ox, py - oy
xnew = px * c - py * s
ynew = px * s + py * c
return xnew + ox, ynew + oy
def get_corners(pose, size, angle):
cx, cy = pose
w, h = size
corners = [
rotate_point_ext(cx - w / 2, cy - h / 2, angle, cx, cy),
rotate_point_ext(cx + w / 2, cy - h / 2, angle, cx, cy),
rotate_point_ext(cx + w / 2, cy + h / 2, angle, cx, cy),
rotate_point_ext(cx - w / 2, cy + h / 2, angle, cx, cy)
]
return corners
def compute_bounding_box(objects, expansion=40):
all_corners = []
for pose, size, angle in objects:
all_corners.extend(get_corners(pose, size, angle))
min_x = min(x for x, y in all_corners)
max_x = max(x for x, y in all_corners)
min_y = min(y for x, y in all_corners)
max_y = max(y for x, y in all_corners)
center_x = (min_x + max_x) / 2
center_y = (min_y + max_y) / 2
width = max_x - min_x + expansion
height = max_y - min_y + expansion
return (center_x, center_y, width, height, 0)
def compute_intersection(bbox, plane_center, plane_width, plane_height):
# 解包最小外接矩形的信息
bbox_center_x, bbox_center_y, bbox_width, bbox_height, _ = bbox
# 计算最小外接矩形的边界
min_x = bbox_center_x - bbox_width / 2
max_x = bbox_center_x + bbox_width / 2
min_y = bbox_center_y - bbox_height / 2
max_y = bbox_center_y + bbox_height / 2
# 计算平面的边界
plane_min_x = plane_center[0] - plane_width / 2
plane_max_x = plane_center[0] + plane_width / 2
plane_min_y = plane_center[1] - plane_height / 2
plane_max_y = plane_center[1] + plane_height / 2
# 计算相交部分的边界
intersect_min_x = max(min_x, plane_min_x)
intersect_max_x = min(max_x, plane_max_x)
intersect_min_y = max(min_y, plane_min_y)
intersect_max_y = min(max_y, plane_max_y)
# 检查相交区域是否有效
if intersect_min_x < intersect_max_x and intersect_min_y < intersect_max_y:
# 计算相交矩形的中心和尺寸
intersect_center_x = (intersect_min_x + intersect_max_x) / 2
intersect_center_y = (intersect_min_y + intersect_max_y) / 2
intersect_width = intersect_max_x - intersect_min_x
intersect_height = intersect_max_y - intersect_min_y
return (intersect_center_x, intersect_center_y, intersect_width, intersect_height,0)
else:
# 如果没有有效的相交区域,返回 None 或其他指示无相交的值
return None
class LayoutSolver2D:
'''
1. get room_vertices
2. Generate Constraint with LLM
3. Generate layout meet to the constraint
'''
def __init__(self, workspace_xyz, workspace_size, objects=None, fix_obj_ids=[], obj_infos={}, angle_step=15):
self.angle_step = angle_step
x_half, y_half, z_half = workspace_size / 2
room_vertices = [
[-x_half, -y_half],
[x_half, -y_half],
[x_half, y_half],
[-x_half, y_half]
]
self.plane_width = workspace_size[0]
self.plane_height = workspace_size[1]
self.room_vertices = room_vertices
self.objects = objects
self.obj_infos = obj_infos
self.cx, self.cy, self.cz = (coord * 1000 for coord in workspace_xyz)
self.workspace_Z_half = workspace_size[2] / 2.0
self.z_offset = 20
self.fix_obj_ids = fix_obj_ids
def parse_solution(self, solutions, obj_id):
[obj_cx, obj_cy], rotation, _ = solutions[obj_id][:3]
obj_xyz = np.array([
self.cx + obj_cx,
self.cy + obj_cy,
self.cz - self.workspace_Z_half + self.objects[obj_id].size[2]/2.0 + self.z_offset
])
init_quat = axis_to_quaternion(self.objects[obj_id].up_axis, "z")
obj_quat = quaternion_rotate(init_quat, self.objects[obj_id].up_axis, -rotation)
obj_pose = np.eye(4)
obj_pose[:3, :3] = get_rotation_matrix_from_quaternion(obj_quat)
obj_pose[:3, 3] = obj_xyz
self.objects[obj_id].obj_pose = obj_pose
def old_solution(self,
opt_obj_ids,
exist_obj_ids,
object_extent=50, # 外拓5cm
start_with_edge=False,
grid_size=0.01 # 1cm
):
solver = DFS_Solver_Floor(grid_size=int(grid_size*1000))
room_poly = Polygon(self.room_vertices)
grid_points = solver.create_grids(room_poly)
objs_succ = []
saved_solutions = {} # TODO 将exist_obj_ids的pose保存进saved_solutions
for obj_id in opt_obj_ids:
size = self.objects[obj_id].size
# import pdb;pdb.set_trace()
size_extent = size[:2] + object_extent
solutions = solver.get_all_solutions(room_poly, grid_points, size_extent)
if len(solutions)>0:
if start_with_edge:
solutions = solver.place_edge(room_poly, solutions, size_extent)
if len(saved_solutions) ==0:
saved_solutions[obj_id] = random.choice(solutions)
else:
solutions = solver.filter_collision(saved_solutions, solutions)
if len(solutions)>0:
saved_solutions[obj_id] = random.choice(solutions)
else:
print(f"No valid solutions for apply {obj_id}.")
continue
else:
print(f"No valid solutions for apply {obj_id}.")
continue
self.parse_solution(saved_solutions, obj_id)
objs_succ.append(obj_id)
return objs_succ
def __call__(self,
opt_obj_ids,
exist_obj_ids,
object_extent=50, # 外拓5cm
start_with_edge=False,
key_obj = True,
grid_size=0.01, # 1cm
initial_angle = 0
):
objs_succ = []
fail_objs = []
placed_objects = []
main_objects = []
label_flag = "no extra"
if not key_obj:
for obj_id in self.objects:
if obj_id not in opt_obj_ids:
if obj_id in self.fix_obj_ids:
continue
size = self.objects[obj_id].size
pose = self.objects[obj_id].obj_pose
quaternion_rotate = get_quaternion_from_rotation_matrix(pose[:3, :3])
angle_pa = get_xyz_euler_from_quaternion(quaternion_rotate)[2]
main_objects.append(((pose[0, 3] - self.cx, pose[1, 3] - self.cy), (size[0], size[1]), angle_pa))
# main_objects.append(((pose[0, 3] - self.cx, pose[1, 3] - self.cy), (size[0], size[1]), angle_pa * (180 / math.pi)))
main_bounding_box_info = compute_bounding_box(main_objects, expansion=object_extent)
intersection = compute_intersection(main_bounding_box_info, (0, 0), self.plane_width, self.plane_height)
if intersection is None:
return objs_succ
placed_objects.append((intersection, get_rotated_corners(*intersection)))
label_flag = "add extra"
object_extent = 0
obj_sizes = []
grid_points = create_grid(self.plane_width, self.plane_height, int(grid_size*1000))
saved_solutions = {}
if len(opt_obj_ids) ==1:
attempts = 800
else:
attempts = 400
for obj_id in opt_obj_ids:
size = self.objects[obj_id].size
_extent = self.obj_infos[obj_id].get("extent", object_extent)
size_extent = size[:2] + _extent
obj_sizes.append((size_extent[0],size_extent[1]))
width, height = size_extent[0],size_extent[1]
area_ratio = (width * height) / (self.plane_width * self.plane_height)
# while len(placed_objects) < num_objects and attempts < max_attempts:
# width, height = obj_sizes[len(placed_objects)]
valid_position = False
best_position = None
max_distance = 1
available_grid_points = filter_occupied_grids(grid_points, placed_objects)
for idx in range(attempts):
if not available_grid_points:
break
if len(opt_obj_ids) == 1 and area_ratio > 0.5:
if idx < len(available_grid_points):
x, y = available_grid_points[idx] # 使用索引获取点
angle = random.choice([0,90,180,270])
else:
# 如果索引超出范围,可以选择退出循环或采取其他措施
break
else:
x, y = random.choice(available_grid_points) # 随机选择一个点
angle = np.random.choice(np.arange(0, 360, self.angle_step))
# x, y = available_grid_points[idx] # random.choice(available_grid_points)
# x = random.uniform(-self.plane_width / 2 + width / 2, self.plane_width / 2 - width / 2)
# y = random.uniform(-self.plane_height / 2 + height / 2, self.plane_height / 2 - height / 2)
# if len(placed_objects)==0:
# angle = initial_angle
# else:
new_corners = get_rotated_corners(x, y, width, height, angle)
if is_within_bounds(new_corners, self.plane_width, self.plane_height) and not is_collision(new_corners, placed_objects):
if not placed_objects:
best_position = (x, y, width, height, angle)
valid_position = True
break
min_distance = min(calculate_distance(new_corners, obj[1]) for obj in placed_objects)
if min_distance > max_distance:
max_distance = min_distance
best_position = (x, y, width, height, angle)
valid_position = True
# break
if valid_position:
placed_objects.append((best_position, get_rotated_corners(*best_position)))
saved_solutions[obj_id] = [[best_position[0],best_position[1]],best_position[4],None]
self.parse_solution(saved_solutions, obj_id)
objs_succ.append(obj_id)
else:
fail_objs.append(obj_id)
# objects = [obj[0] for obj in placed_objects]
# visualize_objects(objects, self.plane_width, self.plane_height,str(obj_id)+label_flag+".jpg")
if len(fail_objs)>0:
Log.warning("failed objects in layout 2d: " + str(fail_objs))
for obj_id in objs_succ:
self.objects[obj_id].obj_pose = rotate_along_axis(self.objects[obj_id].obj_pose, random.uniform(-self.angle_step/2.0, self.angle_step/2.0), rot_axis='z', use_local=False)
return objs_succ