add quad mesh support to usd parser

This commit is contained in:
Balakumar Sundaralingam
2025-08-04 12:24:44 -07:00
parent cca894de9e
commit 1c6f3a0742

View File

@@ -11,7 +11,7 @@
# Standard Library
import math
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Union, Tuple
# Third Party
import numpy as np
@@ -55,6 +55,44 @@ except ImportError:
+ " NOTE: Do not install this if using with ISAAC SIM."
)
# Quad mesh triangulation warp kernel
try:
import warp as wp # type: ignore
_WARP_AVAILABLE = True
@wp.kernel
def _triangulate_quads_kernel(
verts: wp.array1d(dtype=wp.vec3),
quads: wp.array1d(dtype=wp.vec4i),
tris_out: wp.array1d(dtype=wp.vec3i),
):
ind = wp.tid()
ind_out = 2 * ind # each quad as 2 triangles in the output buffer
quad_inds = quads[ind]
q0, q1, q2, q3 = quad_inds[0], quad_inds[1], quad_inds[2], quad_inds[3]
v0 = verts[q0]
v1 = verts[q1]
v2 = verts[q2]
v3 = verts[q3]
# Two possible triangulations
nA1 = wp.normalize(wp.cross(v1 - v0, v2 - v0))
nA2 = wp.normalize(wp.cross(v2 - v0, v3 - v0))
nB1 = wp.normalize(wp.cross(v3 - v1, v0 - v1))
nB2 = wp.normalize(wp.cross(v2 - v1, v3 - v1))
if wp.dot(nA1, nA2) > wp.dot(nB1, nB2):
tris_out[ind_out] = wp.vec3i(q0, q1, q2)
tris_out[ind_out + 1] = wp.vec3i(q0, q2, q3)
else:
tris_out[ind_out] = wp.vec3i(q1, q3, q0)
tris_out[ind_out + 1] = wp.vec3i(q1, q2, q3)
except ImportError: # pragma: no cover Warp not available
_WARP_AVAILABLE = False
def set_prim_translate(prim, translation):
UsdGeom.Xformable(prim).AddTranslateOp().Set(Gf.Vec3d(translation))
@@ -348,6 +386,84 @@ def get_sphere_attrs(prim, cache=None, transform=None) -> Sphere:
return Sphere(name=str(prim.GetPath()), pose=pose, radius=radius, position=pose[:3])
def triangulate_mesh_faces(
vertices: List[List[float]],
faces: List[int],
face_count: List[int],
) -> Tuple[List[int], List[int]]:
# Triangulate mesh faces. Returns a flattened index buffer and a face-count list (all 3s).
if not faces or not face_count:
raise ValueError("Empty face data provided")
expected_vertices = sum(face_count)
if len(faces) != expected_vertices:
raise ValueError(
f"Face data inconsistent: {len(faces)} vertices but face_count sums to {expected_vertices}"
)
if any(c < 3 for c in face_count):
raise ValueError("Found a face with < 3 vertices")
tri_count = sum(1 for c in face_count if c == 3)
quad_count = sum(1 for c in face_count if c == 4)
other_count = len(face_count) - tri_count - quad_count
log_info(f"[Triangulation] triangles={tri_count}, quads={quad_count}, others={other_count}")
if other_count == 0 and quad_count == 0:
return faces, face_count
new_faces: List[int] = []
quad_list: List[List[int]] = []
face_idx = 0
for count in face_count:
if count == 3:
new_faces.extend(faces[face_idx : face_idx + 3])
elif count == 4:
quad_list.append(faces[face_idx : face_idx + 4])
else:
v0 = faces[face_idx]
for i in range(1, count - 1):
new_faces.extend([v0, faces[face_idx + i], faces[face_idx + i + 1]])
face_idx += count
if quad_list:
if _WARP_AVAILABLE:
log_info(
f"[Triangulation] Using Warp GPU kernel on {len(quad_list)} quads",
)
import numpy as _np
verts_np = _np.asarray(vertices, dtype=_np.float32)
quads_np = _np.asarray(quad_list, dtype=_np.int32)
verts_wp = wp.array(verts_np, dtype=wp.vec3, device="cuda")
quads_wp = wp.array(quads_np, dtype=wp.vec4i, device="cuda")
tris_wp = wp.empty(shape=(quads_wp.shape[0] * 2,), dtype=wp.vec3i, device="cuda")
wp.launch(
_triangulate_quads_kernel,
dim=quads_wp.shape[0],
inputs=[verts_wp, quads_wp],
outputs=[tris_wp],
)
tris_np = tris_wp.numpy().astype(_np.int64).reshape(-1, 3)
new_faces.extend(tris_np.flatten().tolist())
else:
log_warn(
f"[Triangulation] Warp not available, CPU splitting {len(quad_list)} quads",
)
# CPU fallback split along fixed diagonal (v0-v2)
for q in quad_list:
new_faces.extend([q[0], q[1], q[2], q[0], q[2], q[3]])
triangle_count = len(new_faces) // 3
return new_faces, [3] * triangle_count
def get_mesh_attrs(prim, cache=None, transform=None) -> Mesh:
# read cube information
# scale = prim.GetAttribute("size").Get()
@@ -360,13 +476,13 @@ def get_mesh_attrs(prim, cache=None, transform=None) -> Mesh:
face_count = list(prim.GetAttribute("faceVertexCounts").Get())
# assume faces are 3:
if len(faces) / 3 != len(face_count):
log_warn(
"Mesh faces "
+ str(len(faces) / 3)
+ " are not matching faceVertexCounts "
+ str(len(face_count))
)
return None
log_warn(f"Mesh {prim.GetPath()} has non-triangular faces, triangulating...")
try:
faces, face_count = triangulate_mesh_faces(points, faces, face_count)
log_info(f"Triangulated mesh {prim.GetPath()} with {len(faces)} faces")
except ValueError as e:
log_error(f"Failed to triangulate mesh {prim.GetPath()}: {e}")
return None
faces = np.array(faces).reshape(len(face_count), 3).tolist()
if prim.GetAttribute("xformOp:scale").IsValid():
scale = list(prim.GetAttribute("xformOp:scale").Get())