Skip to content

Commit 0dcd68c

Browse files
authored
[flang][cuda] Set allocator index for module allocatable variable (llvm#106777)
Descriptor for module variable with cuda attribute must be set with the correct allocator index. This patch updates the embox operation used in the global to carry the allocator index.
1 parent 10affaf commit 0dcd68c

File tree

2 files changed

+33
-2
lines changed

2 files changed

+33
-2
lines changed

flang/lib/Lower/ConvertVariable.cpp

+18-2
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,20 @@ void Fortran::lower::createGlobalInitialization(
478478
builder.restoreInsertionPoint(insertPt);
479479
}
480480

481+
static unsigned getAllocatorIdx(cuf::DataAttributeAttr dataAttr) {
482+
if (dataAttr) {
483+
if (dataAttr.getValue() == cuf::DataAttribute::Pinned)
484+
return kPinnedAllocatorPos;
485+
if (dataAttr.getValue() == cuf::DataAttribute::Device)
486+
return kDeviceAllocatorPos;
487+
if (dataAttr.getValue() == cuf::DataAttribute::Managed)
488+
return kManagedAllocatorPos;
489+
if (dataAttr.getValue() == cuf::DataAttribute::Unified)
490+
return kUnifiedAllocatorPos;
491+
}
492+
return kDefaultAllocator;
493+
}
494+
481495
/// Create the global op and its init if it has one
482496
static fir::GlobalOp defineGlobal(Fortran::lower::AbstractConverter &converter,
483497
const Fortran::lower::pft::Variable &var,
@@ -540,8 +554,10 @@ static fir::GlobalOp defineGlobal(Fortran::lower::AbstractConverter &converter,
540554
// Create unallocated/disassociated descriptor if no explicit init
541555
Fortran::lower::createGlobalInitialization(
542556
builder, global, [&](fir::FirOpBuilder &b) {
543-
mlir::Value box =
544-
fir::factory::createUnallocatedBox(b, loc, symTy, std::nullopt);
557+
mlir::Value box = fir::factory::createUnallocatedBox(
558+
b, loc, symTy,
559+
/*nonDeferredParams=*/std::nullopt,
560+
/*typeSourceBox=*/{}, getAllocatorIdx(dataAttr));
545561
b.create<fir::HasValueOp>(loc, box);
546562
});
547563
}

flang/test/Lower/CUDA/cuda-allocatable.cuf

+15
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,21 @@
22

33
! Test lowering of CUDA allocatable allocate/deallocate statements.
44

5+
module globals
6+
real, device, allocatable :: a_device(:)
7+
real, managed, allocatable :: a_managed(:)
8+
real, pinned, allocatable :: a_pinned(:)
9+
end module
10+
11+
! CHECK-LABEL: fir.global @_QMglobalsEa_device {data_attr = #cuf.cuda<device>} : !fir.box<!fir.heap<!fir.array<?xf32>>>
12+
! CHECK: %{{.*}} = fir.embox %{{.*}}(%{{.*}}) {allocator_idx = 2 : i32} : (!fir.heap<!fir.array<?xf32>>, !fir.shape<1>) -> !fir.box<!fir.heap<!fir.array<?xf32>>>
13+
14+
! CHECK-LABEL: fir.global @_QMglobalsEa_managed {data_attr = #cuf.cuda<managed>} : !fir.box<!fir.heap<!fir.array<?xf32>>>
15+
! CHECK: %{{.*}} = fir.embox %{{.*}}(%{{.*}}) {allocator_idx = 3 : i32} : (!fir.heap<!fir.array<?xf32>>, !fir.shape<1>) -> !fir.box<!fir.heap<!fir.array<?xf32>>>
16+
17+
! CHECK-LABEL: fir.global @_QMglobalsEa_pinned {data_attr = #cuf.cuda<pinned>} : !fir.box<!fir.heap<!fir.array<?xf32>>>
18+
! CHECK: %{{.*}} = fir.embox %{{.*}}(%{{.*}}) {allocator_idx = 1 : i32} : (!fir.heap<!fir.array<?xf32>>, !fir.shape<1>) -> !fir.box<!fir.heap<!fir.array<?xf32>>>
19+
520
subroutine sub1()
621
real, allocatable, device :: a(:)
722
allocate(a(10))

0 commit comments

Comments
 (0)