@@ -200,19 +200,62 @@ def test_math_functions():
200
200
201
201
202
202
def test_array_functions ():
203
- data = [[1.0 , 2.0 , 3.0 ], [4.0 , 5.0 ], [6.0 ]]
203
+ data = [[1.0 , 2.0 , 3.0 , 3.0 ], [4.0 , 5.0 , 3 .0 ], [6.0 ]]
204
204
ctx = SessionContext ()
205
205
batch = pa .RecordBatch .from_arrays (
206
206
[np .array (data , dtype = object )], names = ["arr" ]
207
207
)
208
208
df = ctx .create_dataframe ([[batch ]])
209
209
210
+ def py_indexof (arr , v ):
211
+ try :
212
+ return arr .index (v ) + 1
213
+ except ValueError :
214
+ return np .nan
215
+
216
+ def py_arr_remove (arr , v , n = None ):
217
+ new_arr = arr [:]
218
+ found = 0
219
+ while found != n :
220
+ try :
221
+ new_arr .remove (v )
222
+ found += 1
223
+ except ValueError :
224
+ break
225
+
226
+ return new_arr
227
+
228
+ def py_arr_replace (arr , from_ , to , n = None ):
229
+ new_arr = arr [:]
230
+ found = 0
231
+ while found != n :
232
+ try :
233
+ idx = new_arr .index (from_ )
234
+ new_arr [idx ] = to
235
+ found += 1
236
+ except ValueError :
237
+ break
238
+
239
+ return new_arr
240
+
210
241
col = column ("arr" )
211
242
test_items = [
212
243
[
213
244
f .array_append (col , literal (99.0 )),
214
245
lambda : [np .append (arr , 99.0 ) for arr in data ],
215
246
],
247
+ [
248
+ f .array_push_back (col , literal (99.0 )),
249
+ lambda : [np .append (arr , 99.0 ) for arr in data ],
250
+ ],
251
+ [
252
+ f .list_append (col , literal (99.0 )),
253
+ lambda : [np .append (arr , 99.0 ) for arr in data ],
254
+ ],
255
+ [
256
+ f .list_push_back (col , literal (99.0 )),
257
+ lambda : [np .append (arr , 99.0 ) for arr in data ],
258
+ ],
216
259
[
217
260
f .array_concat (col , col ),
218
261
lambda : [np .concatenate ([arr , arr ]) for arr in data ],
@@ -253,12 +296,174 @@ def test_array_functions():
253
296
f .list_length (col ),
254
297
lambda : [len (r ) for r in data ],
255
298
],
299
+ [
300
+ f .array_has (col , literal (1.0 )),
301
+ lambda : [1.0 in r for r in data ],
302
+ ],
303
+ [
304
+ f .array_has_all (
305
+ col , f .make_array (* [literal (v ) for v in [1.0 , 3.0 , 5.0 ]])
306
+ ),
307
+ lambda : [np .all ([v in r for v in [1.0 , 3.0 , 5.0 ]]) for r in data ],
308
+ ],
309
+ [
310
+ f .array_has_any (
311
+ col , f .make_array (* [literal (v ) for v in [1.0 , 3.0 , 5.0 ]])
312
+ ),
313
+ lambda : [np .any ([v in r for v in [1.0 , 3.0 , 5.0 ]]) for r in data ],
314
+ ],
315
+ [
316
+ f .array_position (col , literal (1.0 )),
317
+ lambda : [py_indexof (r , 1.0 ) for r in data ],
318
+ ],
319
+ [
320
+ f .array_indexof (col , literal (1.0 )),
321
+ lambda : [py_indexof (r , 1.0 ) for r in data ],
322
+ ],
323
+ [
324
+ f .list_position (col , literal (1.0 )),
325
+ lambda : [py_indexof (r , 1.0 ) for r in data ],
326
+ ],
327
+ [
328
+ f .list_indexof (col , literal (1.0 )),
329
+ lambda : [py_indexof (r , 1.0 ) for r in data ],
330
+ ],
331
+ [
332
+ f .array_positions (col , literal (1.0 )),
333
+ lambda : [
334
+ [i + 1 for i , _v in enumerate (r ) if _v == 1.0 ] for r in data
335
+ ],
336
+ ],
337
+ [
338
+ f .list_positions (col , literal (1.0 )),
339
+ lambda : [
340
+ [i + 1 for i , _v in enumerate (r ) if _v == 1.0 ] for r in data
341
+ ],
342
+ ],
343
+ [
344
+ f .array_ndims (col ),
345
+ lambda : [np .array (r ).ndim for r in data ],
346
+ ],
347
+ [
348
+ f .list_ndims (col ),
349
+ lambda : [np .array (r ).ndim for r in data ],
350
+ ],
351
+ [
352
+ f .array_prepend (literal (99.0 ), col ),
353
+ lambda : [np .insert (arr , 0 , 99.0 ) for arr in data ],
354
+ ],
355
+ [
356
+ f .array_push_front (literal (99.0 ), col ),
357
+ lambda : [np .insert (arr , 0 , 99.0 ) for arr in data ],
358
+ ],
359
+ [
360
+ f .list_prepend (literal (99.0 ), col ),
361
+ lambda : [np .insert (arr , 0 , 99.0 ) for arr in data ],
362
+ ],
363
+ [
364
+ f .list_push_front (literal (99.0 ), col ),
365
+ lambda : [np .insert (arr , 0 , 99.0 ) for arr in data ],
366
+ ],
367
+ [
368
+ f .array_pop_back (col ),
369
+ lambda : [arr [:- 1 ] for arr in data ],
370
+ ],
371
+ [
372
+ f .array_pop_front (col ),
373
+ lambda : [arr [1 :] for arr in data ],
374
+ ],
375
+ [
376
+ f .array_remove (col , literal (3.0 )),
377
+ lambda : [py_arr_remove (arr , 3.0 , 1 ) for arr in data ],
378
+ ],
379
+ [
380
+ f .list_remove (col , literal (3.0 )),
381
+ lambda : [py_arr_remove (arr , 3.0 , 1 ) for arr in data ],
382
+ ],
383
+ [
384
+ f .array_remove_n (col , literal (3.0 ), literal (2 )),
385
+ lambda : [py_arr_remove (arr , 3.0 , 2 ) for arr in data ],
386
+ ],
387
+ [
388
+ f .list_remove_n (col , literal (3.0 ), literal (2 )),
389
+ lambda : [py_arr_remove (arr , 3.0 , 2 ) for arr in data ],
390
+ ],
391
+ [
392
+ f .array_remove_all (col , literal (3.0 )),
393
+ lambda : [py_arr_remove (arr , 3.0 ) for arr in data ],
394
+ ],
395
+ [
396
+ f .list_remove_all (col , literal (3.0 )),
397
+ lambda : [py_arr_remove (arr , 3.0 ) for arr in data ],
398
+ ],
399
+ [
400
+ f .array_repeat (col , literal (2 )),
401
+ lambda : [[arr ] * 2 for arr in data ],
402
+ ],
403
+ [
404
+ f .array_replace (col , literal (3.0 ), literal (4.0 )),
405
+ lambda : [py_arr_replace (arr , 3.0 , 4.0 , 1 ) for arr in data ],
406
+ ],
407
+ [
408
+ f .list_replace (col , literal (3.0 ), literal (4.0 )),
409
+ lambda : [py_arr_replace (arr , 3.0 , 4.0 , 1 ) for arr in data ],
410
+ ],
411
+ [
412
+ f .array_replace_n (col , literal (3.0 ), literal (4.0 ), literal (1 )),
413
+ lambda : [py_arr_replace (arr , 3.0 , 4.0 , 1 ) for arr in data ],
414
+ ],
415
+ [
416
+ f .list_replace_n (col , literal (3.0 ), literal (4.0 ), literal (2 )),
417
+ lambda : [py_arr_replace (arr , 3.0 , 4.0 , 2 ) for arr in data ],
418
+ ],
419
+ [
420
+ f .array_replace_all (col , literal (3.0 ), literal (4.0 )),
421
+ lambda : [py_arr_replace (arr , 3.0 , 4.0 ) for arr in data ],
422
+ ],
423
+ [
424
+ f .list_replace_all (col , literal (3.0 ), literal (4.0 )),
425
+ lambda : [py_arr_replace (arr , 3.0 , 4.0 ) for arr in data ],
426
+ ],
427
+ [
428
+ f .array_slice (col , literal (2 ), literal (4 )),
429
+ lambda : [arr [1 :4 ] for arr in data ],
430
+ ],
431
+ [
432
+ f .list_slice (col , literal (- 1 ), literal (2 )),
433
+ lambda : [arr [- 1 :2 ] for arr in data ],
434
+ ],
256
435
]
257
436
258
437
for stmt , py_expr in test_items :
259
- query_result = df .select (stmt ).collect ()[0 ].column (0 ).tolist ()
438
+ query_result = df .select (stmt ).collect ()[0 ].column (0 )
439
+ for a , b in zip (query_result , py_expr ()):
440
+ np .testing .assert_array_almost_equal (
441
+ np .array (a .as_py (), dtype = float ), np .array (b , dtype = float )
442
+ )
443
+
444
+ obj_test_items = [
445
+ [
446
+ f .array_to_string (col , literal ("," )),
447
+ lambda : ["," .join ([str (int (v )) for v in r ]) for r in data ],
448
+ ],
449
+ [
450
+ f .array_join (col , literal ("," )),
451
+ lambda : ["," .join ([str (int (v )) for v in r ]) for r in data ],
452
+ ],
453
+ [
454
+ f .list_to_string (col , literal ("," )),
455
+ lambda : ["," .join ([str (int (v )) for v in r ]) for r in data ],
456
+ ],
457
+ [
458
+ f .list_join (col , literal ("," )),
459
+ lambda : ["," .join ([str (int (v )) for v in r ]) for r in data ],
460
+ ],
461
+ ]
462
+
463
+ for stmt , py_expr in obj_test_items :
464
+ query_result = np .array (df .select (stmt ).collect ()[0 ].column (0 ))
260
465
for a , b in zip (query_result , py_expr ()):
261
- np . testing . assert_array_almost_equal ( a , b )
466
+ assert a == b
262
467
263
468
264
469
def test_string_functions (df ):
0 commit comments