32
32
class FeedForward (nn .Module ):
33
33
"""MLP based model"""
34
34
35
- def __init__ (self ):
35
+ def __init__ (self , size ):
36
36
super (FeedForward , self ).__init__ ()
37
- self .w1 = nn .Linear (16 , 32 , bias = False )
38
- self .w2 = nn .Linear (16 , 32 , bias = False )
39
- self .out_proj = nn .Linear (32 , 16 , bias = False )
37
+ self .w1 = nn .Linear (size , size * 2 , bias = False )
38
+ self .w2 = nn .Linear (size , size * 2 , bias = False )
39
+ self .out_proj = nn .Linear (size * 2 , size , bias = False )
40
40
41
41
def forward (self , x ):
42
42
x = F .silu (self .w1 (x )) * self .w2 (x )
@@ -45,9 +45,9 @@ def forward(self, x):
45
45
46
46
47
47
class ToyModel (nn .Module ):
48
- def __init__ (self ):
48
+ def __init__ (self , size ):
49
49
super (ToyModel , self ).__init__ ()
50
- self .ffn = FeedForward ()
50
+ self .ffn = FeedForward (size )
51
51
52
52
def forward (self , x ):
53
53
return self .ffn (x )
@@ -56,7 +56,7 @@ def forward(self, x):
56
56
def _test_lowp_mlp_tensor_parallelism_base (
57
57
mesh : DeviceMesh ,
58
58
config : Union [Float8LinearConfig , MXLinearConfig ],
59
- size = 16 ,
59
+ size = 32 ,
60
60
compile : bool = False ,
61
61
allgather_in_lowp : bool = False ,
62
62
):
@@ -67,7 +67,7 @@ def _test_lowp_mlp_tensor_parallelism_base(
67
67
if isinstance (config , MXLinearConfig ):
68
68
convert_model_func = quantize_
69
69
70
- toy_model = ToyModel ().to (device )
70
+ toy_model = ToyModel (size ).to (device )
71
71
toy_model_fp8 = copy .deepcopy (toy_model )
72
72
convert_model_func (toy_model_fp8 , config = config )
73
73
0 commit comments