1
1
program test_linalg
2
+
2
3
use stdlib_experimental_error, only: check
3
4
use stdlib_experimental_kinds, only: sp, dp, qp, int8, int16, int32, int64
4
5
use stdlib_experimental_linalg, only: diag, eye, trace
6
+
5
7
implicit none
8
+
9
+ real (sp), parameter :: sptol = 1000 * epsilon (1._sp )
10
+ real (dp), parameter :: dptol = 1000 * epsilon (1._dp )
11
+ real (qp), parameter :: qptol = 1000 * epsilon (1._qp )
12
+
6
13
logical :: warn
14
+
15
+ ! whether calls to check issue a warning
16
+ ! or stop execution
7
17
warn = .false.
8
18
9
19
!
@@ -59,12 +69,12 @@ subroutine test_eye
59
69
msg= " all(eye(5) == diag([(1,i=1,5)] failed." ,warn= warn)
60
70
61
71
rye = eye(6 )
62
- call check(sum (rye - diag([(1.0_sp ,i= 1 ,6 )])) < epsilon (rye) , &
63
- msg= " sum(rye - diag([(1.0_sp,i=1,6)])) < epsilon(rye) failed." ,warn= warn)
72
+ call check(sum (rye - diag([(1.0_sp ,i= 1 ,6 )])) < sptol , &
73
+ msg= " sum(rye - diag([(1.0_sp,i=1,6)])) < sptol failed." ,warn= warn)
64
74
65
75
cye = eye(7 )
66
- call check(abs (trace(cye) - complex (7.0_sp ,0.0_sp )) < epsilon ( 1.0_sp ) , &
67
- msg= " abs(trace(cye) - complex(7.0_sp,0.0_sp)) < epsilon(1.0_sp) failed." ,warn= warn)
76
+ call check(abs (trace(cye) - complex (7.0_sp ,0.0_sp )) < sptol , &
77
+ msg= " abs(trace(cye) - complex(7.0_sp,0.0_sp)) < sptol failed." ,warn= warn)
68
78
end subroutine
69
79
70
80
subroutine test_diag_rsp
@@ -95,8 +105,8 @@ subroutine test_diag_rsp_k
95
105
call check(all (a == b), &
96
106
msg= " all(a == b) failed." ,warn= warn)
97
107
98
- call check(sum (diag(a,- 1 )) - (n-1 ) < epsilon ( 1.0_sp ) , &
99
- msg= " sum(diag(a,-1)) - (n-1) < epsilon(1.0_sp) failed." ,warn= warn)
108
+ call check(sum (diag(a,- 1 )) - (n-1 ) < sptol , &
109
+ msg= " sum(diag(a,-1)) - (n-1) < sptol failed." ,warn= warn)
100
110
101
111
call check(all (a == transpose (diag([(1._sp ,i= 1 ,n-1 )],1 ))), &
102
112
msg= " all(a == transpose(diag([(1._sp,i=1,n-1)],1))) failed" ,warn= warn)
@@ -151,10 +161,10 @@ subroutine test_diag_csp
151
161
call check(all (a == b), &
152
162
msg= " all(a == b) failed." ,warn= warn)
153
163
154
- call check(all (abs (real (diag(a)) - [(i,i= 1 ,n)]) < epsilon ( 1.0_sp ) ), &
155
- msg= " all(abs(real(diag(a)) - [(i,i=1,n)]) < epsilon(1.0_sp) )" , warn= warn)
156
- call check(all (abs (aimag (diag(a)) - [(1 ,i= 1 ,n)]) < epsilon ( 1.0_sp ) ), &
157
- msg= " all(abs(aimag(diag(a)) - [(1,i=1,n)]) < epsilon(1.0_sp) )" , warn= warn)
164
+ call check(all (abs (real (diag(a)) - [(i,i= 1 ,n)]) < sptol ), &
165
+ msg= " all(abs(real(diag(a)) - [(i,i=1,n)]) < sptol )" , warn= warn)
166
+ call check(all (abs (aimag (diag(a)) - [(1 ,i= 1 ,n)]) < sptol ), &
167
+ msg= " all(abs(aimag(diag(a)) - [(1,i=1,n)]) < sptol )" , warn= warn)
158
168
end subroutine
159
169
160
170
subroutine test_diag_cdp
@@ -204,7 +214,6 @@ subroutine test_diag_int16
204
214
msg= " all(diag(a) == pack(a,mask))" , warn= warn)
205
215
call check(all (diag(diag(a)) == merge (a,0_int16 ,mask)), &
206
216
msg= " all(diag(diag(a)) == merge(a,0_int16,mask)) failed." , warn= warn)
207
- a = unpack (int ([1 ,2 ,3 ,4 ],int16),eye(n)==1 ,a)
208
217
end subroutine
209
218
subroutine test_diag_int32
210
219
integer , parameter :: n = 3
@@ -261,8 +270,8 @@ subroutine test_trace_rsp
261
270
integer :: i
262
271
write (* ,* ) " test_trace_rsp"
263
272
a = reshape ([(i,i= 1 ,n** 2 )],[n,n])
264
- call check(abs (trace(a) - sum (diag(a))) < epsilon ( 1.0_sp ) , &
265
- msg= " abs(trace(a) - sum(diag(a))) < epsilon(1.0_sp) failed." ,warn= warn)
273
+ call check(abs (trace(a) - sum (diag(a))) < sptol , &
274
+ msg= " abs(trace(a) - sum(diag(a))) < sptol failed." ,warn= warn)
266
275
end subroutine
267
276
268
277
subroutine test_trace_rsp_nonsquare
@@ -278,8 +287,8 @@ subroutine test_trace_rsp_nonsquare
278
287
a = reshape ([(i,i= 1 ,n* (n+1 ))],[n,n+1 ])
279
288
ans = sum ([1._sp ,6._sp ,11._sp ,16._sp ])
280
289
281
- call check(abs (trace(a) - ans) < epsilon ( 1.0_sp ) , &
282
- msg= " abs(trace(a) - ans) < epsilon(1.0_sp) failed." ,warn= warn)
290
+ call check(abs (trace(a) - ans) < sptol , &
291
+ msg= " abs(trace(a) - ans) < sptol failed." ,warn= warn)
283
292
end subroutine
284
293
285
294
subroutine test_trace_rdp
@@ -288,8 +297,8 @@ subroutine test_trace_rdp
288
297
integer :: i
289
298
write (* ,* ) " test_trace_rdp"
290
299
a = reshape ([(i,i= 1 ,n** 2 )],[n,n])
291
- call check(abs (trace(a) - sum (diag(a))) < epsilon ( 1.0_dp ) , &
292
- msg= " abs(trace(a) - sum(diag(a))) < epsilon(1.0_dp) failed." ,warn= warn)
300
+ call check(abs (trace(a) - sum (diag(a))) < dptol , &
301
+ msg= " abs(trace(a) - sum(diag(a))) < dptol failed." ,warn= warn)
293
302
end subroutine
294
303
295
304
subroutine test_trace_rdp_nonsquare
@@ -305,8 +314,8 @@ subroutine test_trace_rdp_nonsquare
305
314
a = reshape ([(i** 2 ,i= 1 ,n* (n-1 ))],[n,n-1 ])
306
315
ans = sum ([1._dp ,36._dp ,121._dp ])
307
316
308
- call check(abs (trace(a) - ans) < epsilon ( 1.0_dp ) , &
309
- msg= " abs(trace(a) - ans) < epsilon(1.0_sp) failed." ,warn= warn)
317
+ call check(abs (trace(a) - ans) < dptol , &
318
+ msg= " abs(trace(a) - ans) < dptol failed." ,warn= warn)
310
319
end subroutine
311
320
312
321
subroutine test_trace_rqp
@@ -315,8 +324,8 @@ subroutine test_trace_rqp
315
324
integer :: i
316
325
write (* ,* ) " test_trace_rqp"
317
326
a = reshape ([(i,i= 1 ,n** 2 )],[n,n])
318
- call check(abs (trace(a) - sum (diag(a))) < epsilon ( 1.0_qp ) , &
319
- msg= " abs(trace(a) - sum(diag(a))) < epsilon(1.0_qp) failed." ,warn= warn)
327
+ call check(abs (trace(a) - sum (diag(a))) < qptol , &
328
+ msg= " abs(trace(a) - sum(diag(a))) < qptol failed." ,warn= warn)
320
329
end subroutine
321
330
322
331
@@ -336,8 +345,8 @@ subroutine test_trace_csp
336
345
b = re + im* i_
337
346
338
347
! tr(A + B) = tr(A) + tr(B)
339
- call check(abs (trace(a+ b) - (trace(a) + trace(b))) < 10 * epsilon ( 1.0_sp ) , &
340
- msg= " abs(trace(a+b) - (trace(a) + trace(b))) < 10*epsilon(1.0_sp) failed." ,warn= warn)
348
+ call check(abs (trace(a+ b) - (trace(a) + trace(b))) < sptol , &
349
+ msg= " abs(trace(a+b) - (trace(a) + trace(b))) < sptol failed." ,warn= warn)
341
350
end subroutine
342
351
343
352
subroutine test_trace_cdp
@@ -350,8 +359,8 @@ subroutine test_trace_cdp
350
359
a = reshape ([(j + (n** 2 - (j-1 ))* i_,j= 1 ,n** 2 )],[n,n])
351
360
ans = complex (15 ,15 ) ! (1 + 5 + 9) + (9 + 5 + 1)i
352
361
353
- call check(abs (trace(a) - ans) < epsilon ( 1.0_dp ) , &
354
- msg= " abs(trace(a) - ans) < epsilon(1.0_dp) failed." ,warn= warn)
362
+ call check(abs (trace(a) - ans) < dptol , &
363
+ msg= " abs(trace(a) - ans) < dptol failed." ,warn= warn)
355
364
end subroutine
356
365
357
366
subroutine test_trace_cqp
@@ -360,8 +369,8 @@ subroutine test_trace_cqp
360
369
complex (qp), parameter :: i_ = complex (0 ,1 )
361
370
write (* ,* ) " test_trace_cqp"
362
371
a = 3 * eye(n) + 4 * eye(n)* i_ ! pythagorean triple
363
- call check(abs (trace(a)) - 3 * 5.0_qp < epsilon ( 1.0_qp ) , &
364
- msg= " abs(trace(a)) - 3*5.0_qp < epsilon(1.0_qp) failed." ,warn= warn)
372
+ call check(abs (trace(a)) - 3 * 5.0_qp < qptol , &
373
+ msg= " abs(trace(a)) - 3*5.0_qp < qptol failed." ,warn= warn)
365
374
end subroutine
366
375
367
376
0 commit comments