Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Large batchsize causes Windows Error 0xe06d7363 #136

Open
GUZZ07 opened this issue Oct 31, 2024 · 4 comments
Open

Large batchsize causes Windows Error 0xe06d7363 #136

GUZZ07 opened this issue Oct 31, 2024 · 4 comments

Comments

@GUZZ07
Copy link

GUZZ07 commented Oct 31, 2024

Describe the bug
I tried a simple net on MNIST classification, an error [WinError -529697949] Windows Error 0xe06d7363 occured when I shifted from xpu to npu, finally found it was caused by a large batchsize which worked normally on xpu but caused error on npu.

To Reproduce
Steps to reproduce the behavior:
run the following python snippet with the batchSize variable setted to different values

import torch
import torch.nn as nn
import intel_npu_acceleration_library

from torchsummary import summary

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.linear_relu_stack = nn.Sequential(
            nn.Conv2d(1,10,5),
            nn.ReLU(),
            nn.MaxPool2d(2,2),
            nn.Conv2d(10, 20, 5),
            nn.ReLU(),
            nn.MaxPool2d(2,2),
            nn.Flatten(),
            nn.Linear(320, 10)
        )
    def forward(self, x: torch.Tensor):
        x = self.linear_relu_stack(x)
        x = nn.functional.log_softmax(x, dim=1)
        return x
        
model = LeNet().eval()
model = intel_npu_acceleration_library.compile(model, dtype=torch.float32, training=False)
batchSize = 320
x = torch.randn(batchSize, 1, 28, 28)
y = model(x)
print(torch.isnan(y).sum())

Expected behavior

bahavior when batchsize = 512

exception

Exception has occurred: OSError
[WinError -529697949] Windows Error 0xe06d7363
  File "D:\tofouts\openvino\batchSizeErr.py", line 21, in forward
    x = self.linear_relu_stack(x)
  File "D:\tofouts\openvino\batchSizeErr.py", line 29, in <module>
    y = model(x)
OSError: [WinError -529697949] Windows Error 0xe06d7363

console output

(here is a huge amount of "memref<1x20x8x8xf16, @DDR>")...memref<1x20x8x8xf16, @DDR>, memref<1x20x8x8xf16, @DDR>, memref<1x20x8x8xf16, @DDR>, memref<1x20x8x8xf16, @DDR>) outputs(%alloc_18 : memref<512x20x8x8xf16, @DDR>) -> memref<512x20x8x8xf16, @DDR>
  %1543 = VPUIP.ShapeCast {shape = [1, 10240, 8, 8]} inputs(%1542 : memref<512x20x8x8xf16, @DDR>) -> memref<1x10240x8x8xf16, @DDR>
  %1544 = VPUIP.PermuteCast {dst_order = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)>, mem_perm = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)>} inputs(%1543 : memref<1x10240x8x8xf16, @DDR>) -> memref<1x10240x8x8xf16, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)>, @DDR>
  %1545 = VPUIP.SubView %1544 [0, 5120, 0, 0] [1, 5120, 8, 8] : memref<1x10240x8x8xf16, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)>, @DDR> to memref<1x5120x8x8xf16, {order = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)>, strides = [655360, 1, 81920, 10240]}, @DDR>
  %1546 = VPUIP.NNDMA inputs(%1545 : memref<1x5120x8x8xf16, {order = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)>, strides = [655360, 1, 81920, 10240]}, @DDR>) outputs(%1539 : !VPUIP.DistributedBuffer<1x5120x8x8xf16, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)>, @CMX_NN, {mode = "OVERLAPPED", num_tiles = [1, 1, 2, 1], num_clusters = 2 : i64, uniform_distributed_segments, compute_shapes = [[1, 5120, 4, 8], [1, 5120, 4, 8]], compute_offsets = [[0, 0, 0, 0], [0, 0, 4, 0]], memory_shapes = [[1, 5120, 4, 8], [1, 5120, 4, 8]], memory_offsets = [[0, 0, 0, 0], [0, 0, 4, 0]]}>) -> !VPUIP.DistributedBuffer<1x5120x8x8xf16, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)>, @CMX_NN, {mode = "OVERLAPPED", num_tiles = [1, 1, 2, 1], num_clusters = 2 : i64, uniform_distributed_segments, compute_shapes = [[1, 5120, 4, 8], [1, 5120, 4, 8]], compute_offsets = [[0, 0, 0, 0], [0, 0, 4, 0]], memory_shapes = [[1, 5120, 4, 8], [1, 5120, 4, 8]], memory_offsets = [[0, 0, 0, 0], [0, 0, 4, 0]]}>
  async.yield %1546 : !VPUIP.DistributedBuffer<1x5120x8x8xf16, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)>, @CMX_NN, {mode = "OVERLAPPED", num_tiles = [1, 1, 2, 1], num_clusters = 2 : i64, uniform_distributed_segments, compute_shapes = [[1, 5120, 4, 8], [1, 5120, 4, 8]], compute_offsets = [[0, 0, 0, 0], [0, 0, 4, 0]], memory_shapes = [[1, 5120, 4, 8], [1, 5120, 4, 8]], memory_offsets = [[0, 0, 0, 0], [0, 0, 4, 0]]}>
},
[ERROR] 04:10:02.991 [vpux-compiler] Got Diagnostic at loc(fused<{name = "main", type = "Func"}>["main"]) : FeasibleAllocation Pass failed : Scheduler failure, cannot schedule anything and there is no buffer to spill
loc(fused<{name = "main", type = "Func"}>["main"]): error: FeasibleAllocation Pass failed : Scheduler failure, cannot schedule anything and there is no buffer to spill
[ERROR] 04:10:02.993 [vpux-compiler] Failed Pass FeasibleAllocation on Operation loc(fused<{name = "main", type = "Func"}>["main"])
[ERROR] 04:10:02.993 [vpux-compiler] Failed Pass mlir::detail::OpToOpPassAdaptor on Operation loc(fused<{name = "module", type = "Module"}>["module"])

output when batchsize = 416

[ERROR] 04:09:10.366 [add-enqueue-ops]     Enqueue barrier '%8470 = VPUMI40XX.ConfigureBarrier {consumer_count = 5 : ui8, producer_count = 39 : ui8}(%8468, %8465, %8464 : !VPURegMapped.Index<0:0:39>, !VPURegMapped.Index<0:0:36>, !VPURegMapped.Index<0:0:35>) <23, 85> -> !VPURegMapped.Index<0:0:41>' depends topologically on task to be enqueued itself which updates barrier '%8470 = VPUMI40XX.ConfigureBarrier {consumer_count = 5 : ui8, producer_count = 39 : ui8}(%8468, %8465, %8464 : !VPURegMapped.Index<0:0:39>, !VPURegMapped.Index<0:0:36>, !VPURegMapped.Index<0:0:35>) <23, 85> -> !VPURegMapped.Index<0:0:41>'
[ERROR] 04:09:13.056 [add-enqueue-ops]     Enqueue barrier '%7011 = VPUMI40XX.ConfigureBarrier {consumer_count = 5 : ui8, producer_count = 7 : ui8}(%7010, %7009, %7008, %7007 : !VPURegMapped.Index<0:0:88>, !VPURegMapped.Index<0:0:87>, !VPURegMapped.Index<0:0:86>, !VPURegMapped.Index<0:0:85>) <31, 123> -> !VPURegMapped.Index<0:0:89>' depends topologically on task to be enqueued itself which updates barrier '%7011 = VPUMI40XX.ConfigureBarrier {consumer_count = 5 : ui8, producer_count = 7 : ui8}(%7010, %7009, %7008, %7007 : !VPURegMapped.Index<0:0:88>, !VPURegMapped.Index<0:0:87>, !VPURegMapped.Index<0:0:86>, !VPURegMapped.Index<0:0:85>) <31, 123> -> !VPURegMapped.Index<0:0:89>'
tensor(0)

output when batchsize = 288, 320, 352 or 384

tensor(0)

System info:

  • OS: Windows 11 24H2 26100.2161
  • Processor: Intel(R) Core(TM) Ultra 7 258V
  • NPU Driver: 32.0.100.3053
  • pip packages: intel_npu_acceleration_library==1.3.0, torch==2.5.1

Additional context
console outputs error message when run the snippet when batchSize = 416 for the first time and does not throw exception, but the error message disappear when tried the second time with the same batchSize.
I tried rewrite the network and did not use nn.Sequance, the error occured when pass the input x to its second convolutional layer after the first pool layer, removing the pool layer makes the exception and error message disappear.

@vaikunth-coder27
Copy link

I tried to execute the same script to reproduce the output but despite the change in batchsize the output remains the same : "OSError: [WinError -529697949] Windows Error 0xe06d7363", for the similar system setup.

@alessandropalla
Copy link
Contributor

alessandropalla commented Dec 14, 2024

Can you try the following code?

import torch
import torch.nn as nn
import intel_npu_acceleration_library

from torchsummary import summary

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.linear_relu_stack = nn.Sequential(
            nn.Conv2d(1,10,5),
            nn.ReLU(),
            nn.MaxPool2d(2,2),
            nn.Conv2d(10, 20, 5),
            nn.ReLU(),
            nn.MaxPool2d(2,2),
            nn.Flatten(),
            nn.Linear(320, 10)
        )
    def forward(self, x: torch.Tensor):
        x = self.linear_relu_stack(x)
        x = nn.functional.log_softmax(x, dim=1)
        return x
        
model = LeNet().eval()
model =model.to('npu')
batchSize = 320
x = torch.randn(batchSize, 1, 28, 28)
y = model(x)
print(torch.isnan(y).sum())

@alessandropalla
Copy link
Contributor

This model is worth compiling and running in graph mode, this is why I suggested to use the .to('npu') method that will compile it and run it in one go, vs intel_npu_acceleration_library.compile( that effectively runs in in kernel mode

@GUZZ07
Copy link
Author

GUZZ07 commented Dec 15, 2024

Can you try the following code?

import torch
import torch.nn as nn
import intel_npu_acceleration_library

from torchsummary import summary

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.linear_relu_stack = nn.Sequential(
            nn.Conv2d(1,10,5),
            nn.ReLU(),
            nn.MaxPool2d(2,2),
            nn.Conv2d(10, 20, 5),
            nn.ReLU(),
            nn.MaxPool2d(2,2),
            nn.Flatten(),
            nn.Linear(320, 10)
        )
    def forward(self, x: torch.Tensor):
        x = self.linear_relu_stack(x)
        x = nn.functional.log_softmax(x, dim=1)
        return x
        
model = LeNet().eval()
model =model.to('npu')
batchSize = 320
x = torch.randn(batchSize, 1, 28, 28)
y = model(x)
print(torch.isnan(y).sum())

Yes, this snippet can run successfully and output tensor(0) with batchSize = 1024 now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants