Skip to content

Commit 8ca5d50

Browse files
committed
complex types
1 parent 6431620 commit 8ca5d50

File tree

2 files changed

+174
-104
lines changed

2 files changed

+174
-104
lines changed

src/linalg/cblas.mojo

+156-104
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,35 @@ from sys.ffi import DLHandle
22
from os.path import isfile
33
from os.env import getenv
44

5+
alias F32 = Float32
6+
alias F64 = Float64
7+
8+
9+
@value
10+
struct C32:
11+
var real: F32
12+
var imaginary: F32
13+
14+
fn __init__(inout self, r: F32, i: F32):
15+
self.real = r
16+
self.imaginary = i
17+
18+
19+
@value
20+
struct C64:
21+
var real: F64
22+
var imaginary: F64
23+
24+
fn __init__(inout self, r: F64, i: F64):
25+
self.real = r
26+
self.imaginary = i
27+
28+
29+
alias PF32 = UnsafePointer[F32]
30+
alias PF64 = UnsafePointer[F64]
31+
alias PC32 = UnsafePointer[C32]
32+
alias PC64 = UnsafePointer[C64]
33+
534

635
struct CBLAS:
736
# enum CBLAS_ORDER {CblasRowMajor=101, CblasColMajor=102};
@@ -21,59 +50,76 @@ struct CBLAS:
2150
alias CblasLeft = 141
2251
alias CblasRight = 142
2352

24-
alias PF32 = UnsafePointer[Float32]
25-
alias PF64 = UnsafePointer[Float64]
26-
alias PC32 = UnsafePointer[(Float32, Float32)]
27-
alias PC64 = UnsafePointer[(Float64, Float64)]
28-
29-
alias SDsdotType = fn (
30-
Int, Float32, Self.PF32, Int, Self.PF32, Int
31-
) -> Float32
32-
alias DSdotType = fn (
33-
Int, Float32, Self.PF32, Int, Self.PF32, Int
34-
) -> Float64
35-
alias SDotType = fn (Int, Self.PF32, Int, Self.PF32, Int) -> Float32
36-
alias DDotType = fn (Int, Self.PF64, Int, Self.PF64, Int) -> Float64
37-
alias CDotSubType = fn (
38-
Int, Self.PC32, Int, Self.PC32, Int, Self.PC32
39-
) -> None
40-
alias ZDotSubType = fn (
41-
Int, Self.PC64, Int, Self.PC64, Int, Self.PC64
42-
) -> None
43-
alias SReductType = fn (Int, Self.PF32, Int) -> Float32
44-
alias DReductType = fn (Int, Self.PF64, Int) -> Float64
45-
alias CReductType = fn (Int, Self.PC32, Int) -> Float32
46-
alias ZReductType = fn (Int, Self.PC64, Int) -> Float64
47-
alias SWhichType = fn (Int, Self.PF32, Int) -> Int
48-
alias DWhichType = fn (Int, Self.PF64, Int) -> Int
49-
alias CWhichType = fn (Int, Self.PC32, Int) -> Int
50-
alias ZWhichType = fn (Int, Self.PC64, Int) -> Int
51-
alias SSwapType = fn (Int, Self.PF32, Int, Self.PF32, Int) -> None
52-
alias SAxpy = fn (Int, Float32, Self.PF32, Int, Self.PF32, Int) -> None
53-
54-
var sdsdot: Self.SDsdotType
55-
var dsdot: Self.DSdotType
53+
# float cblas_sdsdot(const int N, const float alpha, const float *X, const int incX, const float *Y, const int incY);
54+
alias SDSDotType = fn (Int, F32, PF32, Int, PF32, Int) -> F32
55+
# double cblas_dsdot(const int N, const float *X, const int incX, const float *Y, const int incY);
56+
alias DSDotType = fn (Int, F32, PF32, Int, PF32, Int) -> F64
57+
# float cblas_sdot(const int N, const float *X, const int incX, const float *Y, const int incY);
58+
alias SDotType = fn (Int, PF32, Int, PF32, Int) -> F32
59+
# double cblas_ddot(const int N, const double *X, const int incX, const double *Y, const int incY);
60+
alias DDotType = fn (Int, PF64, Int, PF64, Int) -> F64
61+
# void cblas_cdotu_sub(const int N, const void *X, const int incX, const void *Y, const int incY, void *dotu);
62+
alias CDotUSubType = fn (Int, PC32, Int, PC32, Int, PC32) -> None
63+
# void cblas_cdotc_sub(const int N, const void *X, const int incX, const void *Y, const int incY, void *dotc);
64+
alias CDotCSubType = fn (Int, PC32, Int, PC32, Int, PC32) -> None
65+
# void cblas_zdotu_sub(const int N, const void *X, const int incX, const void *Y, const int incY, void *dotu);
66+
alias ZDotUSubType = fn (Int, PC64, Int, PC64, Int, PC64) -> None
67+
# void cblas_zdotc_sub(const int N, const void *X, const int incX, const void *Y, const int incY, void *dotc);
68+
alias ZDotCSubType = fn (Int, PC64, Int, PC64, Int, PC64) -> None
69+
# float cblas_snrm2(const int N, const float *X, const int incX);
70+
alias SNrm2Type = fn (Int, PF32, Int) -> F32
71+
# float cblas_sasum(const int N, const float *X, const int incX);
72+
alias SASumType = fn (Int, PF32, Int) -> F32
73+
# double cblas_dnrm2(const int N, const double *X, const int incX);
74+
alias DNrm2Type = fn (Int, PF64, Int) -> F64
75+
# double cblas_dasum(const int N, const double *X, const int incX);
76+
alias DASumType = fn (Int, PF64, Int) -> F64
77+
# float cblas_scnrm2(const int N, const void *X, const int incX);
78+
alias SCNrm2Type = fn (Int, PC32, Int) -> F32
79+
# float cblas_scasum(const int N, const void *X, const int incX);
80+
alias SCASumType = fn (Int, PC32, Int) -> F32
81+
# double cblas_dznrm2(const int N, const void *X, const int incX);
82+
alias DZNrm2Type = fn (Int, PC64, Int) -> F64
83+
# double cblas_dzasum(const int N, const void *X, const int incX);
84+
alias DZASumType = fn (Int, PC64, Int) -> F64
85+
# CBLAS_INDEX cblas_isamax(const int N, const float *X, const int incX);
86+
alias ISAMaxType = fn (Int, PF32, Int) -> Int
87+
# CBLAS_INDEX cblas_idamax(const int N, const double *X, const int incX);
88+
alias IDAMaxType = fn (Int, PF64, Int) -> Int
89+
# CBLAS_INDEX cblas_icamax(const int N, const void *X, const int incX);
90+
alias ICAMaxType = fn (Int, PC32, Int) -> Int
91+
# CBLAS_INDEX cblas_izamax(const int N, const void *X, const int incX);
92+
alias IZAMaxType = fn (Int, PC64, Int) -> Int
93+
# void cblas_sswap(const int N, float *X, const int incX, float *Y, const int incY);
94+
alias SSwapType = fn (Int, PF32, Int, PF32, Int) -> None
95+
# void cblas_scopy(const int N, const float *X, const int incX, float *Y, const int incY);
96+
alias SCopyType = fn (Int, PF32, Int, PF32, Int) -> None
97+
# void cblas_saxpy(const int N, const float alpha, const float *X, const int incX, float *Y, const int incY);
98+
alias SAxpyType = fn (Int, F32, PF32, Int, PF32, Int) -> None
99+
100+
var sdsdot: Self.SDSDotType
101+
var dsdot: Self.DSDotType
56102
var sdot: Self.SDotType
57103
var ddot: Self.DDotType
58-
var cdotc_sub: Self.CDotSubType
59-
var cdotu_sub: Self.CDotSubType
60-
var zdotc_sub: Self.ZDotSubType
61-
var zdotu_sub: Self.ZDotSubType
62-
var snrm2: Self.SReductType
63-
var sasum: Self.SReductType
64-
var dnrm2: Self.DReductType
65-
var dasum: Self.DReductType
66-
var scnrm2: Self.CReductType
67-
var scasum: Self.CReductType
68-
var dznrm2: Self.ZReductType
69-
var dzasum: Self.ZReductType
70-
var isamax: Self.SWhichType
71-
var idamax: Self.DWhichType
72-
var icamax: Self.CWhichType
73-
var izamax: Self.ZWhichType
104+
var cdotu_sub: Self.CDotUSubType
105+
var cdotc_sub: Self.CDotCSubType
106+
var zdotu_sub: Self.ZDotUSubType
107+
var zdotc_sub: Self.ZDotCSubType
108+
var snrm2: Self.SNrm2Type
109+
var sasum: Self.SASumType
110+
var dnrm2: Self.DNrm2Type
111+
var dasum: Self.DASumType
112+
var scnrm2: Self.SCNrm2Type
113+
var scasum: Self.SCASumType
114+
var dznrm2: Self.DZNrm2Type
115+
var dzasum: Self.DZASumType
116+
var isamax: Self.ISAMaxType
117+
var idamax: Self.IDAMaxType
118+
var icamax: Self.ICAMaxType
119+
var izamax: Self.IZAMaxType
74120
var sswap: Self.SSwapType
75-
var scopy: Self.SSwapType
76-
var saxpy: Self.SAxpy
121+
var scopy: Self.SCopyType
122+
var saxpy: Self.SAxpyType
77123

78124
var h: DLHandle # Lifetime???
79125

@@ -87,76 +133,82 @@ struct CBLAS:
87133
self.h = DLHandle(path)
88134
if not self.h:
89135
raise Error("Cannot open dynamic library")
90-
# float cblas_sdsdot(const int N, const float alpha, const float *X, const int incX, const float *Y, const int incY);
91-
self.sdsdot = self.h.get_function[Self.SDsdotType]("cblas_sdsdot")
92-
# double cblas_dsdot(const int N, const float *X, const int incX, const float *Y, const int incY);
93-
self.dsdot = self.h.get_function[Self.DSdotType]("cblas_dsdot")
94-
# float cblas_sdot(const int N, const float *X, const int incX, const float *Y, const int incY);
136+
137+
self.sdsdot = self.h.get_function[Self.SDSDotType]("cblas_sdsdot")
138+
self.dsdot = self.h.get_function[Self.DSDotType]("cblas_dsdot")
95139
self.sdot = self.h.get_function[Self.SDotType]("cblas_sdot")
96-
# double cblas_ddot(const int N, const double *X, const int incX, const double *Y, const int incY);
97140
self.ddot = self.h.get_function[Self.DDotType]("cblas_ddot")
98-
# void cblas_cdotu_sub(const int N, const void *X, const int incX, const void *Y, const int incY, void *dotu);
99-
self.cdotu_sub = self.h.get_function[Self.CDotSubType](
141+
self.cdotu_sub = self.h.get_function[Self.CDotUSubType](
100142
"cblas_cdotu_sub"
101143
)
102-
# void cblas_cdotc_sub(const int N, const void *X, const int incX, const void *Y, const int incY, void *dotc);
103-
self.cdotc_sub = self.h.get_function[Self.CDotSubType](
144+
self.cdotc_sub = self.h.get_function[Self.CDotCSubType](
104145
"cblas_cdotc_sub"
105146
)
106-
# void cblas_zdotu_sub(const int N, const void *X, const int incX, const void *Y, const int incY, void *dotu);
107-
self.zdotu_sub = self.h.get_function[Self.ZDotSubType](
147+
self.zdotu_sub = self.h.get_function[Self.ZDotUSubType](
108148
"cblas_zdotu_sub"
109149
)
110-
# void cblas_zdotc_sub(const int N, const void *X, const int incX, const void *Y, const int incY, void *dotc);
111-
self.zdotc_sub = self.h.get_function[Self.ZDotSubType](
150+
self.zdotc_sub = self.h.get_function[Self.ZDotCSubType](
112151
"cblas_zdotc_sub"
113152
)
114-
# float cblas_snrm2(const int N, const float *X, const int incX);
115-
self.snrm2 = self.h.get_function[Self.SReductType]("cblas_snrm2")
116-
# float cblas_sasum(const int N, const float *X, const int incX);
117-
self.sasum = self.h.get_function[Self.SReductType]("cblas_sasum")
118-
# double cblas_dnrm2(const int N, const double *X, const int incX);
119-
self.dnrm2 = self.h.get_function[Self.DReductType]("cblas_dnrm2")
120-
# double cblas_dasum(const int N, const double *X, const int incX);
121-
self.dasum = self.h.get_function[Self.DReductType]("cblas_dasum")
122-
# float cblas_scnrm2(const int N, const void *X, const int incX);
123-
self.scnrm2 = self.h.get_function[Self.CReductType]("cblas_scnrm2")
124-
# float cblas_scasum(const int N, const void *X, const int incX);
125-
self.scasum = self.h.get_function[Self.CReductType]("cblas_scasum")
126-
# double cblas_dznrm2(const int N, const void *X, const int incX);
127-
self.dznrm2 = self.h.get_function[Self.ZReductType]("cblas_dznrm2")
128-
# double cblas_dzasum(const int N, const void *X, const int incX);
129-
self.dzasum = self.h.get_function[Self.ZReductType]("cblas_dzasum")
130-
# CBLAS_INDEX cblas_isamax(const int N, const float *X, const int incX);
131-
self.isamax = self.h.get_function[Self.SWhichType]("cblas_isamax")
132-
# CBLAS_INDEX cblas_idamax(const int N, const double *X, const int incX);
133-
self.idamax = self.h.get_function[Self.DWhichType]("cblas_idamax")
134-
# CBLAS_INDEX cblas_icamax(const int N, const void *X, const int incX);
135-
self.icamax = self.h.get_function[Self.CWhichType]("cblas_icamax")
136-
# CBLAS_INDEX cblas_izamax(const int N, const void *X, const int incX);
137-
self.izamax = self.h.get_function[Self.ZWhichType]("cblas_izamax")
138-
# void cblas_sswap(const int N, float *X, const int incX, float *Y, const int incY);
153+
self.snrm2 = self.h.get_function[Self.SNrm2Type]("cblas_snrm2")
154+
self.sasum = self.h.get_function[Self.SASumType]("cblas_sasum")
155+
self.dnrm2 = self.h.get_function[Self.DNrm2Type]("cblas_dnrm2")
156+
self.dasum = self.h.get_function[Self.DASumType]("cblas_dasum")
157+
self.scnrm2 = self.h.get_function[Self.SCNrm2Type]("cblas_scnrm2")
158+
self.scasum = self.h.get_function[Self.SCASumType]("cblas_scasum")
159+
self.dznrm2 = self.h.get_function[Self.DZNrm2Type]("cblas_dznrm2")
160+
self.dzasum = self.h.get_function[Self.DZASumType]("cblas_dzasum")
161+
self.isamax = self.h.get_function[Self.ISAMaxType]("cblas_isamax")
162+
self.idamax = self.h.get_function[Self.IDAMaxType]("cblas_idamax")
163+
self.icamax = self.h.get_function[Self.ICAMaxType]("cblas_icamax")
164+
self.izamax = self.h.get_function[Self.IZAMaxType]("cblas_izamax")
139165
self.sswap = self.h.get_function[Self.SSwapType]("cblas_sswap")
140-
# void cblas_scopy(const int N, const float *X, const int incX, float *Y, const int incY);
141-
self.scopy = self.h.get_function[Self.SSwapType]("cblas_scopy")
142-
# void cblas_saxpy(const int N, const float alpha, const float *X, const int incX, float *Y, const int incY);
143-
self.saxpy = self.h.get_function[Self.SAxpy]("cblas_saxpy")
166+
self.scopy = self.h.get_function[Self.SCopyType]("cblas_scopy")
167+
self.saxpy = self.h.get_function[Self.SAxpyType]("cblas_saxpy")
168+
169+
170+
from testing import *
144171

145172

146173
def main():
147-
var n: Int = 100
148-
var a: Float64 = 1
149-
var x = UnsafePointer[Float64].alloc(n.value)
150-
var x_inc: Int = 1
151-
var y = UnsafePointer[Float64].alloc(n.value)
152-
var y_inc: Int = 1
174+
var n: Int = 3
175+
176+
var cblas = CBLAS("/opt/homebrew/opt/openblas/lib/libopenblas.dylib")
177+
178+
var x32 = PF32.alloc(n.value)
179+
var y32 = PF32.alloc(n.value)
153180

154181
for i in range(n):
155-
x[i] = i
156-
y[i] = i
182+
x32[i] = i
183+
y32[i] = i
157184

158-
var cblas = CBLAS("/opt/homebrew/opt/openblas/lib/libopenblas.dylib")
185+
sdsdot_res = cblas.sdsdot(n, 2, x32, 1, y32, 1)
186+
assert_equal(sdsdot_res, 7)
187+
188+
var dsdot_res = cblas.dsdot(n, 2, x32, 1, y32, 1)
189+
assert_equal(dsdot_res, 5)
190+
191+
var sdot_res = cblas.sdot(n, x32, 1, y32, 1)
192+
assert_equal(sdot_res, 5)
193+
194+
var x64 = PF64.alloc(n.value)
195+
var y64 = PF64.alloc(n.value)
159196

160-
var res = cblas.ddot(n, x, x_inc, y, y_inc)
197+
for i in range(n):
198+
x64[i] = i
199+
y64[i] = i
200+
201+
var ddot_res = cblas.ddot(n, x64, 1, y64, 1)
202+
assert_equal(ddot_res, 5)
203+
204+
var cres32 = PC32.alloc(1)
205+
206+
var xc32 = PC32.alloc(n.value)
207+
var yc32 = PC32.alloc(n.value)
208+
209+
for i in range(n):
210+
xc32[i] = C32(i, i)
211+
yc32[i] = C32(i, i)
161212

162-
print(res)
213+
cblas.cdotu_sub(n, xc32, 1, yc32, 1, cres32)
214+
print(cres32[0].real, cres32[0].imaginary)

test/test_cblas.mojo

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from testing import *
2+
from linalg.cblas import CBLAS
3+
4+
5+
def test_sdsdot():
6+
var n: Int = 3
7+
var x = UnsafePointer[Float64].alloc(n.value)
8+
var y = UnsafePointer[Float64].alloc(n.value)
9+
10+
for i in range(n):
11+
x[i] = i
12+
y[i] = i
13+
14+
var cblas = CBLAS("/opt/homebrew/opt/openblas/lib/libopenblas.dylib")
15+
16+
var res = cblas.ddot(n, x, 1, y, 1)
17+
18+
assert_equal(res, 14)

0 commit comments

Comments
 (0)