Skip to content

Commit 31a1ccd

Browse files
committed
Simplifying OpenACC data transfers around the call to mpas_reconstruct_2d
This commit introduces two OpenACC data transfer routines, mpas_reconstruct_2d_h2d and mpas_reconstruct_2d_d2h in order to remove the data transfers from the mpas_reconstruct_2d routine itself. This also allows us to remove extraneous data movements within the atm_srk3 routine. mpas_reconstruct_2d_h2d and mpas_reconstruct_2d_d2h are called before and after the call to mpas_reconstruct in atm_mpas_init_block. And the reconstructed vector fields are also copied to and from the device before and after every dynamics call in mpas_atm_pre_dynamics_h2d and mpas_atm_post_dynamics_d2h.
1 parent b30677d commit 31a1ccd

File tree

3 files changed

+148
-11
lines changed

3 files changed

+148
-11
lines changed

src/core_atmosphere/dynamics/mpas_atm_time_integration.F

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -846,6 +846,7 @@ subroutine mpas_atm_pre_dynamics_h2d(domain)
846846

847847

848848
#ifdef MPAS_OPENACC
849+
type (mpas_pool_type), pointer :: mesh
849850
type (mpas_pool_type), pointer :: state
850851
type (mpas_pool_type), pointer :: diag
851852
type (mpas_pool_type), pointer :: tend
@@ -879,6 +880,10 @@ subroutine mpas_atm_pre_dynamics_h2d(domain)
879880
real (kind=RKIND), dimension(:,:,:), pointer :: scalars_1, scalars_2
880881
real (kind=RKIND), dimension(:,:), pointer :: ruAvg, wwAvg, ruAvg_split, wwAvg_split
881882

883+
integer, pointer :: nCells_ptr
884+
integer :: nCells
885+
real (kind=RKIND), dimension(:,:), pointer :: uReconstructZonal, uReconstructMeridional, uReconstructX, uReconstructY, uReconstructZ
886+
882887
real (kind=RKIND), dimension(:,:), pointer :: tend_ru, tend_rt, tend_rho, tend_rw, rt_diabatic_tend
883888
real (kind=RKIND), dimension(:,:), pointer :: tend_u_euler, tend_w_euler, tend_theta_euler
884889
real(kind=RKIND), dimension(:,:), pointer :: tend_w_pgf, tend_w_buoy
@@ -892,11 +897,13 @@ subroutine mpas_atm_pre_dynamics_h2d(domain)
892897

893898
real (kind=RKIND), dimension(:,:,:), pointer :: lbc_scalars, lbc_tend_scalars
894899

900+
nullify(mesh)
895901
nullify(state)
896902
nullify(diag)
897903
nullify(tend)
898904
nullify(tend_physics)
899905
nullify(lbc)
906+
call mpas_pool_get_subpool(domain % blocklist % structs, 'mesh', mesh)
900907
call mpas_pool_get_subpool(domain % blocklist % structs, 'state', state)
901908
call mpas_pool_get_subpool(domain % blocklist % structs, 'diag', diag)
902909
call mpas_pool_get_subpool(domain % blocklist % structs, 'tend', tend)
@@ -1006,6 +1013,19 @@ subroutine mpas_atm_pre_dynamics_h2d(domain)
10061013
call mpas_pool_get_array(diag, 'wwAvg_split', wwAvg_split)
10071014
!$acc enter data copyin(wwAvg_split)
10081015

1016+
call mpas_pool_get_dimension(mesh, 'nCellsSolve', nCells_ptr)
1017+
nCells = nCells_ptr
1018+
call mpas_pool_get_array(diag, 'uReconstructX', uReconstructX)
1019+
!$acc enter data create(uReconstructX(:,1:nCells))
1020+
call mpas_pool_get_array(diag, 'uReconstructY', uReconstructY)
1021+
!$acc enter data create(uReconstructY(:,1:nCells))
1022+
call mpas_pool_get_array(diag, 'uReconstructZ', uReconstructZ)
1023+
!$acc enter data create(uReconstructZ(:,1:nCells))
1024+
call mpas_pool_get_array(diag, 'uReconstructZonal', uReconstructZonal)
1025+
!$acc enter data create(uReconstructZonal(:,1:nCells))
1026+
call mpas_pool_get_array(diag, 'uReconstructMeridional', uReconstructMeridional)
1027+
!$acc enter data create(uReconstructMeridional(:,1:nCells))
1028+
10091029
call mpas_pool_get_array(state, 'u', u_1, 1)
10101030
!$acc enter data copyin(u_1)
10111031
call mpas_pool_get_array(state, 'u', u_2, 2)
@@ -1108,6 +1128,7 @@ subroutine mpas_atm_post_dynamics_d2h(domain)
11081128

11091129

11101130
#ifdef MPAS_OPENACC
1131+
type (mpas_pool_type), pointer :: mesh
11111132
type (mpas_pool_type), pointer :: state
11121133
type (mpas_pool_type), pointer :: diag
11131134
type (mpas_pool_type), pointer :: tend
@@ -1141,6 +1162,10 @@ subroutine mpas_atm_post_dynamics_d2h(domain)
11411162
real (kind=RKIND), dimension(:,:,:), pointer :: scalars_1, scalars_2
11421163
real (kind=RKIND), dimension(:,:), pointer :: ruAvg, wwAvg, ruAvg_split, wwAvg_split
11431164

1165+
integer, pointer :: nCells_ptr
1166+
integer :: nCells
1167+
real (kind=RKIND), dimension(:,:), pointer :: uReconstructZonal, uReconstructMeridional, uReconstructX, uReconstructY, uReconstructZ
1168+
11441169
real (kind=RKIND), dimension(:,:), pointer :: tend_ru, tend_rt, tend_rho, tend_rw, rt_diabatic_tend
11451170
real (kind=RKIND), dimension(:,:), pointer :: tend_u_euler, tend_w_euler, tend_theta_euler
11461171
real(kind=RKIND), dimension(:,:), pointer :: tend_w_pgf, tend_w_buoy
@@ -1154,11 +1179,13 @@ subroutine mpas_atm_post_dynamics_d2h(domain)
11541179

11551180
real (kind=RKIND), dimension(:,:,:), pointer :: lbc_scalars, lbc_tend_scalars
11561181

1182+
nullify(mesh)
11571183
nullify(state)
11581184
nullify(diag)
11591185
nullify(tend)
11601186
nullify(tend_physics)
11611187
nullify(lbc)
1188+
call mpas_pool_get_subpool(domain % blocklist % structs, 'mesh', mesh)
11621189
call mpas_pool_get_subpool(domain % blocklist % structs, 'state', state)
11631190
call mpas_pool_get_subpool(domain % blocklist % structs, 'diag', diag)
11641191
call mpas_pool_get_subpool(domain % blocklist % structs, 'tend', tend)
@@ -1268,6 +1295,19 @@ subroutine mpas_atm_post_dynamics_d2h(domain)
12681295
call mpas_pool_get_array(diag, 'wwAvg_split', wwAvg_split)
12691296
!$acc exit data copyout(wwAvg_split)
12701297

1298+
call mpas_pool_get_dimension(mesh, 'nCellsSolve', nCells_ptr)
1299+
nCells = nCells_ptr
1300+
call mpas_pool_get_array(diag, 'uReconstructX', uReconstructX)
1301+
!$acc exit data copyout(uReconstructX(:,1:nCells))
1302+
call mpas_pool_get_array(diag, 'uReconstructY', uReconstructY)
1303+
!$acc exit data copyout(uReconstructY(:,1:nCells))
1304+
call mpas_pool_get_array(diag, 'uReconstructZ', uReconstructZ)
1305+
!$acc exit data copyout(uReconstructZ(:,1:nCells))
1306+
call mpas_pool_get_array(diag, 'uReconstructZonal', uReconstructZonal)
1307+
!$acc exit data copyout(uReconstructZonal(:,1:nCells))
1308+
call mpas_pool_get_array(diag, 'uReconstructMeridional', uReconstructMeridional)
1309+
!$acc exit data copyout(uReconstructMeridional(:,1:nCells))
1310+
12711311
call mpas_pool_get_array(state, 'u', u_1, 1)
12721312
!$acc exit data copyout(u_1)
12731313
call mpas_pool_get_array(state, 'u', u_2, 2)

src/core_atmosphere/mpas_atm_core.F

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -543,13 +543,17 @@ subroutine atm_mpas_init_block(dminfo, stream_manager, block, mesh, dt)
543543
call mpas_pool_get_array(diag, 'uReconstructZ', uReconstructZ)
544544
call mpas_pool_get_array(diag, 'uReconstructZonal', uReconstructZonal)
545545
call mpas_pool_get_array(diag, 'uReconstructMeridional', uReconstructMeridional)
546+
call mpas_reconstruct_2d_h2d(mesh, u, uReconstructX, uReconstructY, uReconstructZ, &
547+
uReconstructZonal, uReconstructMeridional)
546548
call mpas_reconstruct(mesh, u, &
547549
uReconstructX, &
548550
uReconstructY, &
549551
uReconstructZ, &
550552
uReconstructZonal, &
551553
uReconstructMeridional &
552554
)
555+
call mpas_reconstruct_2d_d2h(mesh, u, uReconstructX, uReconstructY, uReconstructZ, &
556+
uReconstructZonal, uReconstructMeridional)
553557
554558
#ifdef DO_PHYSICS
555559
!proceed with initialization of physics parameterization if moist_physics is set to true:

src/operators/mpas_vector_reconstruction.F

Lines changed: 104 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -258,16 +258,6 @@ subroutine mpas_reconstruct_2d(meshPool, u, uReconstructX, uReconstructY, uRecon
258258

259259
call mpas_pool_get_config(meshPool, 'on_a_sphere', on_a_sphere)
260260

261-
MPAS_ACC_TIMER_START('mpas_reconstruct_2d [ACC_data_xfer]')
262-
! Only use sections needed, nCells may be all cells or only non-halo cells
263-
!$acc enter data copyin(coeffs_reconstruct(:,:,1:nCells),nEdgesOnCell(1:nCells), &
264-
!$acc edgesOnCell(:,1:nCells),latCell(1:nCells),lonCell(1:nCells))
265-
!$acc enter data copyin(u(:,:))
266-
!$acc enter data create(uReconstructX(:,1:nCells),uReconstructY(:,1:nCells), &
267-
!$acc uReconstructZ(:,1:nCells),uReconstructZonal(:,1:nCells), &
268-
!$acc uReconstructMeridional(:,1:nCells))
269-
MPAS_ACC_TIMER_STOP('mpas_reconstruct_2d [ACC_data_xfer]')
270-
271261
! loop over cell centers
272262
!$omp do schedule(runtime)
273263
!$acc parallel default(present)
@@ -337,6 +327,109 @@ subroutine mpas_reconstruct_2d(meshPool, u, uReconstructX, uReconstructY, uRecon
337327
!$omp end do
338328
end if
339329

330+
end subroutine mpas_reconstruct_2d!}}}
331+
332+
333+
subroutine mpas_reconstruct_2d_h2d(meshPool, u, uReconstructX, uReconstructY, uReconstructZ, uReconstructZonal, uReconstructMeridional, includeHalos)!{{{
334+
335+
implicit none
336+
337+
type (mpas_pool_type), intent(in) :: meshPool !< Input: Mesh information
338+
real (kind=RKIND), dimension(:,:), intent(in) :: u !< Input: Velocity field on edges
339+
real (kind=RKIND), dimension(:,:), intent(out) :: uReconstructX !< Output: X Component of velocity reconstructed to cell centers
340+
real (kind=RKIND), dimension(:,:), intent(out) :: uReconstructY !< Output: Y Component of velocity reconstructed to cell centers
341+
real (kind=RKIND), dimension(:,:), intent(out) :: uReconstructZ !< Output: Z Component of velocity reconstructed to cell centers
342+
real (kind=RKIND), dimension(:,:), intent(out) :: uReconstructZonal !< Output: Zonal Component of velocity reconstructed to cell centers
343+
real (kind=RKIND), dimension(:,:), intent(out) :: uReconstructMeridional !< Output: Meridional Component of velocity reconstructed to cell centers
344+
logical, optional, intent(in) :: includeHalos !< Input: Optional logical that allows reconstruction over halo regions
345+
346+
logical :: includeHalosLocal
347+
integer, dimension(:,:), pointer :: edgesOnCell
348+
integer, dimension(:), pointer :: nEdgesOnCell
349+
integer :: nCells
350+
integer, pointer :: nCells_ptr
351+
real(kind=RKIND), dimension(:), pointer :: latCell, lonCell
352+
real (kind=RKIND), dimension(:,:,:), pointer :: coeffs_reconstruct
353+
354+
if ( present(includeHalos) ) then
355+
includeHalosLocal = includeHalos
356+
else
357+
includeHalosLocal = .false.
358+
end if
359+
360+
! stored arrays used during compute procedure
361+
call mpas_pool_get_array(meshPool, 'coeffs_reconstruct', coeffs_reconstruct)
362+
363+
! temporary variables
364+
call mpas_pool_get_array(meshPool, 'nEdgesOnCell', nEdgesOnCell)
365+
call mpas_pool_get_array(meshPool, 'edgesOnCell', edgesOnCell)
366+
call mpas_pool_get_array(meshPool, 'latCell', latCell)
367+
call mpas_pool_get_array(meshPool, 'lonCell', lonCell)
368+
369+
if ( includeHalosLocal ) then
370+
call mpas_pool_get_dimension(meshPool, 'nCells', nCells_ptr)
371+
else
372+
call mpas_pool_get_dimension(meshPool, 'nCellsSolve', nCells_ptr)
373+
end if
374+
nCells = nCells_ptr
375+
376+
MPAS_ACC_TIMER_START('mpas_reconstruct_2d [ACC_data_xfer]')
377+
! Only use sections needed, nCells may be all cells or only non-halo cells
378+
!$acc enter data copyin(coeffs_reconstruct(:,:,1:nCells),nEdgesOnCell(1:nCells), &
379+
!$acc edgesOnCell(:,1:nCells),latCell(1:nCells),lonCell(1:nCells))
380+
!$acc enter data copyin(u(:,:))
381+
!$acc enter data create(uReconstructX(:,1:nCells),uReconstructY(:,1:nCells), &
382+
!$acc uReconstructZ(:,1:nCells),uReconstructZonal(:,1:nCells), &
383+
!$acc uReconstructMeridional(:,1:nCells))
384+
MPAS_ACC_TIMER_STOP('mpas_reconstruct_2d [ACC_data_xfer]')
385+
386+
end subroutine mpas_reconstruct_2d_h2d
387+
388+
389+
390+
subroutine mpas_reconstruct_2d_d2h(meshPool, u, uReconstructX, uReconstructY, uReconstructZ, uReconstructZonal, uReconstructMeridional, includeHalos)!{{{
391+
392+
implicit none
393+
394+
type (mpas_pool_type), intent(in) :: meshPool !< Input: Mesh information
395+
real (kind=RKIND), dimension(:,:), intent(in) :: u !< Input: Velocity field on edges
396+
real (kind=RKIND), dimension(:,:), intent(out) :: uReconstructX !< Output: X Component of velocity reconstructed to cell centers
397+
real (kind=RKIND), dimension(:,:), intent(out) :: uReconstructY !< Output: Y Component of velocity reconstructed to cell centers
398+
real (kind=RKIND), dimension(:,:), intent(out) :: uReconstructZ !< Output: Z Component of velocity reconstructed to cell centers
399+
real (kind=RKIND), dimension(:,:), intent(out) :: uReconstructZonal !< Output: Zonal Component of velocity reconstructed to cell centers
400+
real (kind=RKIND), dimension(:,:), intent(out) :: uReconstructMeridional !< Output: Meridional Component of velocity reconstructed to cell centers
401+
logical, optional, intent(in) :: includeHalos !< Input: Optional logical that allows reconstruction over halo regions
402+
403+
logical :: includeHalosLocal
404+
integer, dimension(:,:), pointer :: edgesOnCell
405+
integer, dimension(:), pointer :: nEdgesOnCell
406+
integer :: nCells
407+
integer, pointer :: nCells_ptr
408+
real(kind=RKIND), dimension(:), pointer :: latCell, lonCell
409+
real (kind=RKIND), dimension(:,:,:), pointer :: coeffs_reconstruct
410+
411+
if ( present(includeHalos) ) then
412+
includeHalosLocal = includeHalos
413+
else
414+
includeHalosLocal = .false.
415+
end if
416+
417+
! stored arrays used during compute procedure
418+
call mpas_pool_get_array(meshPool, 'coeffs_reconstruct', coeffs_reconstruct)
419+
420+
! temporary variables
421+
call mpas_pool_get_array(meshPool, 'nEdgesOnCell', nEdgesOnCell)
422+
call mpas_pool_get_array(meshPool, 'edgesOnCell', edgesOnCell)
423+
call mpas_pool_get_array(meshPool, 'latCell', latCell)
424+
call mpas_pool_get_array(meshPool, 'lonCell', lonCell)
425+
426+
if ( includeHalosLocal ) then
427+
call mpas_pool_get_dimension(meshPool, 'nCells', nCells_ptr)
428+
else
429+
call mpas_pool_get_dimension(meshPool, 'nCellsSolve', nCells_ptr)
430+
end if
431+
nCells = nCells_ptr
432+
340433
MPAS_ACC_TIMER_START('mpas_reconstruct_2d [ACC_data_xfer]')
341434
!$acc exit data delete(coeffs_reconstruct(:,:,1:nCells),nEdgesOnCell(1:nCells), &
342435
!$acc edgesOnCell(:,1:nCells),latCell(1:nCells),lonCell(1:nCells))
@@ -346,7 +439,7 @@ subroutine mpas_reconstruct_2d(meshPool, u, uReconstructX, uReconstructY, uRecon
346439
!$acc uReconstructMeridional(:,1:nCells))
347440
MPAS_ACC_TIMER_STOP('mpas_reconstruct_2d [ACC_data_xfer]')
348441

349-
end subroutine mpas_reconstruct_2d!}}}
442+
end subroutine mpas_reconstruct_2d_d2h
350443

351444

352445
!***********************************************************************

0 commit comments

Comments
 (0)