Skip to content

Commit 32acc61

Browse files
search ufunc by signature
1 parent 54a2ace commit 32acc61

File tree

1 file changed

+27
-17
lines changed

1 file changed

+27
-17
lines changed

mkl_umath/src/patch.pyx

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ cdef class patch:
4141
functions_dict = {}
4242

4343
def __cinit__(self):
44+
cdef int pi, oi
45+
4446
umaths = [i for i in dir(mu) if isinstance(getattr(mu, i), np.ufunc)]
4547
self.functions_count = 0
4648
for umath in umaths:
@@ -51,23 +53,31 @@ cdef class patch:
5153

5254
func_number = 0
5355
for umath in umaths:
54-
mkl_umath = getattr(mu, umath)
55-
np_umath = getattr(nu, umath)
56-
c_mkl_umath = <cnp.ufunc>mkl_umath
57-
c_np_umath = <cnp.ufunc>np_umath
58-
for type in mkl_umath.types:
59-
np_index = np_umath.types.index(type)
60-
self.functions[func_number].original_function = c_np_umath.functions[np_index]
61-
mkl_index = mkl_umath.types.index(type)
62-
self.functions[func_number].patch_function = c_mkl_umath.functions[mkl_index]
63-
64-
nargs = c_mkl_umath.nargs
65-
self.functions[func_number].signature = <int *> malloc(nargs * sizeof(int))
66-
for i in range(nargs):
67-
self.functions[func_number].signature[i] = c_mkl_umath.types[mkl_index*nargs + i]
68-
69-
self.functions_dict[(umath, type)] = func_number
70-
func_number = func_number + 1
56+
patch_umath = getattr(mu, umath)
57+
c_patch_umath = <cnp.ufunc>patch_umath
58+
c_orig_umath = <cnp.ufunc>getattr(nu, umath)
59+
nargs = c_patch_umath.nargs
60+
for pi in range(c_patch_umath.ntypes):
61+
oi = 0
62+
while oi < c_orig_umath.ntypes:
63+
found = True
64+
for i in range(c_patch_umath.nargs):
65+
if c_patch_umath.types[pi * nargs + i] != c_orig_umath.types[oi * nargs + i]:
66+
found = False
67+
break
68+
if found == True:
69+
break
70+
oi = oi + 1
71+
if oi < c_orig_umath.ntypes:
72+
self.functions[func_number].original_function = c_orig_umath.functions[oi]
73+
self.functions[func_number].patch_function = c_patch_umath.functions[pi]
74+
self.functions[func_number].signature = <int *> malloc(nargs * sizeof(int))
75+
for i in range(nargs):
76+
self.functions[func_number].signature[i] = c_patch_umath.types[pi * nargs + i]
77+
self.functions_dict[(umath, patch_umath.types[pi])] = func_number
78+
func_number = func_number + 1
79+
else:
80+
raise RuntimeError("Unable to find original function for: " + umath + " " + patch_umath.types[pi])
7181

7282
def __dealloc__(self):
7383
for i in range(self.functions_count):

0 commit comments

Comments
 (0)