Skip to content

Commit 3eda980

Browse files
authored
Adding pl.Struct support for pl.Dataframe (pola-rs#306)
Adding missing pl.Struct support for pl.Dataframe from rows to close pola-rs#298
1 parent 1b7a42f commit 3eda980

File tree

2 files changed

+214
-61
lines changed

2 files changed

+214
-61
lines changed

__tests__/dataframe.test.ts

+95-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ describe("dataframe", () => {
77
pl.Series("foo", [1, 2, 9], pl.Int16),
88
pl.Series("bar", [6, 2, 8], pl.Int16),
99
]);
10-
1110
test("dtypes", () => {
1211
const expected = [pl.Float64, pl.String];
1312
const actual = pl.DataFrame({ a: [1, 2, 3], b: ["a", "b", "c"] }).dtypes;
@@ -1318,6 +1317,101 @@ describe("dataframe", () => {
13181317
]);
13191318
expect(actual).toFrameEqual(expected);
13201319
});
1320+
test("df from JSON with multiple struct", () => {
1321+
const rows = [
1322+
{
1323+
id: 1,
1324+
name: "one",
1325+
attributes: {
1326+
b: false,
1327+
bb: true,
1328+
s: "one",
1329+
x: 1,
1330+
att2: { s: "two", y: 2, att3: { s: "three", y: 3 } },
1331+
},
1332+
},
1333+
];
1334+
1335+
const actual = pl.DataFrame(rows);
1336+
const expected = `shape: (1,)
1337+
Series: 'attributes' [struct[5]]
1338+
[
1339+
{false,true,"one",1.0,{"two",2.0,{"three",3.0}}}
1340+
]`;
1341+
expect(actual.select("attributes").toSeries().toString()).toEqual(expected);
1342+
});
1343+
test("df from JSON with struct", () => {
1344+
const rows = [
1345+
{
1346+
id: 1,
1347+
name: "one",
1348+
attributes: { b: false, bb: true, s: "one", x: 1 },
1349+
},
1350+
{
1351+
id: 2,
1352+
name: "two",
1353+
attributes: { b: false, bb: true, s: "two", x: 2 },
1354+
},
1355+
{
1356+
id: 3,
1357+
name: "three",
1358+
attributes: { b: false, bb: true, s: "three", x: 3 },
1359+
},
1360+
];
1361+
1362+
let actual = pl.DataFrame(rows);
1363+
expect(actual.schema).toStrictEqual({
1364+
id: pl.Float64,
1365+
name: pl.String,
1366+
attributes: pl.Struct([
1367+
new pl.Field("b", pl.Bool),
1368+
new pl.Field("bb", pl.Bool),
1369+
new pl.Field("s", pl.String),
1370+
new pl.Field("x", pl.Float64),
1371+
]),
1372+
});
1373+
1374+
let expected = `shape: (3, 3)
1375+
┌─────┬───────┬──────────────────────────┐
1376+
│ id ┆ name ┆ attributes │
1377+
│ --- ┆ --- ┆ --- │
1378+
│ f64 ┆ str ┆ struct[4] │
1379+
╞═════╪═══════╪══════════════════════════╡
1380+
│ 1.0 ┆ one ┆ {false,true,"one",1.0} │
1381+
│ 2.0 ┆ two ┆ {false,true,"two",2.0} │
1382+
│ 3.0 ┆ three ┆ {false,true,"three",3.0} │
1383+
└─────┴───────┴──────────────────────────┘`;
1384+
expect(actual.toString()).toStrictEqual(expected);
1385+
1386+
const schema = {
1387+
id: pl.Int32,
1388+
name: pl.String,
1389+
attributes: pl.Struct([
1390+
new pl.Field("b", pl.Bool),
1391+
new pl.Field("bb", pl.Bool),
1392+
new pl.Field("s", pl.String),
1393+
new pl.Field("x", pl.Int16),
1394+
]),
1395+
};
1396+
actual = pl.DataFrame(rows, { schema: schema });
1397+
expected = `shape: (3, 3)
1398+
┌─────┬───────┬────────────────────────┐
1399+
│ id ┆ name ┆ attributes │
1400+
│ --- ┆ --- ┆ --- │
1401+
│ i32 ┆ str ┆ struct[4] │
1402+
╞═════╪═══════╪════════════════════════╡
1403+
│ 1 ┆ one ┆ {false,true,"one",1} │
1404+
│ 2 ┆ two ┆ {false,true,"two",2} │
1405+
│ 3 ┆ three ┆ {false,true,"three",3} │
1406+
└─────┴───────┴────────────────────────┘`;
1407+
expect(actual.toString()).toStrictEqual(expected);
1408+
expect(actual.getColumn("name").toArray()).toEqual(
1409+
rows.map((e) => e["name"]),
1410+
);
1411+
expect(actual.getColumn("attributes").toArray()).toMatchObject(
1412+
rows.map((e) => e["attributes"]),
1413+
);
1414+
});
13211415
test("pivot", () => {
13221416
{
13231417
const df = pl.DataFrame({

src/dataframe.rs

+119-60
Original file line numberDiff line numberDiff line change
@@ -442,27 +442,26 @@ pub fn from_rows(
442442
infer_schema(pairs, infer_schema_length)
443443
}
444444
};
445-
let len = rows.len();
446-
let it: Vec<Row> = (0..len)
445+
let it: Vec<Row> = (0..rows.len())
447446
.into_iter()
448447
.map(|idx| {
449448
let obj = rows
450449
.get::<Object>(idx as u32)
451450
.unwrap_or(None)
452451
.unwrap_or_else(|| env.create_object().unwrap());
452+
453453
Row(schema
454454
.iter_fields()
455455
.map(|fld| {
456-
let dtype = fld.dtype().clone();
457-
let key = fld.name();
458-
if let Ok(unknown) = obj.get(key) {
459-
let av = match unknown {
460-
Some(unknown) => unsafe {
461-
coerce_js_anyvalue(unknown, dtype).unwrap_or(AnyValue::Null)
462-
},
463-
None => AnyValue::Null,
464-
};
465-
av
456+
let dtype: &DataType = fld.dtype();
457+
let key: &PlSmallStr = fld.name();
458+
if let Ok(unknown) = obj.get::<&polars::prelude::PlSmallStr, JsUnknown>(key) {
459+
match unknown {
460+
Some(unknown) => {
461+
coerce_js_anyvalue(unknown, dtype.clone()).unwrap_or(AnyValue::Null)
462+
}
463+
_ => AnyValue::Null,
464+
}
466465
} else {
467466
AnyValue::Null
468467
}
@@ -1620,61 +1619,79 @@ fn obj_to_pairs(rows: &Array, len: usize) -> impl '_ + Iterator<Item = Vec<(Stri
16201619
let len = std::cmp::min(len, rows.len() as usize);
16211620
(0..len).map(move |idx| {
16221621
let obj = rows.get::<Object>(idx as u32).unwrap().unwrap();
1623-
16241622
let keys = Object::keys(&obj).unwrap();
16251623
keys.iter()
16261624
.map(|key| {
16271625
let value = obj.get::<_, napi::JsUnknown>(&key).unwrap_or(None);
1628-
let dtype = match value {
1629-
Some(val) => {
1630-
let ty = val.get_type().unwrap();
1631-
match ty {
1632-
ValueType::Boolean => DataType::Boolean,
1633-
ValueType::Number => DataType::Float64,
1634-
ValueType::String => DataType::String,
1635-
ValueType::Object => {
1636-
if val.is_array().unwrap() {
1637-
let arr: napi::JsObject = unsafe { val.cast() };
1638-
let len = arr.get_array_length().unwrap();
1639-
1640-
if len == 0 {
1641-
DataType::List(DataType::Null.into())
1642-
} else {
1643-
// dont compare too many items, as it could be expensive
1644-
let max_take = std::cmp::min(len as usize, 10);
1645-
let mut dtypes: Vec<DataType> =
1646-
Vec::with_capacity(len as usize);
1647-
1648-
for idx in 0..max_take {
1649-
let item: napi::JsUnknown =
1650-
arr.get_element(idx as u32).unwrap();
1651-
let ty = item.get_type().unwrap();
1652-
let dt: Wrap<DataType> = ty.into();
1653-
dtypes.push(dt.0)
1654-
}
1655-
let dtype = coerce_data_type(&dtypes);
1656-
1657-
DataType::List(dtype.into())
1658-
}
1659-
} else if val.is_date().unwrap() {
1660-
DataType::Datetime(TimeUnit::Milliseconds, None)
1661-
} else {
1662-
DataType::Struct(vec![])
1663-
}
1664-
}
1665-
ValueType::BigInt => DataType::UInt64,
1666-
_ => DataType::Null,
1667-
}
1668-
}
1669-
None => DataType::Null,
1670-
};
1671-
(key.to_owned(), dtype)
1626+
(key.to_owned(), obj_to_type(value))
16721627
})
16731628
.collect()
16741629
})
16751630
}
16761631

1677-
unsafe fn coerce_js_anyvalue<'a>(val: JsUnknown, dtype: DataType) -> JsResult<AnyValue<'a>> {
1632+
fn obj_to_type(value: Option<JsUnknown>) -> DataType {
1633+
match value {
1634+
Some(val) => {
1635+
let ty = val.get_type().unwrap();
1636+
match ty {
1637+
ValueType::Boolean => DataType::Boolean,
1638+
ValueType::Number => DataType::Float64,
1639+
ValueType::BigInt => DataType::UInt64,
1640+
ValueType::String => DataType::String,
1641+
ValueType::Object => {
1642+
if val.is_array().unwrap() {
1643+
let arr: napi::JsObject = unsafe { val.cast() };
1644+
let len = arr.get_array_length().unwrap();
1645+
if len == 0 {
1646+
DataType::List(DataType::Null.into())
1647+
} else {
1648+
// dont compare too many items, as it could be expensive
1649+
let max_take = std::cmp::min(len as usize, 10);
1650+
let mut dtypes: Vec<DataType> = Vec::with_capacity(len as usize);
1651+
1652+
for idx in 0..max_take {
1653+
let item: napi::JsUnknown = arr.get_element(idx as u32).unwrap();
1654+
let ty = item.get_type().unwrap();
1655+
let dt: Wrap<DataType> = ty.into();
1656+
dtypes.push(dt.0)
1657+
}
1658+
let dtype = coerce_data_type(&dtypes);
1659+
1660+
DataType::List(dtype.into())
1661+
}
1662+
} else if val.is_date().unwrap() {
1663+
DataType::Datetime(TimeUnit::Milliseconds, None)
1664+
} else {
1665+
let inner_val: napi::JsObject = unsafe { val.cast() };
1666+
let inner_keys = Object::keys(&inner_val).unwrap();
1667+
let mut fldvec: Vec<Field> = Vec::with_capacity(inner_keys.len() as usize);
1668+
1669+
inner_keys.iter().for_each(|key| {
1670+
let inner_val = inner_val.get::<_, napi::JsUnknown>(&key).unwrap();
1671+
let dtype = match inner_val.as_ref().unwrap().get_type().unwrap() {
1672+
ValueType::Boolean => DataType::Boolean,
1673+
ValueType::Number => DataType::Float64,
1674+
ValueType::BigInt => DataType::UInt64,
1675+
ValueType::String => DataType::String,
1676+
// determine struct type using a recursive func
1677+
ValueType::Object => obj_to_type(inner_val),
1678+
_ => DataType::Null,
1679+
};
1680+
1681+
let fld = Field::new(key.into(), dtype);
1682+
fldvec.push(fld);
1683+
});
1684+
DataType::Struct(fldvec)
1685+
}
1686+
}
1687+
_ => DataType::Null,
1688+
}
1689+
}
1690+
None => DataType::Null,
1691+
}
1692+
}
1693+
1694+
fn coerce_js_anyvalue<'a>(val: JsUnknown, dtype: DataType) -> JsResult<AnyValue<'a>> {
16781695
use DataType::*;
16791696
let vtype = val.get_type().unwrap();
16801697
match (vtype, dtype) {
@@ -1749,17 +1766,59 @@ unsafe fn coerce_js_anyvalue<'a>(val: JsUnknown, dtype: DataType) -> JsResult<An
17491766
}
17501767
(ValueType::Object, DataType::Datetime(_, _)) => {
17511768
if val.is_date()? {
1752-
let d: napi::JsDate = val.cast();
1769+
let d: napi::JsDate = unsafe { val.cast() };
17531770
let d = d.value_of()?;
17541771
Ok(AnyValue::Datetime(d as i64, TimeUnit::Milliseconds, None))
17551772
} else {
17561773
Ok(AnyValue::Null)
17571774
}
17581775
}
17591776
(ValueType::Object, DataType::List(_)) => {
1760-
let s = val.to_series();
1777+
let s = unsafe { val.to_series() };
17611778
Ok(AnyValue::List(s))
17621779
}
1780+
(ValueType::Object, DataType::Struct(fields)) => {
1781+
let number_of_fields: i8 = fields.len().try_into().map_err(|e| {
1782+
napi::Error::from_reason(format!(
1783+
"the number of `fields` cannot be larger than i8::MAX {e:?}"
1784+
))
1785+
})?;
1786+
1787+
let inner_val: napi::JsObject = unsafe { val.cast() };
1788+
let mut val_vec: Vec<polars::prelude::AnyValue<'_>> =
1789+
Vec::with_capacity(number_of_fields as usize);
1790+
fields.iter().for_each(|fld| {
1791+
let single_val = inner_val
1792+
.get::<_, napi::JsUnknown>(&fld.name)
1793+
.unwrap()
1794+
.unwrap();
1795+
let vv = match &fld.dtype {
1796+
DataType::Boolean => {
1797+
AnyValue::Boolean(single_val.coerce_to_bool().unwrap().get_value().unwrap())
1798+
}
1799+
DataType::String => AnyValue::from_js(single_val).expect("Expecting string"),
1800+
DataType::Int16 => AnyValue::Int16(
1801+
single_val.coerce_to_number().unwrap().get_int32().unwrap() as i16,
1802+
),
1803+
DataType::Int32 => {
1804+
AnyValue::Int32(single_val.coerce_to_number().unwrap().get_int32().unwrap())
1805+
}
1806+
DataType::Int64 => {
1807+
AnyValue::Int64(single_val.coerce_to_number().unwrap().get_int64().unwrap())
1808+
}
1809+
DataType::Float64 => AnyValue::Float64(
1810+
single_val.coerce_to_number().unwrap().get_double().unwrap(),
1811+
),
1812+
DataType::Struct(_) => {
1813+
coerce_js_anyvalue(single_val, fld.dtype.clone()).unwrap()
1814+
}
1815+
_ => AnyValue::Null,
1816+
};
1817+
val_vec.push(vv);
1818+
});
1819+
1820+
Ok(AnyValue::StructOwned(Box::new((val_vec, fields))))
1821+
}
17631822
_ => Ok(AnyValue::Null),
17641823
}
17651824
}

0 commit comments

Comments
 (0)