@@ -451,5 +451,77 @@ def intersection_ray_sphere(sphere_center,
451
451
return intersections_points , normals
452
452
453
453
454
+ def intersection_ray_triangle (
455
+ ray_org ,
456
+ ray_dir ,
457
+ triangles ,
458
+ epsilon = 1e-8 ,
459
+ name = "ray_intersection_ray_triangle" ,
460
+ ):
461
+ """Möller-Trumbore intersection algorithm.
462
+
463
+ Simultaneously computes barycentric coordinates and distance to intersections
464
+ of ray to planes defined by triangles. Uses epsilon to detect and ignore
465
+ numerically unstable cases, returning all zeros instead. No attempt is made
466
+ to ensure that intersections are contained within each triangle.
467
+
468
+ Note:
469
+ In the following, A1 to An are optional batch dimensions.
470
+
471
+ Args:
472
+ ray_org: A tensor of shape `[A1, ..., An, 3]`,
473
+ where the last dimension represents the 3D position of the ray origin.
474
+ ray_dir: A tensor of shape `[A1, ..., An, 3]`, where
475
+ the last dimension represents the normalized 3D direction of the ray.
476
+ triangles: A tensor of shape `[A1, ..., An, 3, 3]`, containing batches of
477
+ triangles represented using 3 vertices, where the last dimension
478
+ represents the 3D position of each vertex.
479
+ epsilon: Epsilon value use to detect and ignore degenerate cases.
480
+ name: A name for this op that defaults to "ray_intersection_ray_triangle"
481
+
482
+ Returns:
483
+ A tensor of shape `[A1, ..., An, 3]` representing the barycentric
484
+ coordinates of each intersection location, and a tensor of shape
485
+ `[A1, ..., An]` containing the distance of each ray origin to the
486
+ intersection location
487
+ """
488
+ with tf .name_scope (name ):
489
+ ray_org = tf .convert_to_tensor (value = ray_org )
490
+ ray_dir = tf .convert_to_tensor (value = ray_dir )
491
+ triangles = tf .convert_to_tensor (value = triangles )
492
+
493
+ shape .check_static (
494
+ tensor = ray_org , tensor_name = "ray_org" , has_dim_equals = (- 1 , 3 ))
495
+ shape .check_static (
496
+ tensor = ray_dir , tensor_name = "ray_dir" , has_dim_equals = (- 1 , 3 ))
497
+ shape .check_static (
498
+ tensor = triangles ,
499
+ tensor_name = "triangles" ,
500
+ has_dim_equals = [(- 2 , 3 ), (- 1 , 3 )],
501
+ )
502
+
503
+ shape .compare_batch_dimensions (
504
+ (ray_org , ray_dir , triangles ), (- 2 , - 2 , - 3 ),
505
+ broadcast_compatible = False )
506
+
507
+ e1 = triangles [..., 1 , :] - triangles [..., 0 , :]
508
+ e2 = triangles [..., 2 , :] - triangles [..., 0 , :]
509
+ s = ray_org - triangles [..., 0 , :]
510
+ h = tf .linalg .cross (ray_dir , e2 )
511
+ q = tf .linalg .cross (s , e1 )
512
+ a = vector .dot (h , e1 , keepdims = False )
513
+ invalid = tf .abs (a ) < epsilon
514
+ denom = tf .where (invalid , tf .zeros_like (a ), tf .math .divide_no_nan (1.0 , a ))
515
+
516
+ t = denom * vector .dot (q , e2 , keepdims = False )
517
+ b1 = denom * vector .dot (h , s , keepdims = False )
518
+ b2 = denom * vector .dot (q , ray_dir , keepdims = False )
519
+ b0 = 1 - b1 - b2
520
+ barys = tf .stack ((b0 , b1 , b2 ), axis = - 1 )
521
+ barys = tf .where (invalid [..., tf .newaxis ], tf .zeros_like (barys ), barys )
522
+ t = tf .where (invalid , tf .zeros_like (t ), t )
523
+ return barys , t
524
+
525
+
454
526
# API contains all public functions and classes.
455
527
__all__ = export_api .get_functions_and_classes ()
0 commit comments