@@ -37,12 +37,15 @@ ctypedef struct function_info:
37
37
cdef class patch:
38
38
cdef int functions_count
39
39
cdef function_info* functions
40
+ cdef bint _is_patched
40
41
41
- functions_dict = {}
42
+ functions_dict = dict ()
42
43
43
44
def __cinit__ (self ):
44
45
cdef int pi, oi
45
46
47
+ self ._is_patched = False
48
+
46
49
umaths = [i for i in dir (mu) if isinstance (getattr (mu, i), np.ufunc)]
47
50
self .functions_count = 0
48
51
for umath in umaths:
@@ -95,7 +98,9 @@ cdef class patch:
95
98
index = self .functions_dict[func]
96
99
function = self .functions[index].patch_function
97
100
signature = self .functions[index].signature
98
- res = cnp.PyUFunc_ReplaceLoopBySignature(np_umath, function, signature, & temp)
101
+ res = cnp.PyUFunc_ReplaceLoopBySignature(< cnp.ufunc> np_umath, function, signature, & temp)
102
+
103
+ self ._is_patched = True
99
104
100
105
def do_unpatch (self ):
101
106
cdef int res
@@ -110,6 +115,10 @@ cdef class patch:
110
115
signature = self .functions[index].signature
111
116
res = cnp.PyUFunc_ReplaceLoopBySignature(np_umath, function, signature, & temp)
112
117
118
+ self ._is_patched = False
119
+
120
+ def is_patched (self ):
121
+ return self ._is_patched
113
122
114
123
from threading import local as threading_local
115
124
_tls = threading_local()
@@ -123,14 +132,40 @@ def _initialize_tls():
123
132
_tls.patch = patch()
124
133
_tls.initialized = True
125
134
126
-
127
- def do_patch ():
135
+
136
+ def use_in_numpy ():
137
+ '''
138
+ Enables using of mkl_umath in Numpy.
139
+ '''
128
140
if not _is_tls_initialized():
129
141
_initialize_tls()
130
142
_tls.patch.do_patch()
131
143
132
144
133
- def do_unpatch ():
145
+ def restore ():
146
+ '''
147
+ Disables using of mkl_umath in Numpy.
148
+ '''
134
149
if not _is_tls_initialized():
135
150
_initialize_tls()
136
151
_tls.patch.do_unpatch()
152
+
153
+
154
+ def is_patched ():
155
+ '''
156
+ Returns whether Numpy has been patched with mkl_umath.
157
+ '''
158
+ if not _is_tls_initialized():
159
+ _initialize_tls()
160
+ _tls.patch.is_patched()
161
+
162
+ from contextlib import ContextDecorator
163
+
164
+ class mkl_umath (ContextDecorator ):
165
+ def __enter__ (self ):
166
+ use_in_numpy()
167
+ return self
168
+
169
+ def __exit__ (self , *exc ):
170
+ restore()
171
+ return False
0 commit comments