Skip to content

dt64 scalar lookups match dt64 unit #192

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 15, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 14 additions & 22 deletions src/auto_map.c
Original file line number Diff line number Diff line change
@@ -74,20 +74,6 @@ typedef enum KeysArrayType{
KAT_DTas,
} KeysArrayType;

NPY_DATETIMEUNIT
dt_unit_from_array(PyArrayObject* a) {
// This is based on get_datetime_metadata_from_dtype in the NumPy source, but that function is private. This does not check that the dytpe is of the appropriate type.
PyArray_Descr* dt = PyArray_DESCR(a); // borrowed ref
PyArray_DatetimeMetaData* dma = &(((PyArray_DatetimeDTypeMetaData *)PyDataType_C_METADATA(dt))->meta);
return dma->base;
}

NPY_DATETIMEUNIT
dt_unit_from_scalar(PyDatetimeScalarObject* dts) {
// Based on convert_pyobject_to_datetime and related usage in datetime.c
PyArray_DatetimeMetaData* dma = &(dts->obmeta);
return dma->base;
}

KeysArrayType
at_to_kat(int array_t, PyArrayObject* a) {
@@ -123,7 +109,7 @@ at_to_kat(int array_t, PyArrayObject* a) {
return KAT_STRING;

case NPY_DATETIME: {
NPY_DATETIMEUNIT dtu = dt_unit_from_array(a);
NPY_DATETIMEUNIT dtu = AK_dt_unit_from_array(a);
switch (dtu) {
case NPY_FR_Y:
return KAT_DTY;
@@ -685,9 +671,6 @@ lookup_hash_obj(FAMObject *self, PyObject *key, Py_hash_t hash)
int result = -1;
Py_hash_t h = 0;

// AK_DEBUG_MSG_OBJ("lookup_hash_obj", key);
// TODO: if key is a dt64, we need to get the units and compare to units before doing PyObject_RichCompareBool

while (1) {
for (Py_ssize_t i = 0; i < SCAN; i++) {
h = table[table_pos].hash;
@@ -702,6 +685,16 @@ lookup_hash_obj(FAMObject *self, PyObject *key, Py_hash_t hash)
if (guess == key) { // Hit. Object ID comparison
return table_pos;
}

// if key is a dt64, only do PyObject_RichCompareBool if units match
if (PyArray_IsScalar(key, Datetime) && PyArray_IsScalar(guess, Datetime)) {
if (AK_dt_unit_from_scalar((PyDatetimeScalarObject *)key)
!= AK_dt_unit_from_scalar((PyDatetimeScalarObject *)guess)) {
table_pos++;
continue;
}
}

result = PyObject_RichCompareBool(guess, key, Py_EQ);
if (result < 0) { // Error.
return -1;
@@ -1030,10 +1023,9 @@ lookup_datetime(FAMObject *self, PyObject* key) {
if (PyArray_IsScalar(key, Datetime)) {
v = (npy_int64)PyArrayScalar_VAL(key, Datetime);
// if we observe a NAT, we skip unit checks
// AK_DEBUG_MSG_OBJ("dt64 value", PyLong_FromLongLong(v));

if (v != NPY_DATETIME_NAT) {
NPY_DATETIMEUNIT key_unit = dt_unit_from_scalar(
NPY_DATETIMEUNIT key_unit = AK_dt_unit_from_scalar(
(PyDatetimeScalarObject *)key);
if (!kat_is_datetime_unit(self->keys_array_type, key_unit)) {
return -1;
@@ -1872,7 +1864,7 @@ fam_get_all(FAMObject *self, PyObject *key) {
GET_ALL_FLEXIBLE(char, char_get_end_p, lookup_hash_string, string_to_hash, PyBytes_FromStringAndSize);
break;
case NPY_DATETIME: {
NPY_DATETIMEUNIT key_unit = dt_unit_from_array(key_array);
NPY_DATETIMEUNIT key_unit = AK_dt_unit_from_array(key_array);
if (!kat_is_datetime_unit(self->keys_array_type, key_unit)) {
PyErr_SetString(PyExc_KeyError, "datetime64 units do not match");
Py_DECREF(array);
@@ -2070,7 +2062,7 @@ fam_get_any(FAMObject *self, PyObject *key) {
GET_ANY_FLEXIBLE(char, char_get_end_p, lookup_hash_string, string_to_hash);
break;
case NPY_DATETIME: {
NPY_DATETIMEUNIT key_unit = dt_unit_from_array(key_array);
NPY_DATETIMEUNIT key_unit = AK_dt_unit_from_array(key_array);
if (!kat_is_datetime_unit(self->keys_array_type, key_unit)) {
return values;
}
9 changes: 0 additions & 9 deletions src/tri_map.c
Original file line number Diff line number Diff line change
@@ -11,15 +11,6 @@
# include "tri_map.h"
# include "utilities.h"

static inline NPY_DATETIMEUNIT
AK_dt_unit_from_array(PyArrayObject* a) {
// This is based on get_datetime_metadata_from_dtype in the NumPy source, but that function is private. This does not check that the dtype is of the appropriate type.
PyArray_Descr* dt = PyArray_DESCR(a); // borrowed ref
PyArray_DatetimeMetaData* dma = &(((PyArray_DatetimeDTypeMetaData *)PyDataType_C_METADATA(dt))->meta);
// PyArray_DatetimeMetaData* dma = &(((PyArray_DatetimeDTypeMetaData *)PyArray_DESCR(a)->c_metadata)->meta);
return dma->base;
}

typedef struct TriMapOne {
Py_ssize_t from; // signed
Py_ssize_t to;
17 changes: 17 additions & 0 deletions src/utilities.h
Original file line number Diff line number Diff line change
@@ -9,6 +9,7 @@
# define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION

# include "numpy/arrayobject.h"
# include "numpy/arrayscalars.h"

static const size_t UCS4_SIZE = sizeof(Py_UCS4);

@@ -318,4 +319,20 @@ AK_nonzero_1d(PyArrayObject* array) {
return final;
}

static inline NPY_DATETIMEUNIT
AK_dt_unit_from_array(PyArrayObject* a) {
// This is based on get_datetime_metadata_from_dtype in the NumPy source, but that function is private. This does not check that the dtype is of the appropriate type.
PyArray_Descr* dt = PyArray_DESCR(a); // borrowed ref
PyArray_DatetimeMetaData* dma = &(((PyArray_DatetimeDTypeMetaData *)PyDataType_C_METADATA(dt))->meta);
// PyArray_DatetimeMetaData* dma = &(((PyArray_DatetimeDTypeMetaData *)PyArray_DESCR(a)->c_metadata)->meta);
return dma->base;
}

static inline NPY_DATETIMEUNIT
AK_dt_unit_from_scalar(PyDatetimeScalarObject* dts) {
// Based on convert_pyobject_to_datetime and related usage in datetime.c
PyArray_DatetimeMetaData* dma = &(dts->obmeta);
return dma->base;
}

#endif /* ARRAYKIT_SRC_UTILITIES_H_ */