Skip to content

float8 rowwise vanilla TP low throughput #1207

Open
@danielvegamyhre

Description

@danielvegamyhre

Bug description

Llama3 8b on 4xH100s with per op SAC, using FSDP=2, TP=2

  • bf16: 5378 TPS, 45.68 GiB peak memory
  • float8 rowwise: 5189 TPS, 45.67 GiB peak memory

Versions

  • torch 2.8.0a0+gite21ad6e
  • torchtitan @ HEAD
  • torchao 0.11.0

Metadata

Metadata

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions