Skip to content

Commit 0e45022

Browse files
authored
Add smoke test Using a simple RN50 with torch.compile (#7359)
1 parent 924d373 commit 0e45022

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

test/smoke_test.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from pathlib import Path
55

66
import torch
7+
import torch.nn as nn
78
import torchvision
89
from torchvision.io import read_image
910
from torchvision.models import resnet50, ResNet50_Weights
@@ -26,6 +27,12 @@ def smoke_test_torchvision_read_decode() -> None:
2627
if img_png.ndim != 3 or img_png.numel() < 100:
2728
raise RuntimeError(f"Unexpected shape of img_png: {img_png.shape}")
2829

30+
def smoke_test_compile() -> None:
31+
model = resnet50().cuda()
32+
model = torch.compile(model)
33+
x = torch.randn(1, 3, 224, 224, device="cuda")
34+
out = model(x)
35+
print(f"torch.compile model output: {out.shape}")
2936

3037
def smoke_test_torchvision_resnet50_classify(device: str = "cpu") -> None:
3138
img = read_image(str(SCRIPT_DIR / ".." / "gallery" / "assets" / "dog2.jpg")).to(device)
@@ -54,14 +61,18 @@ def smoke_test_torchvision_resnet50_classify(device: str = "cpu") -> None:
5461

5562
def main() -> None:
5663
print(f"torchvision: {torchvision.__version__}")
64+
print(f"torch.cuda.is_available: {torch.cuda.is_available()}")
5765
smoke_test_torchvision()
5866
smoke_test_torchvision_read_decode()
5967
smoke_test_torchvision_resnet50_classify()
6068
if torch.cuda.is_available():
6169
smoke_test_torchvision_resnet50_classify("cuda")
70+
smoke_test_compile()
71+
6272
if torch.backends.mps.is_available():
6373
smoke_test_torchvision_resnet50_classify("mps")
6474

6575

76+
6677
if __name__ == "__main__":
6778
main()

0 commit comments

Comments
 (0)