Skip to content

Commit 54a2ace

Browse files
implemented patch class
1 parent 34c7a8f commit 54a2ace

File tree

1 file changed

+73
-113
lines changed

1 file changed

+73
-113
lines changed

mkl_umath/src/patch.pyx

Lines changed: 73 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -25,34 +25,80 @@ import numpy as np
2525

2626
from libc.stdlib cimport malloc, free
2727

28-
cimport cpython.pycapsule
29-
3028
cnp.import_umath()
3129

30+
3231
ctypedef struct function_info:
33-
cnp.PyUFuncGenericFunction np_function
34-
cnp.PyUFuncGenericFunction mkl_function
32+
cnp.PyUFuncGenericFunction original_function
33+
cnp.PyUFuncGenericFunction patch_function
3534
int* signature
3635

37-
ctypedef struct functions_struct:
38-
int count
39-
function_info* functions
40-
41-
42-
cdef const char *capsule_name = "functions_cache"
4336

37+
cdef class patch:
38+
cdef int functions_count
39+
cdef function_info* functions
4440

45-
cdef void _capsule_destructor(object caps):
46-
cdef functions_struct* fs
47-
48-
if (caps is None):
49-
print("Nothing to destroy")
50-
return
51-
fs = <functions_struct *>cpython.pycapsule.PyCapsule_GetPointer(caps, capsule_name)
52-
for i in range(fs[0].count):
53-
free(fs[0].functions[i].signature)
54-
free(fs[0].functions)
55-
free(fs)
41+
functions_dict = {}
42+
43+
def __cinit__(self):
44+
umaths = [i for i in dir(mu) if isinstance(getattr(mu, i), np.ufunc)]
45+
self.functions_count = 0
46+
for umath in umaths:
47+
mkl_umath = getattr(mu, umath)
48+
self.functions_count = self.functions_count + mkl_umath.ntypes
49+
50+
self.functions = <function_info *> malloc(self.functions_count * sizeof(function_info))
51+
52+
func_number = 0
53+
for umath in umaths:
54+
mkl_umath = getattr(mu, umath)
55+
np_umath = getattr(nu, umath)
56+
c_mkl_umath = <cnp.ufunc>mkl_umath
57+
c_np_umath = <cnp.ufunc>np_umath
58+
for type in mkl_umath.types:
59+
np_index = np_umath.types.index(type)
60+
self.functions[func_number].original_function = c_np_umath.functions[np_index]
61+
mkl_index = mkl_umath.types.index(type)
62+
self.functions[func_number].patch_function = c_mkl_umath.functions[mkl_index]
63+
64+
nargs = c_mkl_umath.nargs
65+
self.functions[func_number].signature = <int *> malloc(nargs * sizeof(int))
66+
for i in range(nargs):
67+
self.functions[func_number].signature[i] = c_mkl_umath.types[mkl_index*nargs + i]
68+
69+
self.functions_dict[(umath, type)] = func_number
70+
func_number = func_number + 1
71+
72+
def __dealloc__(self):
73+
for i in range(self.functions_count):
74+
free(self.functions[i].signature)
75+
free(self.functions)
76+
77+
def do_patch(self):
78+
cdef int res
79+
cdef cnp.PyUFuncGenericFunction temp
80+
cdef cnp.PyUFuncGenericFunction function
81+
cdef int* signature
82+
83+
for func in self.functions_dict:
84+
np_umath = getattr(nu, func[0])
85+
index = self.functions_dict[func]
86+
function = self.functions[index].patch_function
87+
signature = self.functions[index].signature
88+
res = cnp.PyUFunc_ReplaceLoopBySignature(np_umath, function, signature, &temp)
89+
90+
def do_unpatch(self):
91+
cdef int res
92+
cdef cnp.PyUFuncGenericFunction temp
93+
cdef cnp.PyUFuncGenericFunction function
94+
cdef int* signature
95+
96+
for func in self.functions_dict:
97+
np_umath = getattr(nu, func[0])
98+
index = self.functions_dict[func]
99+
function = self.functions[index].original_function
100+
signature = self.functions[index].signature
101+
res = cnp.PyUFunc_ReplaceLoopBySignature(np_umath, function, signature, &temp)
56102

57103

58104
from threading import local as threading_local
@@ -64,103 +110,17 @@ def _is_tls_initialized():
64110

65111

66112
def _initialize_tls():
67-
cdef functions_struct* fs
68-
cdef int funcs_count
69-
70-
_tls.functions_dict = {}
71-
72-
umaths = [i for i in dir(mu) if isinstance(getattr(mu, i), np.ufunc)]
73-
funcs_count = 0
74-
for umath in umaths:
75-
mkl_umath = getattr(mu, umath)
76-
funcs_count = funcs_count + mkl_umath.ntypes
77-
78-
fs = <functions_struct *> malloc(sizeof(functions_struct))
79-
fs[0].count = funcs_count
80-
fs[0].functions = <function_info *> malloc(funcs_count * sizeof(function_info))
81-
82-
func_number = 0
83-
for umath in umaths:
84-
mkl_umath = getattr(mu, umath)
85-
np_umath = getattr(nu, umath)
86-
c_mkl_umath = <cnp.ufunc>mkl_umath
87-
c_np_umath = <cnp.ufunc>np_umath
88-
for type in mkl_umath.types:
89-
np_index = np_umath.types.index(type)
90-
fs[0].functions[func_number].np_function = c_np_umath.functions[np_index]
91-
mkl_index = mkl_umath.types.index(type)
92-
fs[0].functions[func_number].mkl_function = c_mkl_umath.functions[mkl_index]
93-
94-
nargs = c_mkl_umath.nargs
95-
fs[0].functions[func_number].signature = <int *> malloc(nargs * sizeof(int))
96-
for i in range(nargs):
97-
fs[0].functions[func_number].signature[i] = c_mkl_umath.types[mkl_index*nargs + i]
98-
99-
_tls.functions_dict[(umath, type)] = func_number
100-
func_number = func_number + 1
101-
102-
_tls.functions_capsule = cpython.pycapsule.PyCapsule_New(<void *>fs, capsule_name, &_capsule_destructor)
103-
113+
_tls.patch = patch()
104114
_tls.initialized = True
105115

106-
107-
def _get_func_dict():
116+
117+
def do_patch():
108118
if not _is_tls_initialized():
109119
_initialize_tls()
110-
return _tls.functions_dict
120+
_tls.patch.do_patch()
111121

112122

113-
cdef function_info* _get_functions():
114-
cdef function_info* functions
115-
cdef functions_struct* fs
116-
123+
def do_unpatch():
117124
if not _is_tls_initialized():
118125
_initialize_tls()
119-
120-
capsule = _tls.functions_capsule
121-
if (not cpython.pycapsule.PyCapsule_IsValid(capsule, capsule_name)):
122-
raise ValueError("Internal Error: invalid capsule stored in TLS")
123-
fs = <functions_struct *>cpython.pycapsule.PyCapsule_GetPointer(capsule, capsule_name)
124-
return fs[0].functions
125-
126-
127-
cdef void c_do_patch():
128-
cdef int res
129-
cdef cnp.PyUFuncGenericFunction temp
130-
cdef cnp.PyUFuncGenericFunction function
131-
cdef int* signature
132-
133-
funcs_dict = _get_func_dict()
134-
functions = _get_functions()
135-
136-
for func in funcs_dict:
137-
np_umath = getattr(nu, func[0])
138-
index = funcs_dict[func]
139-
function = functions[index].mkl_function
140-
signature = functions[index].signature
141-
res = cnp.PyUFunc_ReplaceLoopBySignature(np_umath, function, signature, &temp)
142-
143-
144-
cdef void c_do_unpatch():
145-
cdef int res
146-
cdef cnp.PyUFuncGenericFunction temp
147-
cdef cnp.PyUFuncGenericFunction function
148-
cdef int* signature
149-
150-
funcs_dict = _get_func_dict()
151-
functions = _get_functions()
152-
153-
for func in funcs_dict:
154-
np_umath = getattr(nu, func[0])
155-
index = funcs_dict[func]
156-
function = functions[index].np_function
157-
signature = functions[index].signature
158-
res = cnp.PyUFunc_ReplaceLoopBySignature(np_umath, function, signature, &temp)
159-
160-
161-
def do_patch():
162-
c_do_patch()
163-
164-
165-
def do_unpatch():
166-
c_do_unpatch()
126+
_tls.patch.do_unpatch()

0 commit comments

Comments
 (0)