@@ -41,6 +41,8 @@ cdef class patch:
41
41
functions_dict = {}
42
42
43
43
def __cinit__ (self ):
44
+ cdef int pi, oi
45
+
44
46
umaths = [i for i in dir (mu) if isinstance (getattr (mu, i), np.ufunc)]
45
47
self .functions_count = 0
46
48
for umath in umaths:
@@ -51,23 +53,31 @@ cdef class patch:
51
53
52
54
func_number = 0
53
55
for umath in umaths:
54
- mkl_umath = getattr (mu, umath)
55
- np_umath = getattr (nu, umath)
56
- c_mkl_umath = < cnp.ufunc> mkl_umath
57
- c_np_umath = < cnp.ufunc> np_umath
58
- for type in mkl_umath.types:
59
- np_index = np_umath.types.index(type )
60
- self .functions[func_number].original_function = c_np_umath.functions[np_index]
61
- mkl_index = mkl_umath.types.index(type )
62
- self .functions[func_number].patch_function = c_mkl_umath.functions[mkl_index]
63
-
64
- nargs = c_mkl_umath.nargs
65
- self .functions[func_number].signature = < int * > malloc(nargs * sizeof(int ))
66
- for i in range (nargs):
67
- self .functions[func_number].signature[i] = c_mkl_umath.types[mkl_index* nargs + i]
68
-
69
- self .functions_dict[(umath, type )] = func_number
70
- func_number = func_number + 1
56
+ patch_umath = getattr (mu, umath)
57
+ c_patch_umath = < cnp.ufunc> patch_umath
58
+ c_orig_umath = < cnp.ufunc> getattr (nu, umath)
59
+ nargs = c_patch_umath.nargs
60
+ for pi in range (c_patch_umath.ntypes):
61
+ oi = 0
62
+ while oi < c_orig_umath.ntypes:
63
+ found = True
64
+ for i in range (c_patch_umath.nargs):
65
+ if c_patch_umath.types[pi * nargs + i] != c_orig_umath.types[oi * nargs + i]:
66
+ found = False
67
+ break
68
+ if found == True :
69
+ break
70
+ oi = oi + 1
71
+ if oi < c_orig_umath.ntypes:
72
+ self .functions[func_number].original_function = c_orig_umath.functions[oi]
73
+ self .functions[func_number].patch_function = c_patch_umath.functions[pi]
74
+ self .functions[func_number].signature = < int * > malloc(nargs * sizeof(int ))
75
+ for i in range (nargs):
76
+ self .functions[func_number].signature[i] = c_patch_umath.types[pi * nargs + i]
77
+ self .functions_dict[(umath, patch_umath.types[pi])] = func_number
78
+ func_number = func_number + 1
79
+ else :
80
+ raise RuntimeError (" Unable to find original function for: " + umath + " " + patch_umath.types[pi])
71
81
72
82
def __dealloc__ (self ):
73
83
for i in range (self .functions_count):
0 commit comments