-
Notifications
You must be signed in to change notification settings - Fork 2.8k
/
Copy pathfeatures.py
2278 lines (1951 loc) · 90.5 KB
/
features.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2020 The HuggingFace Datasets Authors and the TensorFlow Datasets Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""This class handle features definition in datasets and some utilities to display table type."""
import copy
import json
import re
import sys
from collections.abc import Iterable, Mapping
from collections.abc import Sequence as SequenceABC
from dataclasses import InitVar, dataclass, field, fields
from functools import reduce, wraps
from operator import mul
from typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple, Union
from typing import Sequence as Sequence_
import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.compute as pc
import pyarrow.types
from pandas.api.extensions import ExtensionArray as PandasExtensionArray
from pandas.api.extensions import ExtensionDtype as PandasExtensionDtype
from .. import config
from ..naming import camelcase_to_snakecase, snakecase_to_camelcase
from ..table import array_cast
from ..utils import experimental, logging
from ..utils.py_utils import asdict, first_non_null_value, zip_dict
from .audio import Audio
from .image import Image, encode_pil_image
from .translation import Translation, TranslationVariableLanguages
logger = logging.get_logger(__name__)
def _arrow_to_datasets_dtype(arrow_type: pa.DataType) -> str:
"""
_arrow_to_datasets_dtype takes a pyarrow.DataType and converts it to a datasets string dtype.
In effect, `dt == string_to_arrow(_arrow_to_datasets_dtype(dt))`
"""
if pyarrow.types.is_null(arrow_type):
return "null"
elif pyarrow.types.is_boolean(arrow_type):
return "bool"
elif pyarrow.types.is_int8(arrow_type):
return "int8"
elif pyarrow.types.is_int16(arrow_type):
return "int16"
elif pyarrow.types.is_int32(arrow_type):
return "int32"
elif pyarrow.types.is_int64(arrow_type):
return "int64"
elif pyarrow.types.is_uint8(arrow_type):
return "uint8"
elif pyarrow.types.is_uint16(arrow_type):
return "uint16"
elif pyarrow.types.is_uint32(arrow_type):
return "uint32"
elif pyarrow.types.is_uint64(arrow_type):
return "uint64"
elif pyarrow.types.is_float16(arrow_type):
return "float16" # pyarrow dtype is "halffloat"
elif pyarrow.types.is_float32(arrow_type):
return "float32" # pyarrow dtype is "float"
elif pyarrow.types.is_float64(arrow_type):
return "float64" # pyarrow dtype is "double"
elif pyarrow.types.is_time32(arrow_type):
return f"time32[{pa.type_for_alias(str(arrow_type)).unit}]"
elif pyarrow.types.is_time64(arrow_type):
return f"time64[{pa.type_for_alias(str(arrow_type)).unit}]"
elif pyarrow.types.is_timestamp(arrow_type):
if arrow_type.tz is None:
return f"timestamp[{arrow_type.unit}]"
elif arrow_type.tz:
return f"timestamp[{arrow_type.unit}, tz={arrow_type.tz}]"
else:
raise ValueError(f"Unexpected timestamp object {arrow_type}.")
elif pyarrow.types.is_date32(arrow_type):
return "date32" # pyarrow dtype is "date32[day]"
elif pyarrow.types.is_date64(arrow_type):
return "date64" # pyarrow dtype is "date64[ms]"
elif pyarrow.types.is_duration(arrow_type):
return f"duration[{arrow_type.unit}]"
elif pyarrow.types.is_decimal128(arrow_type):
return f"decimal128({arrow_type.precision}, {arrow_type.scale})"
elif pyarrow.types.is_decimal256(arrow_type):
return f"decimal256({arrow_type.precision}, {arrow_type.scale})"
elif pyarrow.types.is_binary(arrow_type):
return "binary"
elif pyarrow.types.is_large_binary(arrow_type):
return "large_binary"
elif pyarrow.types.is_string(arrow_type):
return "string"
elif pyarrow.types.is_large_string(arrow_type):
return "large_string"
elif pyarrow.types.is_dictionary(arrow_type):
return _arrow_to_datasets_dtype(arrow_type.value_type)
else:
raise ValueError(f"Arrow type {arrow_type} does not have a datasets dtype equivalent.")
def string_to_arrow(datasets_dtype: str) -> pa.DataType:
"""
string_to_arrow takes a datasets string dtype and converts it to a pyarrow.DataType.
In effect, `dt == string_to_arrow(_arrow_to_datasets_dtype(dt))`
This is necessary because the datasets.Value() primitive type is constructed using a string dtype
Value(dtype=str)
But Features.type (via `get_nested_type()` expects to resolve Features into a pyarrow Schema,
which means that each Value() must be able to resolve into a corresponding pyarrow.DataType, which is the
purpose of this function.
"""
def _dtype_error_msg(dtype, pa_dtype, examples=None, urls=None):
msg = f"{dtype} is not a validly formatted string representation of the pyarrow {pa_dtype} type."
if examples:
examples = ", ".join(examples[:-1]) + " or " + examples[-1] if len(examples) > 1 else examples[0]
msg += f"\nValid examples include: {examples}."
if urls:
urls = ", ".join(urls[:-1]) + " and " + urls[-1] if len(urls) > 1 else urls[0]
msg += f"\nFor more insformation, see: {urls}."
return msg
if datasets_dtype in pa.__dict__:
return pa.__dict__[datasets_dtype]()
if (datasets_dtype + "_") in pa.__dict__:
return pa.__dict__[datasets_dtype + "_"]()
timestamp_matches = re.search(r"^timestamp\[(.*)\]$", datasets_dtype)
if timestamp_matches:
timestamp_internals = timestamp_matches.group(1)
internals_matches = re.search(r"^(s|ms|us|ns),\s*tz=([a-zA-Z0-9/_+\-:]*)$", timestamp_internals)
if timestamp_internals in ["s", "ms", "us", "ns"]:
return pa.timestamp(timestamp_internals)
elif internals_matches:
return pa.timestamp(internals_matches.group(1), internals_matches.group(2))
else:
raise ValueError(
_dtype_error_msg(
datasets_dtype,
"timestamp",
examples=["timestamp[us]", "timestamp[us, tz=America/New_York"],
urls=["https://arrow.apache.org/docs/python/generated/pyarrow.timestamp.html"],
)
)
duration_matches = re.search(r"^duration\[(.*)\]$", datasets_dtype)
if duration_matches:
duration_internals = duration_matches.group(1)
if duration_internals in ["s", "ms", "us", "ns"]:
return pa.duration(duration_internals)
else:
raise ValueError(
_dtype_error_msg(
datasets_dtype,
"duration",
examples=["duration[s]", "duration[us]"],
urls=["https://arrow.apache.org/docs/python/generated/pyarrow.duration.html"],
)
)
time_matches = re.search(r"^time(.*)\[(.*)\]$", datasets_dtype)
if time_matches:
time_internals_bits = time_matches.group(1)
if time_internals_bits == "32":
time_internals_unit = time_matches.group(2)
if time_internals_unit in ["s", "ms"]:
return pa.time32(time_internals_unit)
else:
raise ValueError(
f"{time_internals_unit} is not a valid unit for the pyarrow time32 type. Supported units: s (second) and ms (millisecond)."
)
elif time_internals_bits == "64":
time_internals_unit = time_matches.group(2)
if time_internals_unit in ["us", "ns"]:
return pa.time64(time_internals_unit)
else:
raise ValueError(
f"{time_internals_unit} is not a valid unit for the pyarrow time64 type. Supported units: us (microsecond) and ns (nanosecond)."
)
else:
raise ValueError(
_dtype_error_msg(
datasets_dtype,
"time",
examples=["time32[s]", "time64[us]"],
urls=[
"https://arrow.apache.org/docs/python/generated/pyarrow.time32.html",
"https://arrow.apache.org/docs/python/generated/pyarrow.time64.html",
],
)
)
decimal_matches = re.search(r"^decimal(.*)\((.*)\)$", datasets_dtype)
if decimal_matches:
decimal_internals_bits = decimal_matches.group(1)
if decimal_internals_bits == "128":
decimal_internals_precision_and_scale = re.search(r"^(\d+),\s*(-?\d+)$", decimal_matches.group(2))
if decimal_internals_precision_and_scale:
precision = decimal_internals_precision_and_scale.group(1)
scale = decimal_internals_precision_and_scale.group(2)
return pa.decimal128(int(precision), int(scale))
else:
raise ValueError(
_dtype_error_msg(
datasets_dtype,
"decimal128",
examples=["decimal128(10, 2)", "decimal128(4, -2)"],
urls=["https://arrow.apache.org/docs/python/generated/pyarrow.decimal128.html"],
)
)
elif decimal_internals_bits == "256":
decimal_internals_precision_and_scale = re.search(r"^(\d+),\s*(-?\d+)$", decimal_matches.group(2))
if decimal_internals_precision_and_scale:
precision = decimal_internals_precision_and_scale.group(1)
scale = decimal_internals_precision_and_scale.group(2)
return pa.decimal256(int(precision), int(scale))
else:
raise ValueError(
_dtype_error_msg(
datasets_dtype,
"decimal256",
examples=["decimal256(30, 2)", "decimal256(38, -4)"],
urls=["https://arrow.apache.org/docs/python/generated/pyarrow.decimal256.html"],
)
)
else:
raise ValueError(
_dtype_error_msg(
datasets_dtype,
"decimal",
examples=["decimal128(12, 3)", "decimal256(40, 6)"],
urls=[
"https://arrow.apache.org/docs/python/generated/pyarrow.decimal128.html",
"https://arrow.apache.org/docs/python/generated/pyarrow.decimal256.html",
],
)
)
raise ValueError(
f"Neither {datasets_dtype} nor {datasets_dtype + '_'} seems to be a pyarrow data type. "
f"Please make sure to use a correct data type, see: "
f"https://arrow.apache.org/docs/python/api/datatypes.html#factory-functions"
)
def _cast_to_python_objects(obj: Any, only_1d_for_numpy: bool, optimize_list_casting: bool) -> Tuple[Any, bool]:
"""
Cast pytorch/tensorflow/pandas objects to python numpy array/lists.
It works recursively.
If `optimize_list_casting` is True, to avoid iterating over possibly long lists, it first checks (recursively) if the first element that is not None or empty (if it is a sequence) has to be casted.
If the first element needs to be casted, then all the elements of the list will be casted, otherwise they'll stay the same.
This trick allows to cast objects that contain tokenizers outputs without iterating over every single token for example.
Args:
obj: the object (nested struct) to cast.
only_1d_for_numpy (bool): whether to keep the full multi-dim tensors as multi-dim numpy arrays, or convert them to
nested lists of 1-dimensional numpy arrays. This can be useful to keep only 1-d arrays to instantiate Arrow arrays.
Indeed Arrow only support converting 1-dimensional array values.
optimize_list_casting (bool): whether to optimize list casting by checking the first non-null element to see if it needs to be casted
and if it doesn't, not checking the rest of the list elements.
Returns:
casted_obj: the casted object
has_changed (bool): True if the object has been changed, False if it is identical
"""
if config.TF_AVAILABLE and "tensorflow" in sys.modules:
import tensorflow as tf
if config.TORCH_AVAILABLE and "torch" in sys.modules:
import torch
if config.JAX_AVAILABLE and "jax" in sys.modules:
import jax.numpy as jnp
if config.PIL_AVAILABLE and "PIL" in sys.modules:
import PIL.Image
if isinstance(obj, np.ndarray):
if obj.ndim == 0:
return obj[()], True
elif not only_1d_for_numpy or obj.ndim == 1:
return obj, False
else:
return (
[
_cast_to_python_objects(
x, only_1d_for_numpy=only_1d_for_numpy, optimize_list_casting=optimize_list_casting
)[0]
for x in obj
],
True,
)
elif config.TORCH_AVAILABLE and "torch" in sys.modules and isinstance(obj, torch.Tensor):
if obj.dtype == torch.bfloat16:
return _cast_to_python_objects(
obj.detach().to(torch.float).cpu().numpy(),
only_1d_for_numpy=only_1d_for_numpy,
optimize_list_casting=optimize_list_casting,
)[0], True
if obj.ndim == 0:
return obj.detach().cpu().numpy()[()], True
elif not only_1d_for_numpy or obj.ndim == 1:
return obj.detach().cpu().numpy(), True
else:
return (
[
_cast_to_python_objects(
x, only_1d_for_numpy=only_1d_for_numpy, optimize_list_casting=optimize_list_casting
)[0]
for x in obj.detach().cpu().numpy()
],
True,
)
elif config.TF_AVAILABLE and "tensorflow" in sys.modules and isinstance(obj, tf.Tensor):
if obj.ndim == 0:
return obj.numpy()[()], True
elif not only_1d_for_numpy or obj.ndim == 1:
return obj.numpy(), True
else:
return (
[
_cast_to_python_objects(
x, only_1d_for_numpy=only_1d_for_numpy, optimize_list_casting=optimize_list_casting
)[0]
for x in obj.numpy()
],
True,
)
elif config.JAX_AVAILABLE and "jax" in sys.modules and isinstance(obj, jnp.ndarray):
if obj.ndim == 0:
return np.asarray(obj)[()], True
elif not only_1d_for_numpy or obj.ndim == 1:
return np.asarray(obj), True
else:
return (
[
_cast_to_python_objects(
x, only_1d_for_numpy=only_1d_for_numpy, optimize_list_casting=optimize_list_casting
)[0]
for x in np.asarray(obj)
],
True,
)
elif config.PIL_AVAILABLE and "PIL" in sys.modules and isinstance(obj, PIL.Image.Image):
return encode_pil_image(obj), True
elif isinstance(obj, pd.Series):
return (
_cast_to_python_objects(
obj.tolist(), only_1d_for_numpy=only_1d_for_numpy, optimize_list_casting=optimize_list_casting
)[0],
True,
)
elif isinstance(obj, pd.DataFrame):
return (
{
key: _cast_to_python_objects(
value, only_1d_for_numpy=only_1d_for_numpy, optimize_list_casting=optimize_list_casting
)[0]
for key, value in obj.to_dict("series").items()
},
True,
)
elif isinstance(obj, pd.Timestamp):
return obj.to_pydatetime(), True
elif isinstance(obj, pd.Timedelta):
return obj.to_pytimedelta(), True
elif isinstance(obj, Mapping):
has_changed = not isinstance(obj, dict)
output = {}
for k, v in obj.items():
casted_v, has_changed_v = _cast_to_python_objects(
v, only_1d_for_numpy=only_1d_for_numpy, optimize_list_casting=optimize_list_casting
)
has_changed |= has_changed_v
output[k] = casted_v
return output if has_changed else obj, has_changed
elif hasattr(obj, "__array__"):
return (
_cast_to_python_objects(
obj.__array__(), only_1d_for_numpy=only_1d_for_numpy, optimize_list_casting=optimize_list_casting
)[0],
True,
)
elif isinstance(obj, (list, tuple)):
if len(obj) > 0:
for first_elmt in obj:
if _check_non_null_non_empty_recursive(first_elmt):
break
casted_first_elmt, has_changed_first_elmt = _cast_to_python_objects(
first_elmt, only_1d_for_numpy=only_1d_for_numpy, optimize_list_casting=optimize_list_casting
)
if has_changed_first_elmt or not optimize_list_casting:
return (
[
_cast_to_python_objects(
elmt, only_1d_for_numpy=only_1d_for_numpy, optimize_list_casting=optimize_list_casting
)[0]
for elmt in obj
],
True,
)
else:
if isinstance(obj, (list, tuple)):
return obj, False
else:
return list(obj), True
else:
return obj, False
else:
return obj, False
def cast_to_python_objects(obj: Any, only_1d_for_numpy=False, optimize_list_casting=True) -> Any:
"""
Cast numpy/pytorch/tensorflow/pandas objects to python lists.
It works recursively.
If `optimize_list_casting` is True, To avoid iterating over possibly long lists, it first checks (recursively) if the first element that is not None or empty (if it is a sequence) has to be casted.
If the first element needs to be casted, then all the elements of the list will be casted, otherwise they'll stay the same.
This trick allows to cast objects that contain tokenizers outputs without iterating over every single token for example.
Args:
obj: the object (nested struct) to cast
only_1d_for_numpy (bool, default ``False``): whether to keep the full multi-dim tensors as multi-dim numpy arrays, or convert them to
nested lists of 1-dimensional numpy arrays. This can be useful to keep only 1-d arrays to instantiate Arrow arrays.
Indeed Arrow only support converting 1-dimensional array values.
optimize_list_casting (bool, default ``True``): whether to optimize list casting by checking the first non-null element to see if it needs to be casted
and if it doesn't, not checking the rest of the list elements.
Returns:
casted_obj: the casted object
"""
return _cast_to_python_objects(
obj, only_1d_for_numpy=only_1d_for_numpy, optimize_list_casting=optimize_list_casting
)[0]
@dataclass
class Value:
"""
Scalar feature value of a particular data type.
The possible dtypes of `Value` are as follows:
- `null`
- `bool`
- `int8`
- `int16`
- `int32`
- `int64`
- `uint8`
- `uint16`
- `uint32`
- `uint64`
- `float16`
- `float32` (alias float)
- `float64` (alias double)
- `time32[(s|ms)]`
- `time64[(us|ns)]`
- `timestamp[(s|ms|us|ns)]`
- `timestamp[(s|ms|us|ns), tz=(tzstring)]`
- `date32`
- `date64`
- `duration[(s|ms|us|ns)]`
- `decimal128(precision, scale)`
- `decimal256(precision, scale)`
- `binary`
- `large_binary`
- `string`
- `large_string`
Args:
dtype (`str`):
Name of the data type.
Example:
```py
>>> from datasets import Features
>>> features = Features({'stars': Value(dtype='int32')})
>>> features
{'stars': Value(dtype='int32', id=None)}
```
"""
dtype: str
id: Optional[str] = None
# Automatically constructed
pa_type: ClassVar[Any] = None
_type: str = field(default="Value", init=False, repr=False)
def __post_init__(self):
if self.dtype == "double": # fix inferred type
self.dtype = "float64"
if self.dtype == "float": # fix inferred type
self.dtype = "float32"
self.pa_type = string_to_arrow(self.dtype)
def __call__(self):
return self.pa_type
def encode_example(self, value):
if pa.types.is_boolean(self.pa_type):
return bool(value)
elif pa.types.is_integer(self.pa_type):
return int(value)
elif pa.types.is_floating(self.pa_type):
return float(value)
elif pa.types.is_string(self.pa_type):
return str(value)
else:
return value
class _ArrayXD:
def __post_init__(self):
self.shape = tuple(self.shape)
def __call__(self):
pa_type = globals()[self.__class__.__name__ + "ExtensionType"](self.shape, self.dtype)
return pa_type
def encode_example(self, value):
return value
@dataclass
class Array2D(_ArrayXD):
"""Create a two-dimensional array.
Args:
shape (`tuple`):
Size of each dimension.
dtype (`str`):
Name of the data type.
Example:
```py
>>> from datasets import Features
>>> features = Features({'x': Array2D(shape=(1, 3), dtype='int32')})
```
"""
shape: tuple
dtype: str
id: Optional[str] = None
# Automatically constructed
_type: str = field(default="Array2D", init=False, repr=False)
@dataclass
class Array3D(_ArrayXD):
"""Create a three-dimensional array.
Args:
shape (`tuple`):
Size of each dimension.
dtype (`str`):
Name of the data type.
Example:
```py
>>> from datasets import Features
>>> features = Features({'x': Array3D(shape=(1, 2, 3), dtype='int32')})
```
"""
shape: tuple
dtype: str
id: Optional[str] = None
# Automatically constructed
_type: str = field(default="Array3D", init=False, repr=False)
@dataclass
class Array4D(_ArrayXD):
"""Create a four-dimensional array.
Args:
shape (`tuple`):
Size of each dimension.
dtype (`str`):
Name of the data type.
Example:
```py
>>> from datasets import Features
>>> features = Features({'x': Array4D(shape=(1, 2, 2, 3), dtype='int32')})
```
"""
shape: tuple
dtype: str
id: Optional[str] = None
# Automatically constructed
_type: str = field(default="Array4D", init=False, repr=False)
@dataclass
class Array5D(_ArrayXD):
"""Create a five-dimensional array.
Args:
shape (`tuple`):
Size of each dimension.
dtype (`str`):
Name of the data type.
Example:
```py
>>> from datasets import Features
>>> features = Features({'x': Array5D(shape=(1, 2, 2, 3, 3), dtype='int32')})
```
"""
shape: tuple
dtype: str
id: Optional[str] = None
# Automatically constructed
_type: str = field(default="Array5D", init=False, repr=False)
class _ArrayXDExtensionType(pa.ExtensionType):
ndims: Optional[int] = None
def __init__(self, shape: tuple, dtype: str):
if self.ndims is None or self.ndims <= 1:
raise ValueError("You must instantiate an array type with a value for dim that is > 1")
if len(shape) != self.ndims:
raise ValueError(f"shape={shape} and ndims={self.ndims} don't match")
for dim in range(1, self.ndims):
if shape[dim] is None:
raise ValueError(f"Support only dynamic size on first dimension. Got: {shape}")
self.shape = tuple(shape)
self.value_type = dtype
self.storage_dtype = self._generate_dtype(self.value_type)
pa.ExtensionType.__init__(self, self.storage_dtype, f"{self.__class__.__module__}.{self.__class__.__name__}")
def __arrow_ext_serialize__(self):
return json.dumps((self.shape, self.value_type)).encode()
@classmethod
def __arrow_ext_deserialize__(cls, storage_type, serialized):
args = json.loads(serialized)
return cls(*args)
# This was added to pa.ExtensionType in pyarrow >= 13.0.0
def __reduce__(self):
return self.__arrow_ext_deserialize__, (self.storage_type, self.__arrow_ext_serialize__())
def __hash__(self):
return hash((self.__class__, self.shape, self.value_type))
def __arrow_ext_class__(self):
return ArrayExtensionArray
def _generate_dtype(self, dtype):
dtype = string_to_arrow(dtype)
for d in reversed(self.shape):
dtype = pa.list_(dtype)
# Don't specify the size of the list, since fixed length list arrays have issues
# being validated after slicing in pyarrow 0.17.1
return dtype
def to_pandas_dtype(self):
return PandasArrayExtensionDtype(self.value_type)
class Array2DExtensionType(_ArrayXDExtensionType):
ndims = 2
class Array3DExtensionType(_ArrayXDExtensionType):
ndims = 3
class Array4DExtensionType(_ArrayXDExtensionType):
ndims = 4
class Array5DExtensionType(_ArrayXDExtensionType):
ndims = 5
# Register the extension types for deserialization
pa.register_extension_type(Array2DExtensionType((1, 2), "int64"))
pa.register_extension_type(Array3DExtensionType((1, 2, 3), "int64"))
pa.register_extension_type(Array4DExtensionType((1, 2, 3, 4), "int64"))
pa.register_extension_type(Array5DExtensionType((1, 2, 3, 4, 5), "int64"))
def _is_zero_copy_only(pa_type: pa.DataType, unnest: bool = False) -> bool:
"""
When converting a pyarrow array to a numpy array, we must know whether this could be done in zero-copy or not.
This function returns the value of the ``zero_copy_only`` parameter to pass to ``.to_numpy()``, given the type of the pyarrow array.
# zero copy is available for all primitive types except booleans and temporal types (date, time, timestamp or duration)
# primitive types are types for which the physical representation in arrow and in numpy
# https://github.com/wesm/arrow/blob/c07b9b48cf3e0bbbab493992a492ae47e5b04cad/python/pyarrow/types.pxi#L821
# see https://arrow.apache.org/docs/python/generated/pyarrow.Array.html#pyarrow.Array.to_numpy
# and https://issues.apache.org/jira/browse/ARROW-2871?jql=text%20~%20%22boolean%20to_numpy%22
"""
def _unnest_pa_type(pa_type: pa.DataType) -> pa.DataType:
if pa.types.is_list(pa_type):
return _unnest_pa_type(pa_type.value_type)
return pa_type
if unnest:
pa_type = _unnest_pa_type(pa_type)
return pa.types.is_primitive(pa_type) and not (pa.types.is_boolean(pa_type) or pa.types.is_temporal(pa_type))
class ArrayExtensionArray(pa.ExtensionArray):
def __array__(self):
zero_copy_only = _is_zero_copy_only(self.storage.type, unnest=True)
return self.to_numpy(zero_copy_only=zero_copy_only)
def __getitem__(self, i):
return self.storage[i]
def to_numpy(self, zero_copy_only=True):
storage: pa.ListArray = self.storage
null_mask = storage.is_null().to_numpy(zero_copy_only=False)
if self.type.shape[0] is not None:
size = 1
null_indices = np.arange(len(storage))[null_mask] - np.arange(np.sum(null_mask))
for i in range(self.type.ndims):
size *= self.type.shape[i]
storage = storage.flatten()
numpy_arr = storage.to_numpy(zero_copy_only=zero_copy_only)
numpy_arr = numpy_arr.reshape(len(self) - len(null_indices), *self.type.shape)
if len(null_indices):
numpy_arr = np.insert(numpy_arr.astype(np.float64), null_indices, np.nan, axis=0)
else:
shape = self.type.shape
ndims = self.type.ndims
arrays = []
first_dim_offsets = np.array([off.as_py() for off in storage.offsets])
for i, is_null in enumerate(null_mask):
if is_null:
arrays.append(np.nan)
else:
storage_el = storage[i : i + 1]
first_dim = first_dim_offsets[i + 1] - first_dim_offsets[i]
# flatten storage
for _ in range(ndims):
storage_el = storage_el.flatten()
numpy_arr = storage_el.to_numpy(zero_copy_only=zero_copy_only)
arrays.append(numpy_arr.reshape(first_dim, *shape[1:]))
if len(np.unique(np.diff(first_dim_offsets))) > 1:
# ragged
numpy_arr = np.empty(len(arrays), dtype=object)
numpy_arr[:] = arrays
else:
numpy_arr = np.array(arrays)
return numpy_arr
def to_pylist(self):
zero_copy_only = _is_zero_copy_only(self.storage.type, unnest=True)
numpy_arr = self.to_numpy(zero_copy_only=zero_copy_only)
if self.type.shape[0] is None and numpy_arr.dtype == object:
return [arr.tolist() for arr in numpy_arr.tolist()]
else:
return numpy_arr.tolist()
class PandasArrayExtensionDtype(PandasExtensionDtype):
_metadata = "value_type"
def __init__(self, value_type: Union["PandasArrayExtensionDtype", np.dtype]):
self._value_type = value_type
def __from_arrow__(self, array: Union[pa.Array, pa.ChunkedArray]):
if isinstance(array, pa.ChunkedArray):
array = array.type.wrap_array(pa.concat_arrays([chunk.storage for chunk in array.chunks]))
zero_copy_only = _is_zero_copy_only(array.storage.type, unnest=True)
numpy_arr = array.to_numpy(zero_copy_only=zero_copy_only)
return PandasArrayExtensionArray(numpy_arr)
@classmethod
def construct_array_type(cls):
return PandasArrayExtensionArray
@property
def type(self) -> type:
return np.ndarray
@property
def kind(self) -> str:
return "O"
@property
def name(self) -> str:
return f"array[{self.value_type}]"
@property
def value_type(self) -> np.dtype:
return self._value_type
class PandasArrayExtensionArray(PandasExtensionArray):
def __init__(self, data: np.ndarray, copy: bool = False):
self._data = data if not copy else np.array(data)
self._dtype = PandasArrayExtensionDtype(data.dtype)
def __array__(self, dtype=None):
"""
Convert to NumPy Array.
Note that Pandas expects a 1D array when dtype is set to object.
But for other dtypes, the returned shape is the same as the one of ``data``.
More info about pandas 1D requirement for PandasExtensionArray here:
https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.api.extensions.ExtensionArray.html#pandas.api.extensions.ExtensionArray
"""
if dtype == np.dtype(object):
out = np.empty(len(self._data), dtype=object)
for i in range(len(self._data)):
out[i] = self._data[i]
return out
if dtype is None:
return self._data
else:
return self._data.astype(dtype)
def copy(self, deep: bool = False) -> "PandasArrayExtensionArray":
return PandasArrayExtensionArray(self._data, copy=True)
@classmethod
def _from_sequence(
cls, scalars, dtype: Optional[PandasArrayExtensionDtype] = None, copy: bool = False
) -> "PandasArrayExtensionArray":
if len(scalars) > 1 and all(
isinstance(x, np.ndarray) and x.shape == scalars[0].shape and x.dtype == scalars[0].dtype for x in scalars
):
data = np.array(scalars, dtype=dtype if dtype is None else dtype.value_type, copy=copy)
else:
data = np.empty(len(scalars), dtype=object)
data[:] = scalars
return cls(data, copy=copy)
@classmethod
def _concat_same_type(cls, to_concat: Sequence_["PandasArrayExtensionArray"]) -> "PandasArrayExtensionArray":
if len(to_concat) > 1 and all(
va._data.shape == to_concat[0]._data.shape and va._data.dtype == to_concat[0]._data.dtype
for va in to_concat
):
data = np.vstack([va._data for va in to_concat])
else:
data = np.empty(len(to_concat), dtype=object)
data[:] = [va._data for va in to_concat]
return cls(data, copy=False)
@property
def dtype(self) -> PandasArrayExtensionDtype:
return self._dtype
@property
def nbytes(self) -> int:
return self._data.nbytes
def isna(self) -> np.ndarray:
return np.array([pd.isna(arr).any() for arr in self._data])
def __setitem__(self, key: Union[int, slice, np.ndarray], value: Any) -> None:
raise NotImplementedError()
def __getitem__(self, item: Union[int, slice, np.ndarray]) -> Union[np.ndarray, "PandasArrayExtensionArray"]:
if isinstance(item, int):
return self._data[item]
return PandasArrayExtensionArray(self._data[item], copy=False)
def take(
self, indices: Sequence_[int], allow_fill: bool = False, fill_value: bool = None
) -> "PandasArrayExtensionArray":
indices: np.ndarray = np.asarray(indices, dtype=int)
if allow_fill:
fill_value = (
self.dtype.na_value if fill_value is None else np.asarray(fill_value, dtype=self.dtype.value_type)
)
mask = indices == -1
if (indices < -1).any():
raise ValueError("Invalid value in `indices`, must be all >= -1 for `allow_fill` is True")
elif len(self) > 0:
pass
elif not np.all(mask):
raise IndexError("Invalid take for empty PandasArrayExtensionArray, must be all -1.")
else:
data = np.array([fill_value] * len(indices), dtype=self.dtype.value_type)
return PandasArrayExtensionArray(data, copy=False)
took = self._data.take(indices, axis=0)
if allow_fill and mask.any():
took[mask] = [fill_value] * np.sum(mask)
return PandasArrayExtensionArray(took, copy=False)
def __len__(self) -> int:
return len(self._data)
def __eq__(self, other) -> np.ndarray:
if not isinstance(other, PandasArrayExtensionArray):
raise NotImplementedError(f"Invalid type to compare to: {type(other)}")
return (self._data == other._data).all()
def pandas_types_mapper(dtype):
if isinstance(dtype, _ArrayXDExtensionType):
return PandasArrayExtensionDtype(dtype.value_type)
@dataclass
class ClassLabel:
"""Feature type for integer class labels.
There are 3 ways to define a `ClassLabel`, which correspond to the 3 arguments:
* `num_classes`: Create 0 to (num_classes-1) labels.
* `names`: List of label strings.
* `names_file`: File containing the list of labels.
Under the hood the labels are stored as integers.
You can use negative integers to represent unknown/missing labels.
Args:
num_classes (`int`, *optional*):
Number of classes. All labels must be < `num_classes`.
names (`list` of `str`, *optional*):
String names for the integer classes.
The order in which the names are provided is kept.
names_file (`str`, *optional*):
Path to a file with names for the integer classes, one per line.
Example:
```py
>>> from datasets import Features
>>> features = Features({'label': ClassLabel(num_classes=3, names=['bad', 'ok', 'good'])})
>>> features
{'label': ClassLabel(num_classes=3, names=['bad', 'ok', 'good'], id=None)}
```
"""
num_classes: InitVar[Optional[int]] = None # Pseudo-field: ignored by asdict/fields when converting to/from dict
names: List[str] = None
names_file: InitVar[Optional[str]] = None # Pseudo-field: ignored by asdict/fields when converting to/from dict
id: Optional[str] = None
# Automatically constructed
dtype: ClassVar[str] = "int64"
pa_type: ClassVar[Any] = pa.int64()
_str2int: ClassVar[Dict[str, int]] = None
_int2str: ClassVar[Dict[int, int]] = None
_type: str = field(default="ClassLabel", init=False, repr=False)
def __post_init__(self, num_classes, names_file):
self.num_classes = num_classes
self.names_file = names_file
if self.names_file is not None and self.names is not None:
raise ValueError("Please provide either names or names_file but not both.")
# Set self.names
if self.names is None:
if self.names_file is not None:
self.names = self._load_names_from_file(self.names_file)
elif self.num_classes is not None:
self.names = [str(i) for i in range(self.num_classes)]
else:
raise ValueError("Please provide either num_classes, names or names_file.")
elif not isinstance(self.names, SequenceABC):
raise TypeError(f"Please provide names as a list, is {type(self.names)}")