@@ -25,35 +25,100 @@ import numpy as np
25
25
26
26
from libc.stdlib cimport malloc, free
27
27
28
- cimport cpython.pycapsule
29
-
30
28
cnp.import_umath()
31
29
30
+
32
31
ctypedef struct function_info:
33
- cnp.PyUFuncGenericFunction np_function
34
- cnp.PyUFuncGenericFunction mkl_function
32
+ cnp.PyUFuncGenericFunction original_function
33
+ cnp.PyUFuncGenericFunction patch_function
35
34
int * signature
36
35
37
- ctypedef struct functions_struct:
38
- int count
39
- function_info* functions
40
-
41
-
42
- cdef const char * capsule_name = " functions_cache"
43
-
44
-
45
- cdef void _capsule_destructor(object caps):
46
- cdef functions_struct* fs
47
-
48
- if (caps is None ):
49
- print (" Nothing to destroy" )
50
- return
51
- fs = < functions_struct * > cpython.pycapsule.PyCapsule_GetPointer(caps, capsule_name)
52
- for i in range (fs[0 ].count):
53
- free(fs[0 ].functions[i].signature)
54
- free(fs[0 ].functions)
55
- free(fs)
56
36
37
+ cdef class patch:
38
+ cdef int functions_count
39
+ cdef function_info* functions
40
+ cdef bint _is_patched
41
+
42
+ functions_dict = dict ()
43
+
44
+ def __cinit__ (self ):
45
+ cdef int pi, oi
46
+
47
+ self ._is_patched = False
48
+
49
+ umaths = [i for i in dir (mu) if isinstance (getattr (mu, i), np.ufunc)]
50
+ self .functions_count = 0
51
+ for umath in umaths:
52
+ mkl_umath = getattr (mu, umath)
53
+ self .functions_count = self .functions_count + mkl_umath.ntypes
54
+
55
+ self .functions = < function_info * > malloc(self .functions_count * sizeof(function_info))
56
+
57
+ func_number = 0
58
+ for umath in umaths:
59
+ patch_umath = getattr (mu, umath)
60
+ c_patch_umath = < cnp.ufunc> patch_umath
61
+ c_orig_umath = < cnp.ufunc> getattr (nu, umath)
62
+ nargs = c_patch_umath.nargs
63
+ for pi in range (c_patch_umath.ntypes):
64
+ oi = 0
65
+ while oi < c_orig_umath.ntypes:
66
+ found = True
67
+ for i in range (c_patch_umath.nargs):
68
+ if c_patch_umath.types[pi * nargs + i] != c_orig_umath.types[oi * nargs + i]:
69
+ found = False
70
+ break
71
+ if found == True :
72
+ break
73
+ oi = oi + 1
74
+ if oi < c_orig_umath.ntypes:
75
+ self .functions[func_number].original_function = c_orig_umath.functions[oi]
76
+ self .functions[func_number].patch_function = c_patch_umath.functions[pi]
77
+ self .functions[func_number].signature = < int * > malloc(nargs * sizeof(int ))
78
+ for i in range (nargs):
79
+ self .functions[func_number].signature[i] = c_patch_umath.types[pi * nargs + i]
80
+ self .functions_dict[(umath, patch_umath.types[pi])] = func_number
81
+ func_number = func_number + 1
82
+ else :
83
+ raise RuntimeError (" Unable to find original function for: " + umath + " " + patch_umath.types[pi])
84
+
85
+ def __dealloc__ (self ):
86
+ for i in range (self .functions_count):
87
+ free(self .functions[i].signature)
88
+ free(self .functions)
89
+
90
+ def do_patch (self ):
91
+ cdef int res
92
+ cdef cnp.PyUFuncGenericFunction temp
93
+ cdef cnp.PyUFuncGenericFunction function
94
+ cdef int * signature
95
+
96
+ for func in self .functions_dict:
97
+ np_umath = getattr (nu, func[0 ])
98
+ index = self .functions_dict[func]
99
+ function = self .functions[index].patch_function
100
+ signature = self .functions[index].signature
101
+ res = cnp.PyUFunc_ReplaceLoopBySignature(< cnp.ufunc> np_umath, function, signature, & temp)
102
+
103
+ self ._is_patched = True
104
+
105
+ def do_unpatch (self ):
106
+ cdef int res
107
+ cdef cnp.PyUFuncGenericFunction temp
108
+ cdef cnp.PyUFuncGenericFunction function
109
+ cdef int * signature
110
+
111
+ for func in self .functions_dict:
112
+ np_umath = getattr (nu, func[0 ])
113
+ index = self .functions_dict[func]
114
+ function = self .functions[index].original_function
115
+ signature = self .functions[index].signature
116
+ res = cnp.PyUFunc_ReplaceLoopBySignature(np_umath, function, signature, & temp)
117
+
118
+ self ._is_patched = False
119
+
120
+ def is_patched (self ):
121
+ return self ._is_patched
57
122
58
123
from threading import local as threading_local
59
124
_tls = threading_local()
@@ -64,103 +129,43 @@ def _is_tls_initialized():
64
129
65
130
66
131
def _initialize_tls ():
67
- cdef functions_struct* fs
68
- cdef int funcs_count
69
-
70
- _tls.functions_dict = {}
71
-
72
- umaths = [i for i in dir (mu) if isinstance (getattr (mu, i), np.ufunc)]
73
- funcs_count = 0
74
- for umath in umaths:
75
- mkl_umath = getattr (mu, umath)
76
- funcs_count = funcs_count + mkl_umath.ntypes
77
-
78
- fs = < functions_struct * > malloc(sizeof(functions_struct))
79
- fs[0 ].count = funcs_count
80
- fs[0 ].functions = < function_info * > malloc(funcs_count * sizeof(function_info))
81
-
82
- func_number = 0
83
- for umath in umaths:
84
- mkl_umath = getattr (mu, umath)
85
- np_umath = getattr (nu, umath)
86
- c_mkl_umath = < cnp.ufunc> mkl_umath
87
- c_np_umath = < cnp.ufunc> np_umath
88
- for type in mkl_umath.types:
89
- np_index = np_umath.types.index(type )
90
- fs[0 ].functions[func_number].np_function = c_np_umath.functions[np_index]
91
- mkl_index = mkl_umath.types.index(type )
92
- fs[0 ].functions[func_number].mkl_function = c_mkl_umath.functions[mkl_index]
93
-
94
- nargs = c_mkl_umath.nargs
95
- fs[0 ].functions[func_number].signature = < int * > malloc(nargs * sizeof(int ))
96
- for i in range (nargs):
97
- fs[0 ].functions[func_number].signature[i] = c_mkl_umath.types[mkl_index* nargs + i]
98
-
99
- _tls.functions_dict[(umath, type )] = func_number
100
- func_number = func_number + 1
101
-
102
- _tls.functions_capsule = cpython.pycapsule.PyCapsule_New(< void * > fs, capsule_name, & _capsule_destructor)
103
-
132
+ _tls.patch = patch()
104
133
_tls.initialized = True
105
134
106
135
107
- def _get_func_dict ():
136
+ def use_in_numpy ():
137
+ '''
138
+ Enables using of mkl_umath in Numpy.
139
+ '''
108
140
if not _is_tls_initialized():
109
141
_initialize_tls()
110
- return _tls.functions_dict
142
+ _tls.patch.do_patch()
111
143
112
144
113
- cdef function_info * _get_functions ():
114
- cdef function_info * functions
115
- cdef functions_struct * fs
116
-
145
+ def restore ():
146
+ '''
147
+ Disables using of mkl_umath in Numpy.
148
+ '''
117
149
if not _is_tls_initialized():
118
150
_initialize_tls()
151
+ _tls.patch.do_unpatch()
119
152
120
- capsule = _tls.functions_capsule
121
- if (not cpython.pycapsule.PyCapsule_IsValid(capsule, capsule_name)):
122
- raise ValueError (" Internal Error: invalid capsule stored in TLS" )
123
- fs = < functions_struct * > cpython.pycapsule.PyCapsule_GetPointer(capsule, capsule_name)
124
- return fs[0 ].functions
125
-
126
-
127
- cdef void c_do_patch():
128
- cdef int res
129
- cdef cnp.PyUFuncGenericFunction temp
130
- cdef cnp.PyUFuncGenericFunction function
131
- cdef int * signature
132
-
133
- funcs_dict = _get_func_dict()
134
- functions = _get_functions()
135
-
136
- for func in funcs_dict:
137
- np_umath = getattr (nu, func[0 ])
138
- index = funcs_dict[func]
139
- function = functions[index].mkl_function
140
- signature = functions[index].signature
141
- res = cnp.PyUFunc_ReplaceLoopBySignature(np_umath, function, signature, & temp)
142
-
143
-
144
- cdef void c_do_unpatch():
145
- cdef int res
146
- cdef cnp.PyUFuncGenericFunction temp
147
- cdef cnp.PyUFuncGenericFunction function
148
- cdef int * signature
149
-
150
- funcs_dict = _get_func_dict()
151
- functions = _get_functions()
152
-
153
- for func in funcs_dict:
154
- np_umath = getattr (nu, func[0 ])
155
- index = funcs_dict[func]
156
- function = functions[index].np_function
157
- signature = functions[index].signature
158
- res = cnp.PyUFunc_ReplaceLoopBySignature(np_umath, function, signature, & temp)
159
153
154
+ def is_patched ():
155
+ '''
156
+ Returns whether Numpy has been patched with mkl_umath.
157
+ '''
158
+ if not _is_tls_initialized():
159
+ _initialize_tls()
160
+ _tls.patch.is_patched()
160
161
161
- def do_patch ():
162
- c_do_patch()
162
+ from contextlib import ContextDecorator
163
163
164
+ class mkl_umath (ContextDecorator ):
165
+ def __enter__ (self ):
166
+ use_in_numpy()
167
+ return self
164
168
165
- def do_unpatch ():
166
- c_do_unpatch()
169
+ def __exit__ (self , *exc ):
170
+ restore()
171
+ return False
0 commit comments