@@ -25,34 +25,80 @@ import numpy as np
25
25
26
26
from libc.stdlib cimport malloc, free
27
27
28
- cimport cpython.pycapsule
29
-
30
28
cnp.import_umath()
31
29
30
+
32
31
ctypedef struct function_info:
33
- cnp.PyUFuncGenericFunction np_function
34
- cnp.PyUFuncGenericFunction mkl_function
32
+ cnp.PyUFuncGenericFunction original_function
33
+ cnp.PyUFuncGenericFunction patch_function
35
34
int * signature
36
35
37
- ctypedef struct functions_struct:
38
- int count
39
- function_info* functions
40
-
41
-
42
- cdef const char * capsule_name = " functions_cache"
43
36
37
+ cdef class patch:
38
+ cdef int functions_count
39
+ cdef function_info* functions
44
40
45
- cdef void _capsule_destructor(object caps):
46
- cdef functions_struct* fs
47
-
48
- if (caps is None ):
49
- print (" Nothing to destroy" )
50
- return
51
- fs = < functions_struct * > cpython.pycapsule.PyCapsule_GetPointer(caps, capsule_name)
52
- for i in range (fs[0 ].count):
53
- free(fs[0 ].functions[i].signature)
54
- free(fs[0 ].functions)
55
- free(fs)
41
+ functions_dict = {}
42
+
43
+ def __cinit__ (self ):
44
+ umaths = [i for i in dir (mu) if isinstance (getattr (mu, i), np.ufunc)]
45
+ self .functions_count = 0
46
+ for umath in umaths:
47
+ mkl_umath = getattr (mu, umath)
48
+ self .functions_count = self .functions_count + mkl_umath.ntypes
49
+
50
+ self .functions = < function_info * > malloc(self .functions_count * sizeof(function_info))
51
+
52
+ func_number = 0
53
+ for umath in umaths:
54
+ mkl_umath = getattr (mu, umath)
55
+ np_umath = getattr (nu, umath)
56
+ c_mkl_umath = < cnp.ufunc> mkl_umath
57
+ c_np_umath = < cnp.ufunc> np_umath
58
+ for type in mkl_umath.types:
59
+ np_index = np_umath.types.index(type )
60
+ self .functions[func_number].original_function = c_np_umath.functions[np_index]
61
+ mkl_index = mkl_umath.types.index(type )
62
+ self .functions[func_number].patch_function = c_mkl_umath.functions[mkl_index]
63
+
64
+ nargs = c_mkl_umath.nargs
65
+ self .functions[func_number].signature = < int * > malloc(nargs * sizeof(int ))
66
+ for i in range (nargs):
67
+ self .functions[func_number].signature[i] = c_mkl_umath.types[mkl_index* nargs + i]
68
+
69
+ self .functions_dict[(umath, type )] = func_number
70
+ func_number = func_number + 1
71
+
72
+ def __dealloc__ (self ):
73
+ for i in range (self .functions_count):
74
+ free(self .functions[i].signature)
75
+ free(self .functions)
76
+
77
+ def do_patch (self ):
78
+ cdef int res
79
+ cdef cnp.PyUFuncGenericFunction temp
80
+ cdef cnp.PyUFuncGenericFunction function
81
+ cdef int * signature
82
+
83
+ for func in self .functions_dict:
84
+ np_umath = getattr (nu, func[0 ])
85
+ index = self .functions_dict[func]
86
+ function = self .functions[index].patch_function
87
+ signature = self .functions[index].signature
88
+ res = cnp.PyUFunc_ReplaceLoopBySignature(np_umath, function, signature, & temp)
89
+
90
+ def do_unpatch (self ):
91
+ cdef int res
92
+ cdef cnp.PyUFuncGenericFunction temp
93
+ cdef cnp.PyUFuncGenericFunction function
94
+ cdef int * signature
95
+
96
+ for func in self .functions_dict:
97
+ np_umath = getattr (nu, func[0 ])
98
+ index = self .functions_dict[func]
99
+ function = self .functions[index].original_function
100
+ signature = self .functions[index].signature
101
+ res = cnp.PyUFunc_ReplaceLoopBySignature(np_umath, function, signature, & temp)
56
102
57
103
58
104
from threading import local as threading_local
@@ -64,103 +110,17 @@ def _is_tls_initialized():
64
110
65
111
66
112
def _initialize_tls ():
67
- cdef functions_struct* fs
68
- cdef int funcs_count
69
-
70
- _tls.functions_dict = {}
71
-
72
- umaths = [i for i in dir (mu) if isinstance (getattr (mu, i), np.ufunc)]
73
- funcs_count = 0
74
- for umath in umaths:
75
- mkl_umath = getattr (mu, umath)
76
- funcs_count = funcs_count + mkl_umath.ntypes
77
-
78
- fs = < functions_struct * > malloc(sizeof(functions_struct))
79
- fs[0 ].count = funcs_count
80
- fs[0 ].functions = < function_info * > malloc(funcs_count * sizeof(function_info))
81
-
82
- func_number = 0
83
- for umath in umaths:
84
- mkl_umath = getattr (mu, umath)
85
- np_umath = getattr (nu, umath)
86
- c_mkl_umath = < cnp.ufunc> mkl_umath
87
- c_np_umath = < cnp.ufunc> np_umath
88
- for type in mkl_umath.types:
89
- np_index = np_umath.types.index(type )
90
- fs[0 ].functions[func_number].np_function = c_np_umath.functions[np_index]
91
- mkl_index = mkl_umath.types.index(type )
92
- fs[0 ].functions[func_number].mkl_function = c_mkl_umath.functions[mkl_index]
93
-
94
- nargs = c_mkl_umath.nargs
95
- fs[0 ].functions[func_number].signature = < int * > malloc(nargs * sizeof(int ))
96
- for i in range (nargs):
97
- fs[0 ].functions[func_number].signature[i] = c_mkl_umath.types[mkl_index* nargs + i]
98
-
99
- _tls.functions_dict[(umath, type )] = func_number
100
- func_number = func_number + 1
101
-
102
- _tls.functions_capsule = cpython.pycapsule.PyCapsule_New(< void * > fs, capsule_name, & _capsule_destructor)
103
-
113
+ _tls.patch = patch()
104
114
_tls.initialized = True
105
115
106
-
107
- def _get_func_dict ():
116
+
117
+ def do_patch ():
108
118
if not _is_tls_initialized():
109
119
_initialize_tls()
110
- return _tls.functions_dict
120
+ _tls.patch.do_patch()
111
121
112
122
113
- cdef function_info* _get_functions():
114
- cdef function_info* functions
115
- cdef functions_struct* fs
116
-
123
+ def do_unpatch ():
117
124
if not _is_tls_initialized():
118
125
_initialize_tls()
119
-
120
- capsule = _tls.functions_capsule
121
- if (not cpython.pycapsule.PyCapsule_IsValid(capsule, capsule_name)):
122
- raise ValueError (" Internal Error: invalid capsule stored in TLS" )
123
- fs = < functions_struct * > cpython.pycapsule.PyCapsule_GetPointer(capsule, capsule_name)
124
- return fs[0 ].functions
125
-
126
-
127
- cdef void c_do_patch():
128
- cdef int res
129
- cdef cnp.PyUFuncGenericFunction temp
130
- cdef cnp.PyUFuncGenericFunction function
131
- cdef int * signature
132
-
133
- funcs_dict = _get_func_dict()
134
- functions = _get_functions()
135
-
136
- for func in funcs_dict:
137
- np_umath = getattr (nu, func[0 ])
138
- index = funcs_dict[func]
139
- function = functions[index].mkl_function
140
- signature = functions[index].signature
141
- res = cnp.PyUFunc_ReplaceLoopBySignature(np_umath, function, signature, & temp)
142
-
143
-
144
- cdef void c_do_unpatch():
145
- cdef int res
146
- cdef cnp.PyUFuncGenericFunction temp
147
- cdef cnp.PyUFuncGenericFunction function
148
- cdef int * signature
149
-
150
- funcs_dict = _get_func_dict()
151
- functions = _get_functions()
152
-
153
- for func in funcs_dict:
154
- np_umath = getattr (nu, func[0 ])
155
- index = funcs_dict[func]
156
- function = functions[index].np_function
157
- signature = functions[index].signature
158
- res = cnp.PyUFunc_ReplaceLoopBySignature(np_umath, function, signature, & temp)
159
-
160
-
161
- def do_patch ():
162
- c_do_patch()
163
-
164
-
165
- def do_unpatch ():
166
- c_do_unpatch()
126
+ _tls.patch.do_unpatch()
0 commit comments