@@ -305,16 +305,25 @@ class InsertGPUAllocsPass final
305305 filter.insert (copy);
306306 }
307307
308- if (allocType != memrefType)
309- allocResult = builder.create <mlir::memref::CastOp>(loc, memrefType,
310- allocResult);
311-
312- op.replaceAllUsesExcept (allocResult, filter);
313- builder.setInsertionPoint (term);
314- if (access.hostRead && access.deviceWrite ) {
315- builder.create <mlir::memref::CopyOp>(loc, allocResult, op);
308+ if (allocType != memrefType) {
309+ mlir::Value castedAllocResult = builder.create <mlir::memref::CastOp>(
310+ loc, memrefType, allocResult);
311+
312+ op.replaceAllUsesExcept (castedAllocResult, filter);
313+ builder.setInsertionPoint (term);
314+ if (access.hostRead && access.deviceWrite ) {
315+ builder.create <mlir::memref::CopyOp>(loc, castedAllocResult, op);
316+ }
317+ builder.create <mlir::gpu::DeallocOp>(loc, std::nullopt ,
318+ castedAllocResult);
319+ } else {
320+ op.replaceAllUsesExcept (allocResult, filter);
321+ builder.setInsertionPoint (term);
322+ if (access.hostRead && access.deviceWrite ) {
323+ builder.create <mlir::memref::CopyOp>(loc, allocResult, op);
324+ }
325+ builder.create <mlir::gpu::DeallocOp>(loc, std::nullopt , allocResult);
316326 }
317- builder.create <mlir::gpu::DeallocOp>(loc, std::nullopt , allocResult);
318327 } else if (m_clientAPI == " vulkan" ) {
319328 auto gpuAlloc =
320329 builder.create <mlir::memref::AllocOp>(loc, allocType, dims);
@@ -325,14 +334,21 @@ class InsertGPUAllocsPass final
325334 filter.insert (copy);
326335 }
327336
328- if (allocType != memrefType)
329- allocResult = builder.create <mlir::memref::CastOp>(loc, memrefType,
330- allocResult);
337+ if (allocType != memrefType) {
338+ mlir::Value castedAllocResult = builder.create <mlir::memref::CastOp>(
339+ loc, memrefType, allocResult);
331340
332- op.replaceAllUsesExcept (allocResult, filter);
333- builder.setInsertionPoint (term);
334- if (access.hostRead && access.deviceWrite ) {
335- builder.create <mlir::memref::CopyOp>(loc, allocResult, op);
341+ op.replaceAllUsesExcept (castedAllocResult, filter);
342+ builder.setInsertionPoint (term);
343+ if (access.hostRead && access.deviceWrite ) {
344+ builder.create <mlir::memref::CopyOp>(loc, castedAllocResult, op);
345+ }
346+ } else {
347+ op.replaceAllUsesExcept (allocResult, filter);
348+ builder.setInsertionPoint (term);
349+ if (access.hostRead && access.deviceWrite ) {
350+ builder.create <mlir::memref::CopyOp>(loc, allocResult, op);
351+ }
336352 }
337353 }
338354 };
0 commit comments