Skip to content

Commit f1cca85

Browse files
committed
[SPARK-43943][SQL][PYTHON][CONNECT] Add SQL math functions to Scala and Python
### What changes were proposed in this pull request? Add following functions: * ceiling * e * pi * ln * negative * positive * power * sign * std * width_bucket to: * Scala API * Python API * Spark Connect Scala Client * Spark Connect Python Client This PR also adds `negate` (which already exists in Scala API and SCSC) to Python API and SCPC. ### Why are the changes needed? for parity ### Does this PR introduce _any_ user-facing change? yes, new functions ### How was this patch tested? added ut / doctest Closes apache#41435 from zhengruifeng/sql_func_math. Authored-by: Ruifeng Zheng <[email protected]> Signed-off-by: Ruifeng Zheng <[email protected]>
1 parent 4ddf83f commit f1cca85

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+839
-1
lines changed

connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -873,6 +873,14 @@ object functions {
873873
*/
874874
def skewness(columnName: String): Column = skewness(Column(columnName))
875875

876+
/**
877+
* Aggregate function: alias for `stddev_samp`.
878+
*
879+
* @group agg_funcs
880+
* @since 3.5.0
881+
*/
882+
def std(e: Column): Column = stddev(e)
883+
876884
/**
877885
* Aggregate function: alias for `stddev_samp`.
878886
*
@@ -1978,6 +1986,22 @@ object functions {
19781986
*/
19791987
def ceil(columnName: String): Column = ceil(Column(columnName))
19801988

1989+
/**
1990+
* Computes the ceiling of the given value of `e` to `scale` decimal places.
1991+
*
1992+
* @group math_funcs
1993+
* @since 3.5.0
1994+
*/
1995+
def ceiling(e: Column, scale: Column): Column = ceil(e, scale)
1996+
1997+
/**
1998+
* Computes the ceiling of the given value of `e` to 0 decimal places.
1999+
*
2000+
* @group math_funcs
2001+
* @since 3.5.0
2002+
*/
2003+
def ceiling(e: Column): Column = ceil(e)
2004+
19812005
/**
19822006
* Convert a number in a string column from one base to another.
19832007
*
@@ -2053,6 +2077,14 @@ object functions {
20532077
*/
20542078
def csc(e: Column): Column = Column.fn("csc", e)
20552079

2080+
/**
2081+
* Returns Euler's number.
2082+
*
2083+
* @group math_funcs
2084+
* @since 3.5.0
2085+
*/
2086+
def e(): Column = Column.fn("e")
2087+
20562088
/**
20572089
* Computes the exponential of the given value.
20582090
*
@@ -2241,6 +2273,14 @@ object functions {
22412273
def least(columnName: String, columnNames: String*): Column =
22422274
least((columnName +: columnNames).map(Column.apply): _*)
22432275

2276+
/**
2277+
* Computes the natural logarithm of the given value.
2278+
*
2279+
* @group math_funcs
2280+
* @since 3.5.0
2281+
*/
2282+
def ln(e: Column): Column = log(e)
2283+
22442284
/**
22452285
* Computes the natural logarithm of the given value.
22462286
*
@@ -2321,6 +2361,30 @@ object functions {
23212361
*/
23222362
def log2(columnName: String): Column = log2(Column(columnName))
23232363

2364+
/**
2365+
* Returns the negated value.
2366+
*
2367+
* @group math_funcs
2368+
* @since 3.5.0
2369+
*/
2370+
def negative(e: Column): Column = Column.fn("negative", e)
2371+
2372+
/**
2373+
* Returns Pi.
2374+
*
2375+
* @group math_funcs
2376+
* @since 3.5.0
2377+
*/
2378+
def pi(): Column = Column.fn("pi")
2379+
2380+
/**
2381+
* Returns the value.
2382+
*
2383+
* @group math_funcs
2384+
* @since 3.5.0
2385+
*/
2386+
def positive(e: Column): Column = Column.fn("positive", e)
2387+
23242388
/**
23252389
* Returns the value of the first argument raised to the power of the second argument.
23262390
*
@@ -2385,6 +2449,14 @@ object functions {
23852449
*/
23862450
def pow(l: Double, rightName: String): Column = pow(l, Column(rightName))
23872451

2452+
/**
2453+
* Returns the value of the first argument raised to the power of the second argument.
2454+
*
2455+
* @group math_funcs
2456+
* @since 3.5.0
2457+
*/
2458+
def power(l: Column, r: Column): Column = pow(l, r)
2459+
23882460
/**
23892461
* Returns the positive value of dividend mod divisor.
23902462
*
@@ -2514,6 +2586,14 @@ object functions {
25142586
def shiftrightunsigned(e: Column, numBits: Int): Column =
25152587
Column.fn("shiftrightunsigned", e, lit(numBits))
25162588

2589+
/**
2590+
* Computes the signum of the given value.
2591+
*
2592+
* @group math_funcs
2593+
* @since 3.5.0
2594+
*/
2595+
def sign(e: Column): Column = signum(e)
2596+
25172597
/**
25182598
* Computes the signum of the given value.
25192599
*
@@ -2702,6 +2782,27 @@ object functions {
27022782
*/
27032783
def radians(columnName: String): Column = radians(Column(columnName))
27042784

2785+
/**
2786+
* Returns the bucket number into which the value of this expression would fall after being
2787+
* evaluated. Note that input arguments must follow conditions listed below; otherwise, the
2788+
* method will return null.
2789+
*
2790+
* @param v
2791+
* value to compute a bucket number in the histogram
2792+
* @param min
2793+
* minimum value of the histogram
2794+
* @param max
2795+
* maximum value of the histogram
2796+
* @param numBucket
2797+
* the number of buckets
2798+
* @return
2799+
* the bucket number into which the value would fall after being evaluated
2800+
* @group math_funcs
2801+
* @since 3.5.0
2802+
*/
2803+
def width_bucket(v: Column, min: Column, max: Column, numBucket: Column): Column =
2804+
Column.fn("width_bucket", v, min, max, numBucket)
2805+
27052806
//////////////////////////////////////////////////////////////////////////////////////////////
27062807
// Misc functions
27072808
//////////////////////////////////////////////////////////////////////////////////////////////

connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1006,6 +1006,10 @@ class PlanGenerationTestSuite
10061006
fn.stddev("a")
10071007
}
10081008

1009+
functionTest("std") {
1010+
fn.std(fn.col("a"))
1011+
}
1012+
10091013
functionTest("stddev_samp") {
10101014
fn.stddev_samp("a")
10111015
}
@@ -1182,6 +1186,14 @@ class PlanGenerationTestSuite
11821186
fn.ceil(fn.col("b"), lit(2))
11831187
}
11841188

1189+
functionTest("ceiling") {
1190+
fn.ceiling(fn.col("b"))
1191+
}
1192+
1193+
functionTest("ceiling scale") {
1194+
fn.ceiling(fn.col("b"), lit(2))
1195+
}
1196+
11851197
functionTest("conv") {
11861198
fn.conv(fn.col("b"), 10, 16)
11871199
}
@@ -1202,6 +1214,10 @@ class PlanGenerationTestSuite
12021214
fn.csc(fn.col("b"))
12031215
}
12041216

1217+
functionTest("e") {
1218+
fn.e()
1219+
}
1220+
12051221
functionTest("exp") {
12061222
fn.exp("b")
12071223
}
@@ -1246,6 +1262,10 @@ class PlanGenerationTestSuite
12461262
fn.log("b")
12471263
}
12481264

1265+
functionTest("ln") {
1266+
fn.ln(fn.col("b"))
1267+
}
1268+
12491269
functionTest("log with base") {
12501270
fn.log(2, "b")
12511271
}
@@ -1262,10 +1282,26 @@ class PlanGenerationTestSuite
12621282
fn.log2("a")
12631283
}
12641284

1285+
functionTest("negative") {
1286+
fn.negative(fn.col("a"))
1287+
}
1288+
1289+
functionTest("pi") {
1290+
fn.pi()
1291+
}
1292+
1293+
functionTest("positive") {
1294+
fn.positive(fn.col("a"))
1295+
}
1296+
12651297
functionTest("pow") {
12661298
fn.pow("a", "b")
12671299
}
12681300

1301+
functionTest("power") {
1302+
fn.power(fn.col("a"), fn.col("b"))
1303+
}
1304+
12691305
functionTest("pmod") {
12701306
fn.pmod(fn.col("a"), fn.lit(10))
12711307
}
@@ -1302,6 +1338,10 @@ class PlanGenerationTestSuite
13021338
fn.signum("b")
13031339
}
13041340

1341+
functionTest("sign") {
1342+
fn.sign(fn.col("b"))
1343+
}
1344+
13051345
functionTest("sin") {
13061346
fn.sin("b")
13071347
}
@@ -2132,6 +2172,10 @@ class PlanGenerationTestSuite
21322172
simple.groupBy(Column("id")).pivot("a").agg(functions.count(Column("b")))
21332173
}
21342174

2175+
test("width_bucket") {
2176+
simple.select(fn.width_bucket(fn.col("b"), fn.col("b"), fn.col("b"), fn.col("a")))
2177+
}
2178+
21352179
test("test broadcast") {
21362180
left.join(fn.broadcast(right), "id")
21372181
}
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Project [CEIL(b#0) AS CEIL(b)#0L]
2+
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Project [ceil(cast(b#0 as decimal(30,15)), 2) AS ceil(b, 2)#0]
2+
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Project [E() AS E()#0]
2+
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Project [LOG(E(), b#0) AS LOG(E(), b)#0]
2+
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Project [-a#0 AS negative(a)#0]
2+
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Project [PI() AS PI()#0]
2+
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Project [positive(a#0) AS (+ a)#0]
2+
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Project [POWER(cast(a#0 as double), b#0) AS POWER(a, b)#0]
2+
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]

0 commit comments

Comments
 (0)