4
4
from pathlib import Path
5
5
6
6
import torch
7
+ import torch .nn as nn
7
8
import torchvision
8
9
from torchvision .io import read_image
9
10
from torchvision .models import resnet50 , ResNet50_Weights
@@ -26,6 +27,12 @@ def smoke_test_torchvision_read_decode() -> None:
26
27
if img_png .ndim != 3 or img_png .numel () < 100 :
27
28
raise RuntimeError (f"Unexpected shape of img_png: { img_png .shape } " )
28
29
30
+ def smoke_test_compile () -> None :
31
+ model = resnet50 ().cuda ()
32
+ model = torch .compile (model )
33
+ x = torch .randn (1 , 3 , 224 , 224 , device = "cuda" )
34
+ out = model (x )
35
+ print (f"torch.compile model output: { out .shape } " )
29
36
30
37
def smoke_test_torchvision_resnet50_classify (device : str = "cpu" ) -> None :
31
38
img = read_image (str (SCRIPT_DIR / ".." / "gallery" / "assets" / "dog2.jpg" )).to (device )
@@ -54,14 +61,18 @@ def smoke_test_torchvision_resnet50_classify(device: str = "cpu") -> None:
54
61
55
62
def main () -> None :
56
63
print (f"torchvision: { torchvision .__version__ } " )
64
+ print (f"torch.cuda.is_available: { torch .cuda .is_available ()} " )
57
65
smoke_test_torchvision ()
58
66
smoke_test_torchvision_read_decode ()
59
67
smoke_test_torchvision_resnet50_classify ()
60
68
if torch .cuda .is_available ():
61
69
smoke_test_torchvision_resnet50_classify ("cuda" )
70
+ smoke_test_compile ()
71
+
62
72
if torch .backends .mps .is_available ():
63
73
smoke_test_torchvision_resnet50_classify ("mps" )
64
74
65
75
76
+
66
77
if __name__ == "__main__" :
67
78
main ()
0 commit comments