-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathorsa.py
50 lines (41 loc) · 1.19 KB
/
orsa.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import os
import cffi
import numpy as np
root = os.path.dirname(os.path.realpath(__file__))
ffi = cffi.FFI()
with open(os.path.join(root, 'build/demo/all.h')) as header:
ffi.cdef(header.read())
orsa = ffi.dlopen(os.path.join(root, "build/demo/libworsa.so"))
def P(array):
typestr = 'double*'
if array.dtype == np.float32:
typestr = 'float*'
elif array.dtype == np.bool:
typestr = 'bool*'
elif array.dtype == np.int32:
typestr = 'int*'
# requires cffi 0.12
return ffi.from_buffer(typestr, array, require_writable=True)
def STR(ptr):
return ffi.string(ptr)
class a(object):
pass
def wrap(a):
if isinstance(a, np.ndarray):
return P(a)
elif isinstance(a, list):
return list(map(wrap, a))
elif isinstance(a, tuple):
return tuple(map(wrap, a))
elif isinstance(a, str):
# this does not allocate memory to a pointer,
# this only works if the destination to write to is an array and not a pointer
return a.encode('utf-8')
return a
def buildnew(f):
return lambda *l: f(*wrap(l))
NULL = ffi.NULL
orsa2 = a()
for n in dir(orsa):
setattr(orsa2, n, buildnew(getattr(orsa, n)))
orsa = orsa2