Skip to content

Commit e422cc1

Browse files
Implemented Möller-Trumbore intersection algorithm for differentiable ray triangle intersection
PiperOrigin-RevId: 551483422
1 parent 1b0203e commit e422cc1

File tree

2 files changed

+124
-0
lines changed

2 files changed

+124
-0
lines changed

tensorflow_graphics/geometry/representation/ray.py

+72
Original file line numberDiff line numberDiff line change
@@ -451,5 +451,77 @@ def intersection_ray_sphere(sphere_center,
451451
return intersections_points, normals
452452

453453

454+
def intersection_ray_triangle(
455+
ray_org,
456+
ray_dir,
457+
triangles,
458+
epsilon=1e-8,
459+
name="ray_intersection_ray_triangle",
460+
):
461+
"""Möller-Trumbore intersection algorithm.
462+
463+
Simultaneously computes barycentric coordinates and distance to intersections
464+
of ray to planes defined by triangles. Uses epsilon to detect and ignore
465+
numerically unstable cases, returning all zeros instead. No attempt is made
466+
to ensure that intersections are contained within each triangle.
467+
468+
Note:
469+
In the following, A1 to An are optional batch dimensions.
470+
471+
Args:
472+
ray_org: A tensor of shape `[A1, ..., An, 3]`,
473+
where the last dimension represents the 3D position of the ray origin.
474+
ray_dir: A tensor of shape `[A1, ..., An, 3]`, where
475+
the last dimension represents the normalized 3D direction of the ray.
476+
triangles: A tensor of shape `[A1, ..., An, 3, 3]`, containing batches of
477+
triangles represented using 3 vertices, where the last dimension
478+
represents the 3D position of each vertex.
479+
epsilon: Epsilon value use to detect and ignore degenerate cases.
480+
name: A name for this op that defaults to "ray_intersection_ray_triangle"
481+
482+
Returns:
483+
A tensor of shape `[A1, ..., An, 3]` representing the barycentric
484+
coordinates of each intersection location, and a tensor of shape
485+
`[A1, ..., An]` containing the distance of each ray origin to the
486+
intersection location
487+
"""
488+
with tf.name_scope(name):
489+
ray_org = tf.convert_to_tensor(value=ray_org)
490+
ray_dir = tf.convert_to_tensor(value=ray_dir)
491+
triangles = tf.convert_to_tensor(value=triangles)
492+
493+
shape.check_static(
494+
tensor=ray_org, tensor_name="ray_org", has_dim_equals=(-1, 3))
495+
shape.check_static(
496+
tensor=ray_dir, tensor_name="ray_dir", has_dim_equals=(-1, 3))
497+
shape.check_static(
498+
tensor=triangles,
499+
tensor_name="triangles",
500+
has_dim_equals=[(-2, 3), (-1, 3)],
501+
)
502+
503+
shape.compare_batch_dimensions(
504+
(ray_org, ray_dir, triangles), (-2, -2, -3),
505+
broadcast_compatible=False)
506+
507+
e1 = triangles[..., 1, :] - triangles[..., 0, :]
508+
e2 = triangles[..., 2, :] - triangles[..., 0, :]
509+
s = ray_org - triangles[..., 0, :]
510+
h = tf.linalg.cross(ray_dir, e2)
511+
q = tf.linalg.cross(s, e1)
512+
a = vector.dot(h, e1, keepdims=False)
513+
invalid = tf.abs(a) < epsilon
514+
denom = tf.where(invalid, tf.zeros_like(a), tf.math.divide_no_nan(1.0, a))
515+
516+
t = denom * vector.dot(q, e2, keepdims=False)
517+
b1 = denom * vector.dot(h, s, keepdims=False)
518+
b2 = denom * vector.dot(q, ray_dir, keepdims=False)
519+
b0 = 1 - b1 - b2
520+
barys = tf.stack((b0, b1, b2), axis=-1)
521+
barys = tf.where(invalid[..., tf.newaxis], tf.zeros_like(barys), barys)
522+
t = tf.where(invalid, tf.zeros_like(t), t)
523+
return barys, t
524+
525+
454526
# API contains all public functions and classes.
455527
__all__ = export_api.get_functions_and_classes()

tensorflow_graphics/geometry/representation/tests/ray_test.py

+52
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,58 @@ def test_intersection_ray_sphere_preset(self, test_inputs, test_outputs):
342342
self.assert_output_is_correct(
343343
ray.intersection_ray_sphere, test_inputs, test_outputs, tile=False)
344344

345+
@flagsaver.flagsaver(tfg_add_asserts_to_graph=False)
346+
def test_intersection_ray_triangle_random(self):
347+
"""Test the intersection_ray_triangle function."""
348+
tensor_size = np.random.randint(3)
349+
tensor_shape = np.random.randint(1, 10, size=(tensor_size)).tolist()
350+
ray_org = np.random.uniform(size=tensor_shape + [3])
351+
ray_dir = np.random.uniform(size=tensor_shape + [3])
352+
ray_dir /= np.linalg.norm(ray_dir, axis=-1, keepdims=True)
353+
triangles = np.random.uniform(size=tensor_shape + [3, 3])
354+
355+
barys, t = ray.intersection_ray_triangle(ray_org, ray_dir, triangles)
356+
357+
intersections_barys = tf.math.reduce_sum(
358+
barys[..., tf.newaxis] * triangles, axis=-2)
359+
360+
intersections_dists = t[..., tf.newaxis] * ray_dir + ray_org
361+
362+
intersections_barys = tf.where(
363+
tf.abs(t) > 0,
364+
intersections_barys,
365+
tf.zeros_like(intersections_barys))
366+
367+
intersections_dists = tf.where(
368+
tf.abs(t) > 0,
369+
intersections_dists,
370+
tf.zeros_like(intersections_dists))
371+
372+
self.assertAllClose(
373+
intersections_barys, intersections_dists, atol=1e-04, rtol=1e-04)
374+
375+
@parameterized.parameters(
376+
(
377+
(
378+
(0.0, 0.0, 0.0),
379+
(0.0, 0.0, 1.0),
380+
((1.0, 0.0, 1.0), (0.0, 1.0, 1.0), (0.0, 0.0, 1.0)),
381+
),
382+
((0.0, 0.0, 1.0), 1.0),
383+
),
384+
(
385+
(
386+
(0.5, 0.5, 0.0),
387+
(0.0, 0.0, 1.0),
388+
((1.0, 0.0, 0.5), (0.0, 1.0, 0.5), (0.0, 0.0, 0.5)),
389+
),
390+
((0.5, 0.5, 0.0), 0.5),
391+
),
392+
)
393+
def test_intersection_ray_triangle_preset(self, test_inputs, test_outputs):
394+
self.assert_output_is_correct(
395+
ray.intersection_ray_triangle, test_inputs, test_outputs, tile=False)
396+
345397

346398
if __name__ == "__main__":
347399
test_case.main()

0 commit comments

Comments
 (0)