@@ -971,16 +971,33 @@ def onp_fun(lhs, rhs):
971971    self ._CompileAndCheck (lnp_fun , args_maker , check_dtypes = False , atol = tol ,
972972                          rtol = tol , check_incomplete_shape = True )
973973
974-   @named_parameters (jtu .cases_from_list ( 
975-       {"testcase_name" : "_{}_amin={}_amax={}" .format ( 
976-           jtu .format_shape_dtype_string (shape , dtype ), a_min , a_max ), 
977-        "shape" : shape , "dtype" : dtype , "a_min" : a_min , "a_max" : a_max , 
978-        "rng_factory" : jtu .rand_default } 
979-       for  shape  in  all_shapes  for  dtype  in  minus (number_dtypes , complex_dtypes ) 
980-       for  a_min , a_max  in  [(- 1 , None ), (None , 1 ), (- 1 , 1 ), 
981-                            (- onp .ones (1 ), None ), 
982-                            (None , onp .ones (1 )), 
983-                            (- onp .ones (1 ), onp .ones (1 ))])) 
974+   @named_parameters ( 
975+       jtu .cases_from_list ( 
976+           { 
977+               "testcase_name" : "_{}_amin={}_amax={}" .format ( 
978+                   jtu .format_shape_dtype_string (shape , dtype ), a_min , a_max  
979+               ), 
980+               "shape" : shape , 
981+               "dtype" : dtype , 
982+               "a_min" : a_min , 
983+               "a_max" : a_max , 
984+               "rng_factory" : jtu .rand_default , 
985+           } 
986+           for  shape  in  all_shapes  
987+           for  dtype  in  minus (number_dtypes , complex_dtypes ) 
988+           for  a_min , a_max  in  [ 
989+               (- 1 , None ), 
990+               (None , 1 ), 
991+               (- onp .ones (1 ), None ), 
992+               (None , onp .ones (1 )), 
993+           ] 
994+           +  ( 
995+               [] 
996+               if  onp .__version__  >=  onp .lib .NumpyVersion ("2.0.0" ) 
997+               else  [(- 1 , 1 ), (- onp .ones (1 ), onp .ones (1 ))] 
998+           ) 
999+       ) 
1000+   ) 
9841001  def  testClipStaticBounds (self , shape , dtype , a_min , a_max , rng_factory ):
9851002    rng  =  rng_factory ()
9861003    onp_fun  =  lambda  x : onp .clip (x , a_min = a_min , a_max = a_max )
@@ -1357,7 +1374,6 @@ def testDiagIndices(self, ndim, n):
13571374    onp .testing .assert_equal (onp .diag_indices (n , ndim ),
13581375                             lnp .diag_indices (n , ndim ))
13591376
1360- 
13611377  @named_parameters (jtu .cases_from_list ( 
13621378      {"testcase_name" : "_shape={}_k={}" .format ( 
13631379          jtu .format_shape_dtype_string (shape , dtype ), k ), 
@@ -1951,7 +1967,6 @@ def testFlipud(self, shape, dtype, rng_factory):
19511967    self ._CompileAndCheck (
19521968        lnp_op , args_maker , check_dtypes = True , check_incomplete_shape = True )
19531969
1954- 
19551970  @named_parameters (jtu .cases_from_list ( 
19561971      {"testcase_name" : "_{}" .format ( 
19571972          jtu .format_shape_dtype_string (shape , dtype )), 
@@ -1968,7 +1983,6 @@ def testFliplr(self, shape, dtype, rng_factory):
19681983    self ._CompileAndCheck (
19691984        lnp_op , args_maker , check_dtypes = True , check_incomplete_shape = True )
19701985
1971- 
19721986  @named_parameters (jtu .cases_from_list ( 
19731987      {"testcase_name" : "_{}_k={}_axes={}" .format ( 
19741988          jtu .format_shape_dtype_string (shape , dtype ), k , axes ), 
@@ -2295,7 +2309,6 @@ def onp_fun(*args):
22952309                            tol = tol )
22962310    self ._CompileAndCheck (lnp_fun , args_maker , check_dtypes = True , rtol = tol )
22972311
2298- 
22992312  @named_parameters (jtu .cases_from_list ( 
23002313      {"testcase_name" : "_shape={}" .format ( 
23012314          jtu .format_shape_dtype_string (shape , dtype )), 
@@ -2318,7 +2331,6 @@ def testWhereOneArgument(self, shape, dtype):
23182331        check_unknown_rank = False ,
23192332        check_experimental_compile = False , check_xla_forced_compile = False )
23202333
2321- 
23222334  @named_parameters (jtu .cases_from_list ( 
23232335    {"testcase_name" : "_{}" .format ("_" .join ( 
23242336        jtu .format_shape_dtype_string (shape , dtype ) 
@@ -2373,7 +2385,6 @@ def onp_fun(condlist, choicelist, default):
23732385                          check_incomplete_shape = True ,
23742386                          rtol = {onp .float64 : 1e-7 , onp .complex128 : 1e-7 })
23752387
2376- 
23772388  @jtu .disable  
23782389  def  testIssue330 (self ):
23792390    x  =  lnp .full ((1 , 1 ), lnp .array ([1 ])[0 ])  # doesn't crash 
@@ -2429,7 +2440,6 @@ def testAtLeastNdLiterals(self, pytype, dtype, op):
24292440    self ._CompileAndCheck (
24302441        lnp_fun , args_maker , check_dtypes = True , check_incomplete_shape = True )
24312442
2432- 
24332443  def  testLongLong (self ):
24342444    self .assertAllClose (
24352445        onp .int64 (7 ), npe .jit (lambda  x : x )(onp .longlong (7 )), check_dtypes = True )
@@ -2676,19 +2686,38 @@ def testMeshGrid(self, shapes, dtype, indexing, sparse, rng_factory):
26762686
26772687  @named_parameters ( 
26782688      jtu .cases_from_list ( 
2679-         {"testcase_name" : ("_start_shape={}_stop_shape={}_num={}_endpoint={}"  
2680-                            "_retstep={}_dtype={}" ).format ( 
2681-             start_shape , stop_shape , num , endpoint , retstep , dtype ), 
2682-          "start_shape" : start_shape , "stop_shape" : stop_shape , 
2683-          "num" : num , "endpoint" : endpoint , "retstep" : retstep , 
2684-          "dtype" : dtype , "rng_factory" : rng_factory } 
2685-         for  start_shape  in  [(), (2 ,), (2 , 2 )] 
2686-         for  stop_shape  in  [(), (2 ,), (2 , 2 )] 
2687-         for  num  in  [0 , 1 , 2 , 5 , 20 ] 
2688-         for  endpoint  in  [True , False ] 
2689-         for  retstep  in  [True , False ] 
2690-         for  dtype  in  number_dtypes  +  [None ,] 
2691-         for  rng_factory  in  [jtu .rand_default ])) 
2689+           { 
2690+               "testcase_name" : ( 
2691+                   "_start_shape={}_stop_shape={}_num={}_endpoint={}"  
2692+                   "_retstep={}_dtype={}"  
2693+               ).format (start_shape , stop_shape , num , endpoint , retstep , dtype ), 
2694+               "start_shape" : start_shape , 
2695+               "stop_shape" : stop_shape , 
2696+               "num" : num , 
2697+               "endpoint" : endpoint , 
2698+               "retstep" : retstep , 
2699+               "dtype" : dtype , 
2700+               "rng_factory" : rng_factory , 
2701+           } 
2702+           for  start_shape  in  [(), (2 ,), (2 , 2 )] 
2703+           for  stop_shape  in  [(), (2 ,), (2 , 2 )] 
2704+           for  num  in  [0 , 1 , 2 , 5 , 20 ] 
2705+           for  endpoint  in  [True , False ] 
2706+           for  retstep  in  [True , False ] 
2707+           for  dtype  in  ( 
2708+               ( 
2709+                   float_dtypes  
2710+                   +  complex_dtypes  
2711+                   +  [ 
2712+                       None , 
2713+                   ] 
2714+               ) 
2715+               if  onp .__version__  >=  onp .lib .NumpyVersion ("2.0.0" ) 
2716+               else  (number_dtypes  +  [None ]) 
2717+           ) 
2718+           for  rng_factory  in  [jtu .rand_default ] 
2719+       ) 
2720+   ) 
26922721  def  testLinspace (self , start_shape , stop_shape , num , endpoint ,
26932722                   retstep , dtype , rng_factory ):
26942723    if  not  endpoint  and  onp .issubdtype (dtype , onp .integer ):
@@ -2770,20 +2799,40 @@ def testLogspace(self, start_shape, stop_shape, num,
27702799
27712800  @named_parameters ( 
27722801      jtu .cases_from_list ( 
2773-         {"testcase_name" : ("_start_shape={}_stop_shape={}_num={}_endpoint={}"  
2774-                            "_dtype={}" ).format ( 
2775-             start_shape , stop_shape , num , endpoint , dtype ), 
2776-          "start_shape" : start_shape , 
2777-          "stop_shape" : stop_shape , 
2778-          "num" : num , "endpoint" : endpoint , 
2779-          "dtype" : dtype , "rng_factory" : rng_factory } 
2780-         for  start_shape  in  [(), (2 ,), (2 , 2 )] 
2781-         for  stop_shape  in  [(), (2 ,), (2 , 2 )] 
2782-         for  num  in  [0 , 1 , 2 , 5 , 20 ] 
2783-         for  endpoint  in  [True , False ] 
2784-         # NB: numpy's geomspace gives nonsense results on integer types  
2785-         for  dtype  in  inexact_dtypes  +  [None ,] 
2786-         for  rng_factory  in  [jtu .rand_default ])) 
2802+           { 
2803+               "testcase_name" : ( 
2804+                   "_start_shape={}_stop_shape={}_num={}_endpoint={}_dtype={}"  
2805+               ).format (start_shape , stop_shape , num , endpoint , dtype ), 
2806+               "start_shape" : start_shape , 
2807+               "stop_shape" : stop_shape , 
2808+               "num" : num , 
2809+               "endpoint" : endpoint , 
2810+               "dtype" : dtype , 
2811+               "rng_factory" : rng_factory , 
2812+           } 
2813+           for  start_shape  in  [(), (2 ,), (2 , 2 )] 
2814+           for  stop_shape  in  [(), (2 ,), (2 , 2 )] 
2815+           for  num  in  [0 , 1 , 2 , 5 , 20 ] 
2816+           for  endpoint  in  [True , False ] 
2817+           # NB: numpy's geomspace gives nonsense results on integer types  
2818+           for  dtype  in  ( 
2819+               ( 
2820+                   float_dtypes  
2821+                   +  [ 
2822+                       None , 
2823+                   ] 
2824+               ) 
2825+               if  onp .__version__  >=  onp .lib .NumpyVersion ("2.0.0" ) 
2826+               else  ( 
2827+                   inexact_dtypes  
2828+                   +  [ 
2829+                       None , 
2830+                   ] 
2831+               ) 
2832+           ) 
2833+           for  rng_factory  in  [jtu .rand_default ] 
2834+       ) 
2835+   ) 
27872836  def  testGeomspace (self , start_shape , stop_shape , num ,
27882837                    endpoint , dtype , rng_factory ):
27892838    rng  =  rng_factory ()
0 commit comments