Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented Möller-Trumbore intersection algorithm for differentiable ray triangle intersection #721

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions tensorflow_graphics/geometry/representation/ray.py
Original file line number Diff line number Diff line change
@@ -451,5 +451,77 @@ def intersection_ray_sphere(sphere_center,
return intersections_points, normals


def intersection_ray_triangle(
ray_org,
ray_dir,
triangles,
epsilon=1e-8,
name="ray_intersection_ray_triangle",
):
"""Möller-Trumbore intersection algorithm.

Simultaneously computes barycentric coordinates and distance to intersections
of ray to planes defined by triangles. Uses epsilon to detect and ignore
numerically unstable cases, returning all zeros instead. No attempt is made
to ensure that intersections are contained within each triangle.

Note:
In the following, A1 to An are optional batch dimensions.

Args:
ray_org: A tensor of shape `[A1, ..., An, 3]`,
where the last dimension represents the 3D position of the ray origin.
ray_dir: A tensor of shape `[A1, ..., An, 3]`, where
the last dimension represents the normalized 3D direction of the ray.
triangles: A tensor of shape `[A1, ..., An, 3, 3]`, containing batches of
triangles represented using 3 vertices, where the last dimension
represents the 3D position of each vertex.
epsilon: Epsilon value use to detect and ignore degenerate cases.
name: A name for this op that defaults to "ray_intersection_ray_triangle"

Returns:
A tensor of shape `[A1, ..., An, 3]` representing the barycentric
coordinates of each intersection location, and a tensor of shape
`[A1, ..., An]` containing the distance of each ray origin to the
intersection location
"""
with tf.name_scope(name):
ray_org = tf.convert_to_tensor(value=ray_org)
ray_dir = tf.convert_to_tensor(value=ray_dir)
triangles = tf.convert_to_tensor(value=triangles)

shape.check_static(
tensor=ray_org, tensor_name="ray_org", has_dim_equals=(-1, 3))
shape.check_static(
tensor=ray_dir, tensor_name="ray_dir", has_dim_equals=(-1, 3))
shape.check_static(
tensor=triangles,
tensor_name="triangles",
has_dim_equals=[(-2, 3), (-1, 3)],
)

shape.compare_batch_dimensions(
(ray_org, ray_dir, triangles), (-2, -2, -3),
broadcast_compatible=False)

e1 = triangles[..., 1, :] - triangles[..., 0, :]
e2 = triangles[..., 2, :] - triangles[..., 0, :]
s = ray_org - triangles[..., 0, :]
h = tf.linalg.cross(ray_dir, e2)
q = tf.linalg.cross(s, e1)
a = vector.dot(h, e1, keepdims=False)
invalid = tf.abs(a) < epsilon
denom = tf.where(invalid, tf.zeros_like(a), tf.math.divide_no_nan(1.0, a))

t = denom * vector.dot(q, e2, keepdims=False)
b1 = denom * vector.dot(h, s, keepdims=False)
b2 = denom * vector.dot(q, ray_dir, keepdims=False)
b0 = 1 - b1 - b2
barys = tf.stack((b0, b1, b2), axis=-1)
barys = tf.where(invalid[..., tf.newaxis], tf.zeros_like(barys), barys)
t = tf.where(invalid, tf.zeros_like(t), t)
return barys, t


# API contains all public functions and classes.
__all__ = export_api.get_functions_and_classes()
52 changes: 52 additions & 0 deletions tensorflow_graphics/geometry/representation/tests/ray_test.py
Original file line number Diff line number Diff line change
@@ -342,6 +342,58 @@ def test_intersection_ray_sphere_preset(self, test_inputs, test_outputs):
self.assert_output_is_correct(
ray.intersection_ray_sphere, test_inputs, test_outputs, tile=False)

@flagsaver.flagsaver(tfg_add_asserts_to_graph=False)
def test_intersection_ray_triangle_random(self):
"""Test the intersection_ray_triangle function."""
tensor_size = np.random.randint(3)
tensor_shape = np.random.randint(1, 10, size=(tensor_size)).tolist()
ray_org = np.random.uniform(size=tensor_shape + [3])
ray_dir = np.random.uniform(size=tensor_shape + [3])
ray_dir /= np.linalg.norm(ray_dir, axis=-1, keepdims=True)
triangles = np.random.uniform(size=tensor_shape + [3, 3])

barys, t = ray.intersection_ray_triangle(ray_org, ray_dir, triangles)

intersections_barys = tf.math.reduce_sum(
barys[..., tf.newaxis] * triangles, axis=-2)

intersections_dists = t[..., tf.newaxis] * ray_dir + ray_org

intersections_barys = tf.where(
tf.abs(t) > 0,
intersections_barys,
tf.zeros_like(intersections_barys))

intersections_dists = tf.where(
tf.abs(t) > 0,
intersections_dists,
tf.zeros_like(intersections_dists))

self.assertAllClose(
intersections_barys, intersections_dists, atol=1e-04, rtol=1e-04)

@parameterized.parameters(
(
(
(0.0, 0.0, 0.0),
(0.0, 0.0, 1.0),
((1.0, 0.0, 1.0), (0.0, 1.0, 1.0), (0.0, 0.0, 1.0)),
),
((0.0, 0.0, 1.0), 1.0),
),
(
(
(0.5, 0.5, 0.0),
(0.0, 0.0, 1.0),
((1.0, 0.0, 0.5), (0.0, 1.0, 0.5), (0.0, 0.0, 0.5)),
),
((0.5, 0.5, 0.0), 0.5),
),
)
def test_intersection_ray_triangle_preset(self, test_inputs, test_outputs):
self.assert_output_is_correct(
ray.intersection_ray_triangle, test_inputs, test_outputs, tile=False)


if __name__ == "__main__":
test_case.main()