Skip to content

Commit 5d73c43

Browse files
Alexander-MakaryevGitHub Enterprise
authored andcommitted
Merge pull request #3 from SAT/patching-update
Patching update
2 parents 34c7a8f + e25738b commit 5d73c43

File tree

3 files changed

+127
-113
lines changed

3 files changed

+127
-113
lines changed

mkl_umath/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
'''
2+
Implementation of Numpy universal math functions using Intel(R) MKL and Intel(R) C compiler runtime.
3+
'''
4+
5+
from ._version import __version__
6+
7+
from ._ufuncs import *
8+
9+
from ._patch import mkl_umath, use_in_numpy, restore, is_patched

mkl_umath/setup.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -151,8 +151,8 @@ def generate_umath_c(ext, build_dir):
151151
libraries = mkl_libraries + ['loops_intel'],
152152
library_dirs = mkl_library_dirs,
153153
extra_compile_args = [
154-
# '-DNDEBUG',
155-
'-ggdb', '-O0', '-Wall', '-Wextra', '-DDEBUG',
154+
'-DNDEBUG',
155+
# '-ggdb', '-O0', '-Wall', '-Wextra', '-DDEBUG',
156156
]
157157
)
158158

@@ -168,8 +168,8 @@ def generate_umath_c(ext, build_dir):
168168
libraries = mkl_libraries + ['loops_intel'],
169169
library_dirs = mkl_library_dirs,
170170
extra_compile_args = [
171-
# '-DNDEBUG',
172-
'-ggdb', '-O0', '-Wall', '-Wextra', '-DDEBUG',
171+
'-DNDEBUG',
172+
#'-ggdb', '-O0', '-Wall', '-Wextra', '-DDEBUG',
173173
]
174174
)
175175

mkl_umath/src/patch.pyx

Lines changed: 114 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -25,35 +25,100 @@ import numpy as np
2525

2626
from libc.stdlib cimport malloc, free
2727

28-
cimport cpython.pycapsule
29-
3028
cnp.import_umath()
3129

30+
3231
ctypedef struct function_info:
33-
cnp.PyUFuncGenericFunction np_function
34-
cnp.PyUFuncGenericFunction mkl_function
32+
cnp.PyUFuncGenericFunction original_function
33+
cnp.PyUFuncGenericFunction patch_function
3534
int* signature
3635

37-
ctypedef struct functions_struct:
38-
int count
39-
function_info* functions
40-
41-
42-
cdef const char *capsule_name = "functions_cache"
43-
44-
45-
cdef void _capsule_destructor(object caps):
46-
cdef functions_struct* fs
47-
48-
if (caps is None):
49-
print("Nothing to destroy")
50-
return
51-
fs = <functions_struct *>cpython.pycapsule.PyCapsule_GetPointer(caps, capsule_name)
52-
for i in range(fs[0].count):
53-
free(fs[0].functions[i].signature)
54-
free(fs[0].functions)
55-
free(fs)
5636

37+
cdef class patch:
38+
cdef int functions_count
39+
cdef function_info* functions
40+
cdef bint _is_patched
41+
42+
functions_dict = dict()
43+
44+
def __cinit__(self):
45+
cdef int pi, oi
46+
47+
self._is_patched = False
48+
49+
umaths = [i for i in dir(mu) if isinstance(getattr(mu, i), np.ufunc)]
50+
self.functions_count = 0
51+
for umath in umaths:
52+
mkl_umath = getattr(mu, umath)
53+
self.functions_count = self.functions_count + mkl_umath.ntypes
54+
55+
self.functions = <function_info *> malloc(self.functions_count * sizeof(function_info))
56+
57+
func_number = 0
58+
for umath in umaths:
59+
patch_umath = getattr(mu, umath)
60+
c_patch_umath = <cnp.ufunc>patch_umath
61+
c_orig_umath = <cnp.ufunc>getattr(nu, umath)
62+
nargs = c_patch_umath.nargs
63+
for pi in range(c_patch_umath.ntypes):
64+
oi = 0
65+
while oi < c_orig_umath.ntypes:
66+
found = True
67+
for i in range(c_patch_umath.nargs):
68+
if c_patch_umath.types[pi * nargs + i] != c_orig_umath.types[oi * nargs + i]:
69+
found = False
70+
break
71+
if found == True:
72+
break
73+
oi = oi + 1
74+
if oi < c_orig_umath.ntypes:
75+
self.functions[func_number].original_function = c_orig_umath.functions[oi]
76+
self.functions[func_number].patch_function = c_patch_umath.functions[pi]
77+
self.functions[func_number].signature = <int *> malloc(nargs * sizeof(int))
78+
for i in range(nargs):
79+
self.functions[func_number].signature[i] = c_patch_umath.types[pi * nargs + i]
80+
self.functions_dict[(umath, patch_umath.types[pi])] = func_number
81+
func_number = func_number + 1
82+
else:
83+
raise RuntimeError("Unable to find original function for: " + umath + " " + patch_umath.types[pi])
84+
85+
def __dealloc__(self):
86+
for i in range(self.functions_count):
87+
free(self.functions[i].signature)
88+
free(self.functions)
89+
90+
def do_patch(self):
91+
cdef int res
92+
cdef cnp.PyUFuncGenericFunction temp
93+
cdef cnp.PyUFuncGenericFunction function
94+
cdef int* signature
95+
96+
for func in self.functions_dict:
97+
np_umath = getattr(nu, func[0])
98+
index = self.functions_dict[func]
99+
function = self.functions[index].patch_function
100+
signature = self.functions[index].signature
101+
res = cnp.PyUFunc_ReplaceLoopBySignature(<cnp.ufunc>np_umath, function, signature, &temp)
102+
103+
self._is_patched = True
104+
105+
def do_unpatch(self):
106+
cdef int res
107+
cdef cnp.PyUFuncGenericFunction temp
108+
cdef cnp.PyUFuncGenericFunction function
109+
cdef int* signature
110+
111+
for func in self.functions_dict:
112+
np_umath = getattr(nu, func[0])
113+
index = self.functions_dict[func]
114+
function = self.functions[index].original_function
115+
signature = self.functions[index].signature
116+
res = cnp.PyUFunc_ReplaceLoopBySignature(np_umath, function, signature, &temp)
117+
118+
self._is_patched = False
119+
120+
def is_patched(self):
121+
return self._is_patched
57122

58123
from threading import local as threading_local
59124
_tls = threading_local()
@@ -64,103 +129,43 @@ def _is_tls_initialized():
64129

65130

66131
def _initialize_tls():
67-
cdef functions_struct* fs
68-
cdef int funcs_count
69-
70-
_tls.functions_dict = {}
71-
72-
umaths = [i for i in dir(mu) if isinstance(getattr(mu, i), np.ufunc)]
73-
funcs_count = 0
74-
for umath in umaths:
75-
mkl_umath = getattr(mu, umath)
76-
funcs_count = funcs_count + mkl_umath.ntypes
77-
78-
fs = <functions_struct *> malloc(sizeof(functions_struct))
79-
fs[0].count = funcs_count
80-
fs[0].functions = <function_info *> malloc(funcs_count * sizeof(function_info))
81-
82-
func_number = 0
83-
for umath in umaths:
84-
mkl_umath = getattr(mu, umath)
85-
np_umath = getattr(nu, umath)
86-
c_mkl_umath = <cnp.ufunc>mkl_umath
87-
c_np_umath = <cnp.ufunc>np_umath
88-
for type in mkl_umath.types:
89-
np_index = np_umath.types.index(type)
90-
fs[0].functions[func_number].np_function = c_np_umath.functions[np_index]
91-
mkl_index = mkl_umath.types.index(type)
92-
fs[0].functions[func_number].mkl_function = c_mkl_umath.functions[mkl_index]
93-
94-
nargs = c_mkl_umath.nargs
95-
fs[0].functions[func_number].signature = <int *> malloc(nargs * sizeof(int))
96-
for i in range(nargs):
97-
fs[0].functions[func_number].signature[i] = c_mkl_umath.types[mkl_index*nargs + i]
98-
99-
_tls.functions_dict[(umath, type)] = func_number
100-
func_number = func_number + 1
101-
102-
_tls.functions_capsule = cpython.pycapsule.PyCapsule_New(<void *>fs, capsule_name, &_capsule_destructor)
103-
132+
_tls.patch = patch()
104133
_tls.initialized = True
105134

106135

107-
def _get_func_dict():
136+
def use_in_numpy():
137+
'''
138+
Enables using of mkl_umath in Numpy.
139+
'''
108140
if not _is_tls_initialized():
109141
_initialize_tls()
110-
return _tls.functions_dict
142+
_tls.patch.do_patch()
111143

112144

113-
cdef function_info* _get_functions():
114-
cdef function_info* functions
115-
cdef functions_struct* fs
116-
145+
def restore():
146+
'''
147+
Disables using of mkl_umath in Numpy.
148+
'''
117149
if not _is_tls_initialized():
118150
_initialize_tls()
151+
_tls.patch.do_unpatch()
119152

120-
capsule = _tls.functions_capsule
121-
if (not cpython.pycapsule.PyCapsule_IsValid(capsule, capsule_name)):
122-
raise ValueError("Internal Error: invalid capsule stored in TLS")
123-
fs = <functions_struct *>cpython.pycapsule.PyCapsule_GetPointer(capsule, capsule_name)
124-
return fs[0].functions
125-
126-
127-
cdef void c_do_patch():
128-
cdef int res
129-
cdef cnp.PyUFuncGenericFunction temp
130-
cdef cnp.PyUFuncGenericFunction function
131-
cdef int* signature
132-
133-
funcs_dict = _get_func_dict()
134-
functions = _get_functions()
135-
136-
for func in funcs_dict:
137-
np_umath = getattr(nu, func[0])
138-
index = funcs_dict[func]
139-
function = functions[index].mkl_function
140-
signature = functions[index].signature
141-
res = cnp.PyUFunc_ReplaceLoopBySignature(np_umath, function, signature, &temp)
142-
143-
144-
cdef void c_do_unpatch():
145-
cdef int res
146-
cdef cnp.PyUFuncGenericFunction temp
147-
cdef cnp.PyUFuncGenericFunction function
148-
cdef int* signature
149-
150-
funcs_dict = _get_func_dict()
151-
functions = _get_functions()
152-
153-
for func in funcs_dict:
154-
np_umath = getattr(nu, func[0])
155-
index = funcs_dict[func]
156-
function = functions[index].np_function
157-
signature = functions[index].signature
158-
res = cnp.PyUFunc_ReplaceLoopBySignature(np_umath, function, signature, &temp)
159153

154+
def is_patched():
155+
'''
156+
Returns whether Numpy has been patched with mkl_umath.
157+
'''
158+
if not _is_tls_initialized():
159+
_initialize_tls()
160+
_tls.patch.is_patched()
160161

161-
def do_patch():
162-
c_do_patch()
162+
from contextlib import ContextDecorator
163163

164+
class mkl_umath(ContextDecorator):
165+
def __enter__(self):
166+
use_in_numpy()
167+
return self
164168

165-
def do_unpatch():
166-
c_do_unpatch()
169+
def __exit__(self, *exc):
170+
restore()
171+
return False

0 commit comments

Comments
 (0)