Skip to content

Commit ff41ffb

Browse files
authored
Merge pull request fortran-lang#192 from jvdp1/cov_correction
Addition of tests and clarification for cov
2 parents 4ac0208 + bf063a6 commit ff41ffb

File tree

2 files changed

+112
-56
lines changed

2 files changed

+112
-56
lines changed

src/stdlib_experimental_stats_cov.fypp

Lines changed: 60 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -178,52 +178,58 @@ contains
178178
, merge(size(x, 1), size(x, 2), mask = 1<dim))
179179

180180
integer :: i, j, n
181-
${t1}$ :: mean_(merge(size(x, 1), size(x, 2), mask = 1<dim))
182-
${t1}$ :: center(size(x, 1),size(x, 2))
181+
${t1}$ :: centeri_(merge(size(x, 2), size(x, 1), mask = 1<dim))
182+
${t1}$ :: centerj_(merge(size(x, 2), size(x, 1), mask = 1<dim))
183+
logical :: mask_(merge(size(x, 2), size(x, 1), mask = 1<dim))
183184

184-
mean_ = mean(x, dim, mask = mask)
185185
select case(dim)
186186
case(1)
187-
do i = 1, size(x, 1)
188-
center(i, :) = merge( x(i, :) - mean_,&
187+
do i = 1, size(res, 2)
188+
do j = 1, size(res, 1)
189+
mask_ = merge(.true., .false., mask(:, i) .and. mask(:, j))
190+
centeri_ = merge( x(:, i) - mean(x(:, i), mask = mask_),&
189191
#:if t1[0] == 'r'
190192
0._${k1}$,&
191193
#:else
192194
cmplx(0,0,kind=${k1}$),&
193195
#:endif
194-
mask(i, :))
195-
end do
196-
#:if t1[0] == 'r'
197-
res = matmul( transpose(center), center)
198-
#:else
199-
res = matmul( transpose(conjg(center)), center)
200-
#:endif
201-
do j = 1, size(res, 2)
202-
do i = 1, size(res, 1)
203-
n = count(merge(.true., .false., mask(:, i) .and. mask(:, j)))
204-
res(i, j) = res(i, j) / (n - merge(1, 0,&
196+
mask_)
197+
centerj_ = merge( x(:, j) - mean(x(:, j), mask = mask_),&
198+
#:if t1[0] == 'r'
199+
0._${k1}$,&
200+
#:else
201+
cmplx(0,0,kind=${k1}$),&
202+
#:endif
203+
mask_)
204+
205+
n = count(mask_)
206+
res(j, i) = dot_product( centerj_, centeri_)&
207+
/ (n - merge(1, 0,&
205208
optval(corrected, .true.) .and. n > 0))
206209
end do
207210
end do
208211
case(2)
209-
do i = 1, size(x, 2)
210-
center(:, i) = merge( x(:, i) - mean_,&
212+
do i = 1, size(res, 2)
213+
do j = 1, size(res, 1)
214+
mask_ = merge(.true., .false., mask(i, :) .and. mask(j, :))
215+
centeri_ = merge( x(i, :) - mean(x(i, :), mask = mask_),&
211216
#:if t1[0] == 'r'
212217
0._${k1}$,&
213218
#:else
214219
cmplx(0,0,kind=${k1}$),&
215220
#:endif
216-
mask(:, i))
217-
end do
218-
#:if t1[0] == 'r'
219-
res = matmul( center, transpose(center))
220-
#:else
221-
res = matmul( center, transpose(conjg(center)))
222-
#:endif
223-
do j = 1, size(res, 2)
224-
do i = 1, size(res, 1)
225-
n = count(merge(.true., .false., mask(i, :) .and. mask(j, :)))
226-
res(i, j) = res(i, j) / (n - merge(1, 0,&
221+
mask_)
222+
centerj_ = merge( x(j, :) - mean(x(j, :), mask = mask_),&
223+
#:if t1[0] == 'r'
224+
0._${k1}$,&
225+
#:else
226+
cmplx(0,0,kind=${k1}$),&
227+
#:endif
228+
mask_)
229+
230+
n = count(mask_)
231+
res(j, i) = dot_product( centeri_, centerj_)&
232+
/ (n - merge(1, 0,&
227233
optval(corrected, .true.) .and. n > 0))
228234
end do
229235
end do
@@ -246,36 +252,38 @@ contains
246252
, merge(size(x, 1), size(x, 2), mask = 1<dim))
247253

248254
integer :: i, j, n
249-
real(dp) :: mean_(merge(size(x, 1), size(x, 2), mask = 1<dim))
250-
real(dp) :: center(size(x, 1),size(x, 2))
255+
real(dp) :: centeri_(merge(size(x, 2), size(x, 1), mask = 1<dim))
256+
real(dp) :: centerj_(merge(size(x, 2), size(x, 1), mask = 1<dim))
257+
logical :: mask_(merge(size(x, 2), size(x, 1), mask = 1<dim))
251258

252-
mean_ = mean(x, dim, mask = mask)
253259
select case(dim)
254260
case(1)
255-
do i = 1, size(x, 1)
256-
center(i, :) = merge( x(i, :) - mean_,&
257-
0._dp,&
258-
mask(i, :))
259-
end do
260-
res = matmul( transpose(center), center)
261-
do j = 1, size(res, 2)
262-
do i = 1, size(res, 1)
263-
n = count(merge(.true., .false., mask(:, i) .and. mask(:, j)))
264-
res(i, j) = res(i, j) / (n - merge(1, 0,&
261+
do i = 1, size(res, 2)
262+
do j = 1, size(res, 1)
263+
mask_ = merge(.true., .false., mask(:, i) .and. mask(:, j))
264+
centeri_ = merge( x(:, i) - mean(x(:, i), mask = mask_),&
265+
0._dp, mask_)
266+
centerj_ = merge( x(:, j) - mean(x(:, j), mask = mask_),&
267+
0._dp, mask_)
268+
269+
n = count(mask_)
270+
res(j, i) = dot_product( centerj_, centeri_)&
271+
/ (n - merge(1, 0,&
265272
optval(corrected, .true.) .and. n > 0))
266273
end do
267274
end do
268275
case(2)
269-
do i = 1, size(x, 2)
270-
center(:, i) = merge( x(:, i) - mean_,&
271-
0._dp,&
272-
mask(:, i))
273-
end do
274-
res = matmul( center, transpose(center))
275-
do j = 1, size(res, 2)
276-
do i = 1, size(res, 1)
277-
n = count(merge(.true., .false., mask(i, :) .and. mask(j, :)))
278-
res(i, j) = res(i, j) / (n - merge(1, 0,&
276+
do i = 1, size(res, 2)
277+
do j = 1, size(res, 1)
278+
mask_ = merge(.true., .false., mask(i, :) .and. mask(j, :))
279+
centeri_ = merge( x(i, :) - mean(x(i, :), mask = mask_),&
280+
0._dp, mask_)
281+
centerj_ = merge( x(j, :) - mean(x(j, :), mask = mask_),&
282+
0._dp, mask_)
283+
284+
n = count(mask_)
285+
res(j, i) = dot_product( centeri_, centerj_)&
286+
/ (n - merge(1, 0,&
279287
optval(corrected, .true.) .and. n > 0))
280288
end do
281289
end do

src/tests/stats/test_cov.f90

Lines changed: 52 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,14 @@ subroutine test_sp(x, x2)
126126
) < sptol)&
127127
, 'sp check 16')
128128

129+
call check( all( abs( cov(x2, 1, mask = x2 < 1000) - cov(x2, 1))&
130+
< sptol)&
131+
, 'sp check 17')
132+
133+
call check( all( abs( cov(x2, 2, mask = x2 < 1000) - cov(x2, 2))&
134+
< sptol)&
135+
, 'sp check 18')
136+
129137
end subroutine test_sp
130138

131139
subroutine test_dp(x, x2)
@@ -213,6 +221,13 @@ subroutine test_dp(x, x2)
213221
) < dptol)&
214222
, 'dp check 16')
215223

224+
call check( all( abs( cov(x2, 1, mask = x2 < 1000) - cov(x2, 1))&
225+
< dptol)&
226+
, 'dp check 17')
227+
228+
call check( all( abs( cov(x2, 2, mask = x2 < 1000) - cov(x2, 2))&
229+
< dptol)&
230+
, 'dp check 18')
216231

217232
end subroutine test_dp
218233

@@ -301,6 +316,14 @@ subroutine test_int32(x, x2)
301316
) < dptol)&
302317
, 'int32 check 16')
303318

319+
call check( all( abs( cov(x2, 1, mask = x2 < 1000) - cov(x2, 1))&
320+
< dptol)&
321+
, 'int32 check 17')
322+
323+
call check( all( abs( cov(x2, 2, mask = x2 < 1000) - cov(x2, 2))&
324+
< dptol)&
325+
, 'int32 check 18')
326+
304327
end subroutine test_int32
305328

306329
subroutine test_int64(x, x2)
@@ -388,6 +411,14 @@ subroutine test_int64(x, x2)
388411
) < dptol)&
389412
, 'int64 check 16')
390413

414+
call check( all( abs( cov(x2, 1, mask = x2 < 1000) - cov(x2, 1))&
415+
< dptol)&
416+
, 'int64 check 17')
417+
418+
call check( all( abs( cov(x2, 2, mask = x2 < 1000) - cov(x2, 2))&
419+
< dptol)&
420+
, 'int64 check 18')
421+
391422
end subroutine test_int64
392423

393424
subroutine test_csp(x, x2)
@@ -459,20 +490,29 @@ subroutine test_csp(x, x2)
459490
! call check( ieee_is_nan(real(cd(3,3)))&
460491
! , 'csp check 10 bis')
461492

493+
494+
call check( all( abs( cov(x2, 1, mask = aimag(x2) < 8) - cov(x2, 1))&
495+
< sptol)&
496+
, 'csp check 11')
497+
498+
call check( all( abs( cov(x2, 2, mask = aimag(x2) < 8) - cov(x2, 2))&
499+
< sptol)&
500+
, 'csp check 12')
501+
462502
call check( all( abs( cov(x2, 2, mask = aimag(x2) < 6) - reshape([&
463503
(4._sp,0._sp), (0._sp,2._sp)&
464504
,(0._sp,-2._sp), (2._sp,0._sp)]&
465505
,[ size(x2, 1), size(x2, 1)])&
466506
) < sptol)&
467-
, 'csp check 11')
507+
, 'csp check 13')
468508

469509
call check( all( abs( cov(x2, 2, mask = aimag(x2) < 6, corrected = .false.) -&
470510
reshape([&
471511
(2.6666666666666666_sp,0._sp), (0._sp,1._sp)&
472512
,(0._sp,-1._sp), (1._sp,0._sp)]&
473513
,[ size(x2, 1), size(x2, 1)])&
474514
) < sp)&
475-
, 'csp check 12')
515+
, 'csp check 14')
476516

477517
end subroutine test_csp
478518

@@ -546,20 +586,28 @@ subroutine test_cdp(x, x2)
546586
! call check( ieee_is_nan(real(cd(3,3)))&
547587
! , 'cdp check 10 bis')
548588

589+
call check( all( abs( cov(x2, 1, mask = aimag(x2) < 8) - cov(x2, 1))&
590+
< dptol)&
591+
, 'cdp check 11')
592+
593+
call check( all( abs( cov(x2, 2, mask = aimag(x2) < 8) - cov(x2, 2))&
594+
< dptol)&
595+
, 'cdp check 12')
596+
549597
call check( all( abs( cov(x2, 2, mask = aimag(x2) < 6) - reshape([&
550598
(4._dp,0._dp), (0._dp,2._dp)&
551599
,(0._dp,-2._dp), (2._dp,0._dp)]&
552600
,[ size(x2, 1), size(x2, 1)])&
553601
) < dptol)&
554-
, 'cdp check 11')
602+
, 'cdp check 13')
555603

556604
call check( all( abs( cov(x2, 2, mask = aimag(x2) < 6, corrected = .false.) -&
557605
reshape([&
558606
(2.6666666666666666_dp,0._dp), (0._dp,1._dp)&
559607
,(0._dp,-1._dp), (1._dp,0._dp)]&
560608
,[ size(x2, 1), size(x2, 1)])&
561609
) < dptol)&
562-
, 'cdp check 12')
610+
, 'cdp check 14')
563611

564612
end subroutine test_cdp
565613

0 commit comments

Comments
 (0)