Skip to content

Commit 84d55ae

Browse files
committed
Add classic pyspark implementation
1 parent 3928491 commit 84d55ae

File tree

7 files changed

+66
-7
lines changed

7 files changed

+66
-7
lines changed

python/pyspark/sql/connect/expressions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -509,6 +509,10 @@ def __repr__(self) -> str:
509509
dt = DateType().fromInternal(self._value)
510510
if dt is not None and isinstance(dt, datetime.date):
511511
return dt.strftime("%Y-%m-%d")
512+
elif isinstance(self._dataType, TimeType):
513+
t = TimeType().fromInternal(self._value)
514+
if t is not None and isinstance(t, datetime.time):
515+
return t.strftime("%H:%M:%S.%f")
512516
elif isinstance(self._dataType, TimestampType):
513517
ts = TimestampType().fromInternal(self._value)
514518
if ts is not None and isinstance(ts, datetime.datetime):

python/pyspark/sql/pandas/types.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
StringType,
3838
BinaryType,
3939
DateType,
40+
TimeType,
4041
TimestampType,
4142
TimestampNTZType,
4243
DayTimeIntervalType,
@@ -302,6 +303,8 @@ def from_arrow_type(at: "pa.DataType", prefer_timestamp_ntz: bool = False) -> Da
302303
spark_type = BinaryType()
303304
elif types.is_date32(at):
304305
spark_type = DateType()
306+
elif types.is_time(at):
307+
spark_type = TimeType()
305308
elif types.is_timestamp(at) and prefer_timestamp_ntz and at.tz is None:
306309
spark_type = TimestampNTZType()
307310
elif types.is_timestamp(at):

python/pyspark/sql/tests/test_functions.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1344,6 +1344,11 @@ def test_shiftrightunsigned(self):
13441344
)
13451345
).collect()
13461346

1347+
def test_lit_time(self):
1348+
t = datetime.time(12, 34, 56)
1349+
actual = self.spark.range(1).select(F.lit(t)).first()[0]
1350+
self.assertEqual(actual, t)
1351+
13471352
def test_lit_day_time_interval(self):
13481353
td = datetime.timedelta(days=1, hours=12, milliseconds=123)
13491354
actual = self.spark.range(1).select(F.lit(td)).first()[0]

python/pyspark/sql/tests/test_sql.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,10 @@ def test_nested_dataframe(self):
168168
self.assertEqual(df3.take(1), [Row(id=4)])
169169
self.assertEqual(df3.tail(1), [Row(id=9)])
170170

171+
def test_lit_time(self):
172+
import datetime
173+
actual = self.spark.sql("select TIME '12:34:56'").first()[0]
174+
self.assertEqual(actual, datetime.time(12, 34, 56))
171175

172176
class SQLTests(SQLTestsMixin, ReusedSQLTestCase):
173177
pass

python/pyspark/sql/tests/test_types.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
IntegerType,
4343
FloatType,
4444
DateType,
45+
TimeType,
4546
TimestampType,
4647
TimestampNTZType,
4748
DayTimeIntervalType,
@@ -525,7 +526,7 @@ def test_create_dataframe_from_objects(self):
525526
self.assertEqual(df.first(), Row(key=1, value="1"))
526527

527528
def test_apply_schema(self):
528-
from datetime import date, datetime, timedelta
529+
from datetime import date, time, datetime, timedelta
529530

530531
rdd = self.sc.parallelize(
531532
[
@@ -537,6 +538,7 @@ def test_apply_schema(self):
537538
2147483647,
538539
1.0,
539540
date(2010, 1, 1),
541+
time(23, 23, 59, 999999),
540542
datetime(2010, 1, 1, 1, 1, 1),
541543
timedelta(days=1),
542544
{"a": 1},
@@ -555,6 +557,7 @@ def test_apply_schema(self):
555557
StructField("int1", IntegerType(), False),
556558
StructField("float1", FloatType(), False),
557559
StructField("date1", DateType(), False),
560+
StructField("time", TimeType(), False),
558561
StructField("time1", TimestampType(), False),
559562
StructField("daytime1", DayTimeIntervalType(), False),
560563
StructField("map1", MapType(StringType(), IntegerType(), False), False),
@@ -573,6 +576,7 @@ def test_apply_schema(self):
573576
x.int1,
574577
x.float1,
575578
x.date1,
579+
x.time,
576580
x.time1,
577581
x.daytime1,
578582
x.map1["a"],
@@ -589,6 +593,7 @@ def test_apply_schema(self):
589593
2147483647,
590594
1.0,
591595
date(2010, 1, 1),
596+
time(23, 23, 59, 999999),
592597
datetime(2010, 1, 1, 1, 1, 1),
593598
timedelta(days=1),
594599
1,
@@ -1241,6 +1246,7 @@ def test_parse_datatype_json_string(self):
12411246
IntegerType(),
12421247
LongType(),
12431248
DateType(),
1249+
TimeType(),
12441250
TimestampType(),
12451251
TimestampNTZType(),
12461252
NullType(),
@@ -1291,6 +1297,7 @@ def test_parse_datatype_string(self):
12911297
_parse_datatype_string("a INT, c DOUBLE"),
12921298
)
12931299
self.assertEqual(VariantType(), _parse_datatype_string("variant"))
1300+
self.assertEqual(TimeType(5), _parse_datatype_string("time(5)"))
12941301

12951302
def test_tree_string(self):
12961303
schema1 = DataType.fromDDL("c1 INT, c2 STRUCT<c3: INT, c4: STRUCT<c5: INT, c6: INT>>")
@@ -1543,6 +1550,7 @@ def test_tree_string_for_builtin_types(self):
15431550
.add("bin", BinaryType())
15441551
.add("bool", BooleanType())
15451552
.add("date", DateType())
1553+
.add("time", TimeType())
15461554
.add("ts", TimestampType())
15471555
.add("ts_ntz", TimestampNTZType())
15481556
.add("dec", DecimalType(10, 2))
@@ -1578,6 +1586,7 @@ def test_tree_string_for_builtin_types(self):
15781586
" |-- bin: binary (nullable = true)",
15791587
" |-- bool: boolean (nullable = true)",
15801588
" |-- date: date (nullable = true)",
1589+
" |-- time: time(6) (nullable = true)",
15811590
" |-- ts: timestamp (nullable = true)",
15821591
" |-- ts_ntz: timestamp_ntz (nullable = true)",
15831592
" |-- dec: decimal(10,2) (nullable = true)",
@@ -1925,6 +1934,7 @@ def test_repr(self):
19251934
BinaryType(),
19261935
BooleanType(),
19271936
DateType(),
1937+
TimeType(),
19281938
TimestampType(),
19291939
DecimalType(),
19301940
DoubleType(),
@@ -2332,8 +2342,8 @@ def test_to_ddl(self):
23322342
schema = StructType().add("a", ArrayType(DoubleType()), False).add("b", DateType())
23332343
self.assertEqual(schema.toDDL(), "a ARRAY<DOUBLE> NOT NULL,b DATE")
23342344

2335-
schema = StructType().add("a", TimestampType()).add("b", TimestampNTZType())
2336-
self.assertEqual(schema.toDDL(), "a TIMESTAMP,b TIMESTAMP_NTZ")
2345+
schema = StructType().add("a", TimestampType()).add("b", TimestampNTZType()).add("c", TimeType())
2346+
self.assertEqual(schema.toDDL(), "a TIMESTAMP,b TIMESTAMP_NTZ,c TIME(6)")
23372347

23382348
def test_from_ddl(self):
23392349
self.assertEqual(DataType.fromDDL("long"), LongType())
@@ -2349,6 +2359,10 @@ def test_from_ddl(self):
23492359
DataType.fromDDL("a int, v variant"),
23502360
StructType([StructField("a", IntegerType()), StructField("v", VariantType())]),
23512361
)
2362+
self.assertEqual(
2363+
DataType.fromDDL("a time(6)"),
2364+
StructType([StructField("a", TimeType(6))]),
2365+
)
23522366

23532367
# Ensures that changing the implementation of `DataType.fromDDL` in PR #47253 does not change
23542368
# `fromDDL`'s behavior.
@@ -2602,8 +2616,9 @@ def __init__(self, **kwargs):
26022616
(decimal.Decimal("1.0"), DecimalType()),
26032617
# Binary
26042618
(bytearray([1, 2]), BinaryType()),
2605-
# Date/Timestamp
2619+
# Date/Time/Timestamp
26062620
(datetime.date(2000, 1, 2), DateType()),
2621+
(datetime.time(1, 0, 0), TimeType()),
26072622
(datetime.datetime(2000, 1, 2, 3, 4), DateType()),
26082623
(datetime.datetime(2000, 1, 2, 3, 4), TimestampType()),
26092624
# Array
@@ -2666,8 +2681,9 @@ def __init__(self, **kwargs):
26662681
("1.0", DecimalType(), TypeError),
26672682
# Binary
26682683
(1, BinaryType(), TypeError),
2669-
# Date/Timestamp
2684+
# Date/Time/Timestamp
26702685
("2000-01-02", DateType(), TypeError),
2686+
("23:59:59", TimeType(), TypeError),
26712687
(946811040, TimestampType(), TypeError),
26722688
# Array
26732689
(["1", None], ArrayType(StringType(), containsNull=False), ValueError),

python/pyspark/sql/types.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,7 @@ def _get_jvm_type_name(cls, dataType: "DataType") -> str:
216216
VarcharType,
217217
DayTimeIntervalType,
218218
YearMonthIntervalType,
219+
TimeType,
219220
),
220221
):
221222
return dataType.simpleString()
@@ -411,6 +412,14 @@ def fromInternal(self, nano: int) -> datetime.time:
411412
microseconds = remainder // 1_000
412413
return datetime.time(hours, minutes, seconds, microseconds)
413414

415+
def simpleString(self) -> str:
416+
return "time(%d)" % (self.precision)
417+
418+
def jsonValue(self) -> str:
419+
return "time(%d)" % (self.precision)
420+
421+
def __repr__(self) -> str:
422+
return "TimeType(%d)" % (self.precision)
414423

415424
class TimestampType(AtomicType, metaclass=DataTypeSingleton):
416425
"""Timestamp (datetime.datetime) data type."""
@@ -2635,6 +2644,7 @@ def convert_struct(obj: Any) -> Optional[Tuple]:
26352644
VarcharType: (str,),
26362645
BinaryType: (bytearray, bytes),
26372646
DateType: (datetime.date, datetime.datetime),
2647+
TimeType: (datetime.time,),
26382648
TimestampType: (datetime.datetime,),
26392649
TimestampNTZType: (datetime.datetime,),
26402650
DayTimeIntervalType: (datetime.timedelta,),
@@ -3240,6 +3250,21 @@ def convert(self, obj: datetime.date, gateway_client: "GatewayClient") -> "JavaG
32403250
Date = JavaClass("java.sql.Date", gateway_client)
32413251
return Date.valueOf(obj.strftime("%Y-%m-%d"))
32423252

3253+
class TimeConverter:
3254+
def can_convert(self, obj: Any) -> bool:
3255+
return isinstance(obj, datetime.time)
3256+
3257+
def convert(self, obj: datetime.time, gateway_client: "GatewayClient") -> "JavaGateway":
3258+
from py4j.java_gateway import JavaClass
3259+
3260+
LocalTime = JavaClass("java.time.LocalTime", gateway_client)
3261+
return LocalTime.of(
3262+
obj.hour,
3263+
obj.minute,
3264+
obj.second,
3265+
obj.microsecond * 1000
3266+
)
3267+
32433268

32443269
class DatetimeConverter:
32453270
def can_convert(self, obj: Any) -> bool:
@@ -3369,6 +3394,7 @@ def convert(self, obj: "np.ndarray", gateway_client: "GatewayClient") -> "JavaGa
33693394
register_input_converter(DatetimeNTZConverter())
33703395
register_input_converter(DatetimeConverter())
33713396
register_input_converter(DateConverter())
3397+
register_input_converter(TimeConverter())
33723398
register_input_converter(DayTimeIntervalTypeConverter())
33733399
register_input_converter(NumpyScalarConverter())
33743400
# NumPy array satisfies py4j.java_collections.ListConverter,

sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ import org.apache.spark.unsafe.types.{UTF8String, VariantVal}
3636
object EvaluatePython {
3737

3838
def needConversionInPython(dt: DataType): Boolean = dt match {
39-
case DateType | TimestampType | TimestampNTZType | VariantType | _: DayTimeIntervalType => true
39+
case DateType | TimestampType | TimestampNTZType | VariantType | _: DayTimeIntervalType
40+
| _: TimeType => true
4041
case _: StructType => true
4142
case _: UserDefinedType[_] => true
4243
case ArrayType(elementType, _) => needConversionInPython(elementType)
@@ -138,7 +139,7 @@ object EvaluatePython {
138139
case c: Int => c
139140
}
140141

141-
case TimestampType | TimestampNTZType | _: DayTimeIntervalType => (obj: Any) =>
142+
case TimestampType | TimestampNTZType | _: DayTimeIntervalType | _: TimeType => (obj: Any) =>
142143
nullSafeConvert(obj) {
143144
case c: Long => c
144145
// Py4J serializes values between MIN_INT and MAX_INT as Ints, not Longs

0 commit comments

Comments
 (0)