Skip to content

Commit 2ad1338

Browse files
committed
Add classic pyspark implementation
1 parent 55a119d commit 2ad1338

File tree

8 files changed

+70
-9
lines changed

8 files changed

+70
-9
lines changed

python/pyspark/sql/connect/expressions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -509,6 +509,10 @@ def __repr__(self) -> str:
509509
dt = DateType().fromInternal(self._value)
510510
if dt is not None and isinstance(dt, datetime.date):
511511
return dt.strftime("%Y-%m-%d")
512+
elif isinstance(self._dataType, TimeType):
513+
t = TimeType().fromInternal(self._value)
514+
if t is not None and isinstance(t, datetime.time):
515+
return t.strftime("%H:%M:%S.%f")
512516
elif isinstance(self._dataType, TimestampType):
513517
ts = TimestampType().fromInternal(self._value)
514518
if ts is not None and isinstance(ts, datetime.datetime):

python/pyspark/sql/connect/types.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def pyspark_types_to_proto_types(data_type: DataType) -> pb2.DataType:
153153
elif isinstance(data_type, DateType):
154154
ret.date.CopyFrom(pb2.DataType.Date())
155155
elif isinstance(data_type, TimeType):
156-
ret.time.CopyFrom(pb2.DataType.Time())
156+
ret.time.precision = data_type.precision
157157
elif isinstance(data_type, TimestampType):
158158
ret.timestamp.CopyFrom(pb2.DataType.Timestamp())
159159
elif isinstance(data_type, TimestampNTZType):
@@ -241,7 +241,7 @@ def proto_schema_to_pyspark_data_type(schema: pb2.DataType) -> DataType:
241241
elif schema.HasField("date"):
242242
return DateType()
243243
elif schema.HasField("time"):
244-
return TimeType()
244+
return TimeType(schema.time.precision) if schema.time.HasField("precision") else TimeType()
245245
elif schema.HasField("timestamp"):
246246
return TimestampType()
247247
elif schema.HasField("timestamp_ntz"):

python/pyspark/sql/pandas/types.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
StringType,
3838
BinaryType,
3939
DateType,
40+
TimeType,
4041
TimestampType,
4142
TimestampNTZType,
4243
DayTimeIntervalType,
@@ -302,6 +303,8 @@ def from_arrow_type(at: "pa.DataType", prefer_timestamp_ntz: bool = False) -> Da
302303
spark_type = BinaryType()
303304
elif types.is_date32(at):
304305
spark_type = DateType()
306+
elif types.is_time(at):
307+
spark_type = TimeType()
305308
elif types.is_timestamp(at) and prefer_timestamp_ntz and at.tz is None:
306309
spark_type = TimestampNTZType()
307310
elif types.is_timestamp(at):

python/pyspark/sql/tests/test_functions.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1344,6 +1344,11 @@ def test_shiftrightunsigned(self):
13441344
)
13451345
).collect()
13461346

1347+
def test_lit_time(self):
1348+
t = datetime.time(12, 34, 56)
1349+
actual = self.spark.range(1).select(F.lit(t)).first()[0]
1350+
self.assertEqual(actual, t)
1351+
13471352
def test_lit_day_time_interval(self):
13481353
td = datetime.timedelta(days=1, hours=12, milliseconds=123)
13491354
actual = self.spark.range(1).select(F.lit(td)).first()[0]

python/pyspark/sql/tests/test_sql.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,12 @@ def test_nested_dataframe(self):
168168
self.assertEqual(df3.take(1), [Row(id=4)])
169169
self.assertEqual(df3.tail(1), [Row(id=9)])
170170

171+
def test_lit_time(self):
172+
import datetime
173+
174+
actual = self.spark.sql("select TIME '12:34:56'").first()[0]
175+
self.assertEqual(actual, datetime.time(12, 34, 56))
176+
171177

172178
class SQLTests(SQLTestsMixin, ReusedSQLTestCase):
173179
pass

python/pyspark/sql/tests/test_types.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
IntegerType,
4343
FloatType,
4444
DateType,
45+
TimeType,
4546
TimestampType,
4647
TimestampNTZType,
4748
DayTimeIntervalType,
@@ -525,7 +526,7 @@ def test_create_dataframe_from_objects(self):
525526
self.assertEqual(df.first(), Row(key=1, value="1"))
526527

527528
def test_apply_schema(self):
528-
from datetime import date, datetime, timedelta
529+
from datetime import date, time, datetime, timedelta
529530

530531
rdd = self.sc.parallelize(
531532
[
@@ -537,6 +538,7 @@ def test_apply_schema(self):
537538
2147483647,
538539
1.0,
539540
date(2010, 1, 1),
541+
time(23, 23, 59, 999999),
540542
datetime(2010, 1, 1, 1, 1, 1),
541543
timedelta(days=1),
542544
{"a": 1},
@@ -555,6 +557,7 @@ def test_apply_schema(self):
555557
StructField("int1", IntegerType(), False),
556558
StructField("float1", FloatType(), False),
557559
StructField("date1", DateType(), False),
560+
StructField("time", TimeType(), False),
558561
StructField("time1", TimestampType(), False),
559562
StructField("daytime1", DayTimeIntervalType(), False),
560563
StructField("map1", MapType(StringType(), IntegerType(), False), False),
@@ -573,6 +576,7 @@ def test_apply_schema(self):
573576
x.int1,
574577
x.float1,
575578
x.date1,
579+
x.time,
576580
x.time1,
577581
x.daytime1,
578582
x.map1["a"],
@@ -589,6 +593,7 @@ def test_apply_schema(self):
589593
2147483647,
590594
1.0,
591595
date(2010, 1, 1),
596+
time(23, 23, 59, 999999),
592597
datetime(2010, 1, 1, 1, 1, 1),
593598
timedelta(days=1),
594599
1,
@@ -1241,6 +1246,7 @@ def test_parse_datatype_json_string(self):
12411246
IntegerType(),
12421247
LongType(),
12431248
DateType(),
1249+
TimeType(5),
12441250
TimestampType(),
12451251
TimestampNTZType(),
12461252
NullType(),
@@ -1291,6 +1297,8 @@ def test_parse_datatype_string(self):
12911297
_parse_datatype_string("a INT, c DOUBLE"),
12921298
)
12931299
self.assertEqual(VariantType(), _parse_datatype_string("variant"))
1300+
self.assertEqual(TimeType(5), _parse_datatype_string("time(5)"))
1301+
self.assertEqual(TimeType(), _parse_datatype_string("time( 6 )"))
12941302

12951303
def test_tree_string(self):
12961304
schema1 = DataType.fromDDL("c1 INT, c2 STRUCT<c3: INT, c4: STRUCT<c5: INT, c6: INT>>")
@@ -1543,6 +1551,7 @@ def test_tree_string_for_builtin_types(self):
15431551
.add("bin", BinaryType())
15441552
.add("bool", BooleanType())
15451553
.add("date", DateType())
1554+
.add("time", TimeType())
15461555
.add("ts", TimestampType())
15471556
.add("ts_ntz", TimestampNTZType())
15481557
.add("dec", DecimalType(10, 2))
@@ -1578,6 +1587,7 @@ def test_tree_string_for_builtin_types(self):
15781587
" |-- bin: binary (nullable = true)",
15791588
" |-- bool: boolean (nullable = true)",
15801589
" |-- date: date (nullable = true)",
1590+
" |-- time: time(6) (nullable = true)",
15811591
" |-- ts: timestamp (nullable = true)",
15821592
" |-- ts_ntz: timestamp_ntz (nullable = true)",
15831593
" |-- dec: decimal(10,2) (nullable = true)",
@@ -1925,6 +1935,7 @@ def test_repr(self):
19251935
BinaryType(),
19261936
BooleanType(),
19271937
DateType(),
1938+
TimeType(),
19281939
TimestampType(),
19291940
DecimalType(),
19301941
DoubleType(),
@@ -2332,8 +2343,10 @@ def test_to_ddl(self):
23322343
schema = StructType().add("a", ArrayType(DoubleType()), False).add("b", DateType())
23332344
self.assertEqual(schema.toDDL(), "a ARRAY<DOUBLE> NOT NULL,b DATE")
23342345

2335-
schema = StructType().add("a", TimestampType()).add("b", TimestampNTZType())
2336-
self.assertEqual(schema.toDDL(), "a TIMESTAMP,b TIMESTAMP_NTZ")
2346+
schema = (
2347+
StructType().add("a", TimestampType()).add("b", TimestampNTZType()).add("c", TimeType())
2348+
)
2349+
self.assertEqual(schema.toDDL(), "a TIMESTAMP,b TIMESTAMP_NTZ,c TIME(6)")
23372350

23382351
def test_from_ddl(self):
23392352
self.assertEqual(DataType.fromDDL("long"), LongType())
@@ -2349,6 +2362,10 @@ def test_from_ddl(self):
23492362
DataType.fromDDL("a int, v variant"),
23502363
StructType([StructField("a", IntegerType()), StructField("v", VariantType())]),
23512364
)
2365+
self.assertEqual(
2366+
DataType.fromDDL("a time(6)"),
2367+
StructType([StructField("a", TimeType(6))]),
2368+
)
23522369

23532370
# Ensures that changing the implementation of `DataType.fromDDL` in PR #47253 does not change
23542371
# `fromDDL`'s behavior.
@@ -2602,8 +2619,9 @@ def __init__(self, **kwargs):
26022619
(decimal.Decimal("1.0"), DecimalType()),
26032620
# Binary
26042621
(bytearray([1, 2]), BinaryType()),
2605-
# Date/Timestamp
2622+
# Date/Time/Timestamp
26062623
(datetime.date(2000, 1, 2), DateType()),
2624+
(datetime.time(1, 0, 0), TimeType()),
26072625
(datetime.datetime(2000, 1, 2, 3, 4), DateType()),
26082626
(datetime.datetime(2000, 1, 2, 3, 4), TimestampType()),
26092627
# Array
@@ -2666,8 +2684,9 @@ def __init__(self, **kwargs):
26662684
("1.0", DecimalType(), TypeError),
26672685
# Binary
26682686
(1, BinaryType(), TypeError),
2669-
# Date/Timestamp
2687+
# Date/Time/Timestamp
26702688
("2000-01-02", DateType(), TypeError),
2689+
("23:59:59", TimeType(), TypeError),
26712690
(946811040, TimestampType(), TypeError),
26722691
# Array
26732692
(["1", None], ArrayType(StringType(), containsNull=False), ValueError),

python/pyspark/sql/types.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,7 @@ def _get_jvm_type_name(cls, dataType: "DataType") -> str:
216216
VarcharType,
217217
DayTimeIntervalType,
218218
YearMonthIntervalType,
219+
TimeType,
219220
),
220221
):
221222
return dataType.simpleString()
@@ -411,6 +412,15 @@ def fromInternal(self, nano: int) -> datetime.time:
411412
microseconds = remainder // 1_000
412413
return datetime.time(hours, minutes, seconds, microseconds)
413414

415+
def simpleString(self) -> str:
416+
return "time(%d)" % (self.precision)
417+
418+
def jsonValue(self) -> str:
419+
return "time(%d)" % (self.precision)
420+
421+
def __repr__(self) -> str:
422+
return "TimeType(%d)" % (self.precision)
423+
414424

415425
class TimestampType(AtomicType, metaclass=DataTypeSingleton):
416426
"""Timestamp (datetime.datetime) data type."""
@@ -2635,6 +2645,7 @@ def convert_struct(obj: Any) -> Optional[Tuple]:
26352645
VarcharType: (str,),
26362646
BinaryType: (bytearray, bytes),
26372647
DateType: (datetime.date, datetime.datetime),
2648+
TimeType: (datetime.time,),
26382649
TimestampType: (datetime.datetime,),
26392650
TimestampNTZType: (datetime.datetime,),
26402651
DayTimeIntervalType: (datetime.timedelta,),
@@ -3241,6 +3252,17 @@ def convert(self, obj: datetime.date, gateway_client: "GatewayClient") -> "JavaG
32413252
return Date.valueOf(obj.strftime("%Y-%m-%d"))
32423253

32433254

3255+
class TimeConverter:
3256+
def can_convert(self, obj: Any) -> bool:
3257+
return isinstance(obj, datetime.time)
3258+
3259+
def convert(self, obj: datetime.time, gateway_client: "GatewayClient") -> "JavaGateway":
3260+
from py4j.java_gateway import JavaClass
3261+
3262+
LocalTime = JavaClass("java.time.LocalTime", gateway_client)
3263+
return LocalTime.of(obj.hour, obj.minute, obj.second, obj.microsecond * 1000)
3264+
3265+
32443266
class DatetimeConverter:
32453267
def can_convert(self, obj: Any) -> bool:
32463268
return isinstance(obj, datetime.datetime)
@@ -3369,6 +3391,7 @@ def convert(self, obj: "np.ndarray", gateway_client: "GatewayClient") -> "JavaGa
33693391
register_input_converter(DatetimeNTZConverter())
33703392
register_input_converter(DatetimeConverter())
33713393
register_input_converter(DateConverter())
3394+
register_input_converter(TimeConverter())
33723395
register_input_converter(DayTimeIntervalTypeConverter())
33733396
register_input_converter(NumpyScalarConverter())
33743397
# NumPy array satisfies py4j.java_collections.ListConverter,

sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ import org.apache.spark.unsafe.types.{UTF8String, VariantVal}
3636
object EvaluatePython {
3737

3838
def needConversionInPython(dt: DataType): Boolean = dt match {
39-
case DateType | TimestampType | TimestampNTZType | VariantType | _: DayTimeIntervalType => true
39+
case DateType | TimestampType | TimestampNTZType | VariantType | _: DayTimeIntervalType
40+
| _: TimeType => true
4041
case _: StructType => true
4142
case _: UserDefinedType[_] => true
4243
case ArrayType(elementType, _) => needConversionInPython(elementType)
@@ -138,7 +139,7 @@ object EvaluatePython {
138139
case c: Int => c
139140
}
140141

141-
case TimestampType | TimestampNTZType | _: DayTimeIntervalType => (obj: Any) =>
142+
case TimestampType | TimestampNTZType | _: DayTimeIntervalType | _: TimeType => (obj: Any) =>
142143
nullSafeConvert(obj) {
143144
case c: Long => c
144145
// Py4J serializes values between MIN_INT and MAX_INT as Ints, not Longs

0 commit comments

Comments
 (0)