@@ -148,21 +148,18 @@ subroutine mat_mult(myid,a,lena,b,lenb,c,lenc,a_b_c,debug)
148148 integer :: n_cont(2 ),kpart_next,ind_partN,k_off(2 )
149149 integer :: stat,ilen2(2 ),lenb_rem(2 )
150150 ! Remote variables to be allocated
151- integer (integ),allocatable :: ibpart_rem(:,: )
151+ integer (integ),allocatable :: ibpart_rem(:)
152152 type jagged_array_r
153153 real (double), allocatable :: values(:)
154154 end type jagged_array_r
155155 type (jagged_array_r) :: b_rem(2 )
156156 ! Remote variables which will point to part_array
157- type jagged_pointer_array_i
158- integer (integ),pointer :: values(:)
159- end type jagged_pointer_array_i
160- type (jagged_pointer_array_i) :: nbnab_rem(2 )
161- type (jagged_pointer_array_i) :: ibseq_rem(2 )
162- type (jagged_pointer_array_i) :: ibind_rem(2 )
163- type (jagged_pointer_array_i) :: ib_nd_acc_rem(2 )
164- type (jagged_pointer_array_i) :: ibndimj_rem(2 )
165- type (jagged_pointer_array_i) :: npxyz_rem(2 )
157+ integer (integ), dimension (:), pointer :: nbnab_rem
158+ integer (integ), dimension (:), pointer :: ibseq_rem
159+ integer (integ), dimension (:), pointer :: ibind_rem
160+ integer (integ), dimension (:), pointer :: ib_nd_acc_rem
161+ integer (integ), dimension (:), pointer :: ibndimj_rem
162+ integer (integ), dimension (:), pointer :: npxyz_rem
166163 ! Arrays for remote variables to point to
167164 integer , target :: part_array(3 * a_b_c% parts% mx_mem_grp+ &
168165 5 * a_b_c% parts% mx_mem_grp* a_b_c% bmat(1 )% mx_abs, 2 )
@@ -172,7 +169,7 @@ subroutine mat_mult(myid,a,lena,b,lenb,c,lenc,a_b_c,debug)
172169 integer , dimension (MPI_STATUS_SIZE) :: mpi_stat
173170 integer , allocatable :: recv_part(:)
174171 real (double) :: t0,t1
175- integer :: request(2 ,2 ), index_rec, index_wait
172+ integer :: request(2 ,2 ), index_rec, index_comp
176173
177174 logical :: new_partition(2 )
178175
@@ -181,9 +178,8 @@ subroutine mat_mult(myid,a,lena,b,lenb,c,lenc,a_b_c,debug)
181178 call start_timer(tmr_std_allocation)
182179 if (iprint_mat> 3.AND .myid== 0 ) t0 = mtime()
183180 ! Allocate memory for the elements
184- allocate (ibpart_rem(a_b_c% parts% mx_mem_grp* a_b_c% bmat(1 )% mx_abs, 2 ),STAT= stat)
181+ allocate (ibpart_rem(a_b_c% parts% mx_mem_grp* a_b_c% bmat(1 )% mx_abs),STAT= stat)
185182 if (stat/= 0 ) call cq_abort(' mat_mult: error allocating ibpart_rem' )
186- ! allocate(atrans(a_b_c%amat(1)%length),STAT=stat)
187183 allocate (atrans(lena),STAT= stat)
188184 if (stat/= 0 ) call cq_abort(' mat_mult: error allocating atrans' )
189185 allocate (recv_part(0 :a_b_c% comms% inode),STAT= stat)
@@ -234,59 +230,58 @@ subroutine mat_mult(myid,a,lena,b,lenb,c,lenc,a_b_c,debug)
234230
235231 ! These indices point to elements of all the 2-element vectors of the variables needed
236232 ! for the do_comms and m_kern_min/max calls. They alternate between the values of
237- ! (index_rec,index_wait )=(1,2) and (2,1) from iteration to iteration.
233+ ! (index_rec,index_comp )=(1,2) and (2,1) from iteration to iteration.
238234 ! index_rec points to the values being received in the current iteration in do_comms,
239- ! and index_wait points to the values received in the previous iteration, thus computation
235+ ! and index_comp points to the values received in the previous iteration, thus computation
240236 ! can start on them in m_kern_min/max
241237 ! These indices are also used to point to elements of the 2x2-element request() array,
242238 ! that contains the MPI request numbers for the non-blocking data receives. There are 2
243239 ! MPI_Irecv calls per call of do_comms, and request() keeps track of 2 sets of those calls,
244240 ! thus it's of size 2x2.
245241 ! request(:,index_rec) points to the requests from the current iteration MPI_Irecv,
246- ! and request(:,index_wait ) points to the requests from the previous iteration, that have
242+ ! and request(:,index_comp ) points to the requests from the previous iteration, that have
247243 ! to complete in order for the computation to start (thus the MPI_Wait).
248244 index_rec = mod (kpart,2 ) + 1
249- index_wait = mod (kpart+1 ,2 ) + 1
245+ index_comp = mod (kpart+1 ,2 ) + 1
250246
251247 ! Check that previous partition data have been received before starting computation
252248 if (kpart.gt. 2 ) then
253- if (request(1 ,index_wait ).ne. MPI_REQUEST_NULL) &
254- call MPI_Wait(request(1 ,index_wait ),MPI_STATUSES_IGNORE,ierr)
255- if (request(2 ,index_wait ).ne. MPI_REQUEST_NULL) &
256- call MPI_Wait(request(2 ,index_wait ),MPI_STATUSES_IGNORE,ierr)
249+ if (request(1 ,index_comp ).ne. MPI_REQUEST_NULL) &
250+ call MPI_Wait(request(1 ,index_comp ),MPI_STATUSES_IGNORE,ierr)
251+ if (request(2 ,index_comp ).ne. MPI_REQUEST_NULL) &
252+ call MPI_Wait(request(2 ,index_comp ),MPI_STATUSES_IGNORE,ierr)
257253 end if
258254
259255 ! If that previous partition was a periodic one, copy over arrays from previous index
260- if (.not. new_partition(index_wait )) then
261- part_array(:,index_wait ) = part_array(:,index_rec)
262- n_cont(index_wait ) = n_cont(index_rec)
263- ilen2(index_wait ) = ilen2(index_rec)
264- b_rem(index_wait ) = b_rem(index_rec)
265- lenb_rem(index_wait ) = lenb_rem(index_rec)
256+ if (.not. new_partition(index_comp )) then
257+ part_array(:,index_comp ) = part_array(:,index_rec)
258+ n_cont(index_comp ) = n_cont(index_rec)
259+ ilen2(index_comp ) = ilen2(index_rec)
260+ b_rem(index_comp ) = b_rem(index_rec)
261+ lenb_rem(index_comp ) = lenb_rem(index_rec)
266262 end if
267263
268264 ! Now point the _rem variables at the appropriate parts of
269265 ! the array where we have received the data
270266 offset = 0
271- nbnab_rem(index_wait) % values = > part_array(offset+1 :offset+ n_cont(index_wait),index_wait )
272- offset = offset+ n_cont(index_wait )
273- ibind_rem(index_wait) % values = > part_array(offset+1 :offset+ n_cont(index_wait),index_wait )
274- offset = offset+ n_cont(index_wait )
275- ib_nd_acc_rem(index_wait) % values = > part_array(offset+1 :offset+ n_cont(index_wait),index_wait )
276- offset = offset+ n_cont(index_wait )
277- ibseq_rem(index_wait) % values = > part_array(offset+1 :offset+ ilen2(index_wait),index_wait )
278- offset = offset+ ilen2(index_wait )
279- npxyz_rem(index_wait) % values = > part_array(offset+1 :offset+3 * ilen2(index_wait),index_wait )
280- offset = offset+3 * ilen2(index_wait )
281- ibndimj_rem(index_wait) % values = > part_array(offset+1 :offset+ ilen2(index_wait),index_wait )
282- if (offset+ ilen2(index_wait )>3 * a_b_c% parts% mx_mem_grp+ &
267+ nbnab_rem = > part_array(offset+1 :offset+ n_cont(index_comp),index_comp )
268+ offset = offset+ n_cont(index_comp )
269+ ibind_rem = > part_array(offset+1 :offset+ n_cont(index_comp),index_comp )
270+ offset = offset+ n_cont(index_comp )
271+ ib_nd_acc_rem = > part_array(offset+1 :offset+ n_cont(index_comp),index_comp )
272+ offset = offset+ n_cont(index_comp )
273+ ibseq_rem = > part_array(offset+1 :offset+ ilen2(index_comp),index_comp )
274+ offset = offset+ ilen2(index_comp )
275+ npxyz_rem = > part_array(offset+1 :offset+3 * ilen2(index_comp),index_comp )
276+ offset = offset+3 * ilen2(index_comp )
277+ ibndimj_rem = > part_array(offset+1 :offset+ ilen2(index_comp),index_comp )
278+ if (offset+ ilen2(index_comp )>3 * a_b_c% parts% mx_mem_grp+ &
283279 5 * a_b_c% parts% mx_mem_grp* a_b_c% bmat(1 )% mx_abs) then
284280 call cq_abort(' mat_mult: error pointing to part_array ' ,kpart-1 )
285281 end if
286282 ! Create ibpart_rem
287- call end_part_comms(myid,n_cont(index_wait),nbnab_rem(index_wait)% values, &
288- ibind_rem(index_wait)% values,npxyz_rem(index_wait)% values,&
289- ibpart_rem(:,index_wait),ncover_yz,a_b_c% gcs% ncoverz)
283+ call end_part_comms(myid,n_cont(index_comp),nbnab_rem, &
284+ ibind_rem,npxyz_rem,ibpart_rem,ncover_yz,a_b_c% gcs% ncoverz)
290285
291286 ! Receive the data from the current partition - non-blocking
292287 if (kpart.lt. a_b_c% ahalo% np_in_halo+1 ) then
@@ -303,17 +298,17 @@ subroutine mat_mult(myid,a,lena,b,lenb,c,lenc,a_b_c,debug)
303298
304299 ! Call the computation kernel on the previous partition
305300 if (a_b_c% mult_type.eq. 1 ) then ! C is full mult
306- call m_kern_max( k_off(index_wait ),kpart-1 ,ib_nd_acc_rem(index_wait) % values , ibind_rem(index_wait) % values , &
307- nbnab_rem(index_wait) % values ,ibpart_rem(:,index_wait), ibseq_rem(index_wait) % values , &
308- ibndimj_rem(index_wait) % values , atrans,b_rem(index_wait )% values,c,a_b_c% ahalo,a_b_c% chalo, &
301+ call m_kern_max( k_off(index_comp ),kpart-1 ,ib_nd_acc_rem, ibind_rem, &
302+ nbnab_rem,ibpart_rem, ibseq_rem, &
303+ ibndimj_rem, atrans,b_rem(index_comp )% values,c,a_b_c% ahalo,a_b_c% chalo, &
309304 a_b_c% ltrans,a_b_c% bmat(1 )% mx_abs,a_b_c% parts% mx_mem_grp, &
310- a_b_c% prim% mx_iprim, lena, lenb_rem(index_wait ), lenc)
305+ a_b_c% prim% mx_iprim, lena, lenb_rem(index_comp ), lenc)
311306 else if (a_b_c% mult_type.eq. 2 ) then ! A is partial mult
312- call m_kern_min( k_off(index_wait ),kpart-1 ,ib_nd_acc_rem(index_wait) % values , ibind_rem(index_wait) % values , &
313- nbnab_rem(index_wait) % values ,ibpart_rem(:,index_wait), ibseq_rem(index_wait) % values , &
314- ibndimj_rem(index_wait) % values , atrans,b_rem(index_wait )% values,c,a_b_c% ahalo,a_b_c% chalo, &
307+ call m_kern_min( k_off(index_comp ),kpart-1 ,ib_nd_acc_rem, ibind_rem, &
308+ nbnab_rem,ibpart_rem, ibseq_rem, &
309+ ibndimj_rem, atrans,b_rem(index_comp )% values,c,a_b_c% ahalo,a_b_c% chalo, &
315310 a_b_c% ltrans,a_b_c% bmat(1 )% mx_abs,a_b_c% parts% mx_mem_grp, &
316- a_b_c% prim% mx_iprim, lena, lenb_rem(index_wait ), lenc)
311+ a_b_c% prim% mx_iprim, lena, lenb_rem(index_comp ), lenc)
317312 end if
318313 ! $omp barrier
319314 end do main_loop
@@ -586,7 +581,7 @@ subroutine do_comms(k_off, kpart, part_array, n_cont, ilen2, a_b_c, b, recv_part
586581 integer , intent (in ) :: kpart
587582 type (matrix_mult), intent (in ) :: a_b_c
588583 real (double), intent (in ) :: b(:)
589- integer , allocatable , dimension (:), intent (inout ) :: recv_part
584+ integer , dimension (:), intent (inout ) :: recv_part
590585 real (double), allocatable , intent (inout ) :: b_rem(:)
591586 integer , intent (out ) :: lenb_rem
592587 integer , intent (in ) :: myid, ncover_yz
@@ -712,23 +707,23 @@ subroutine prefetch(this_part,ahalo,a_b_c,bmat,&
712707 ind_part,b,myid)
713708 else ! Else fetch the data
714709 ilen2 = a_b_c% ilen2rec(ipart,nnode)
715- if (.not. do_nonb_local) then ! Use blocking receive
716- call Mquest_get( prim% mx_ngonn, &
710+ if (do_nonb_local) then ! Use non-blocking receive
711+ if (.not. present (request)) call cq_abort(' Need to provide MPI request argument for non-blocking receive.' )
712+ call Mquest_get_nonb( prim% mx_ngonn, &
717713 a_b_c% ilen2rec(ipart,nnode),&
718714 a_b_c% ilen3rec(ipart,nnode),&
719715 n_cont,inode,ipart,myid,&
720716 bind_rem,b_rem,lenb_rem,bind,&
721717 a_b_c% istart(ipart,nnode), &
722- bmat(1 )% mx_abs,parts% mx_mem_grp,tag)
723- else ! Use non-blocking receive
724- if (.not. present (request)) call cq_abort(' Need to provide MPI request argument for non-blocking receive.' )
725- call Mquest_get_nonb( prim% mx_ngonn, &
718+ bmat(1 )% mx_abs,parts% mx_mem_grp,tag,request)
719+ else ! Use blocking receive
720+ call Mquest_get( prim% mx_ngonn, &
726721 a_b_c% ilen2rec(ipart,nnode),&
727722 a_b_c% ilen3rec(ipart,nnode),&
728723 n_cont,inode,ipart,myid,&
729724 bind_rem,b_rem,lenb_rem,bind,&
730725 a_b_c% istart(ipart,nnode), &
731- bmat(1 )% mx_abs,parts% mx_mem_grp,tag,request )
726+ bmat(1 )% mx_abs,parts% mx_mem_grp,tag)
732727 end if
733728 end if
734729 return
0 commit comments