Skip to content

Test moe compile #224

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mixtral-moe/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
## Downloading Weights

```bash
export MODEL_REPO=mistralai/Mixtral-8x7B-v0.1
export MODEL_REPO=mistralai/Mixtral-8x7B-Instruct-v0.1
python scripts/download.py --repo_id $MODEL_REPO
python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/$MODEL_REPO
```
Expand Down
143 changes: 111 additions & 32 deletions mixtral-moe/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def device_sync(device):
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.triton.unique_kernel_names = True
torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future

torch._dynamo.config.capture_scalar_outputs = True

# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
Expand All @@ -52,7 +52,7 @@ def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = Non
return probs

def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):
probs = logits_to_probs(logits[0, -1], temperature, top_k)
probs = logits_to_probs(logits[:, -1], temperature, top_k)
idx_next = multinomial_sample_one_no_sync(probs)
return idx_next, probs

Expand All @@ -74,11 +74,13 @@ def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torc
next_token, next_prob = decode_one_token(
model, cur_token, input_pos, **sampling_kwargs
)
next_token, next_prob = next_token.clone(), next_prob.clone()

input_pos += 1
new_tokens.append(next_token.clone())
callback(new_tokens[-1])
new_probs.append(next_prob.clone())
cur_token = next_token.view(1, -1)
cur_token = next_token

return new_tokens, new_probs

Expand All @@ -91,6 +93,7 @@ def generate(
model: Transformer,
prompt: torch.Tensor,
max_new_tokens: int,
batch_size: int,
*,
interactive: bool,
callback = lambda x: x,
Expand All @@ -99,32 +102,30 @@ def generate(
"""
Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
"""
device, dtype = prompt.device, prompt.dtype


T = prompt.size(-1)
max_seq_length = min(T + max_new_tokens, model.config.block_size) if not interactive else 350
new_tokens = max_seq_length - T

# duplicate prompt for batch_size
prompt = prompt.repeat(batch_size, 1)

# create an empty tensor of the expected final shape and fill in the current tokens
T = prompt.size(0)
T_new = T + max_new_tokens
if interactive:
max_seq_length = 350
else:
max_seq_length = min(T_new, model.config.block_size)
seq = torch.empty(batch_size, max_seq_length, dtype=prompt.dtype, device=device)
seq[:, :T] = prompt

device, dtype = prompt.device, prompt.dtype
with torch.device(device):
model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
model.setup_caches(max_batch_size=batch_size, max_seq_length=max_seq_length)

# create an empty tensor of the expected final shape and fill in the current tokens
empty = torch.empty(T_new, dtype=dtype, device=device)
empty[:T] = prompt
seq = empty
input_pos = torch.arange(0, T, device=device)

next_token = prefill(model, prompt.view(1, -1), input_pos, **sampling_kwargs)
seq[T] = next_token
next_token = prefill(model, prompt.view(batch_size, -1), input_pos, **sampling_kwargs)
seq[:, T] = next_token.squeeze()

input_pos = torch.tensor([T], device=device, dtype=torch.int)

generated_tokens, _ = decode_n_tokens(model, next_token.view(1, -1), input_pos, max_new_tokens - 1, callback=callback, **sampling_kwargs)
seq[T + 1:] = torch.cat(generated_tokens)
generated_tokens, _ = decode_n_tokens(model, next_token.view(batch_size, -1), input_pos, max_new_tokens - 1, callback=callback, **sampling_kwargs)
seq = torch.cat((seq[:, :T+1], *generated_tokens), dim=-1)

return seq

Expand All @@ -144,8 +145,12 @@ def _load_model(checkpoint_path, device, precision, use_tp):
simple_quantizer = WeightOnlyBit8QuantHandler(model, torch.int8)
model = simple_quantizer.convert_for_runtime()

checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
model.load_state_dict(checkpoint, assign=True)
try:
checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
model.load_state_dict(checkpoint, assign=True)
except:
model = Transformer.from_name(checkpoint_path.parent.name)


if use_tp:
from tp import apply_tp
Expand All @@ -162,6 +167,7 @@ def main(
interactive: bool = False,
num_samples: int = 5,
max_new_tokens: int = 100,
batch_size: int = 1,
top_k: int = 200,
temperature: float = 0.8,
checkpoint_path: Path = Path("checkpoints/mistralai/Mixtral-8x7B-v0.1/model.pth"),
Expand All @@ -172,8 +178,7 @@ def main(
) -> None:
"""Generates text samples based on a pre-trained Transformer model and tokenizer.
"""
assert checkpoint_path.is_file(), checkpoint_path

# assert checkpoint_path.is_file(), checkpoint_path
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
assert tokenizer_path.is_file(), str(tokenizer_path)

Expand Down Expand Up @@ -202,13 +207,81 @@ def main(

torch.manual_seed(1234)
model_size = sum([p.numel() * p.dtype.itemsize for p in itertools.chain(model.parameters(), model.buffers())])


import torchao
from torchao.quantization import quantize_, Int8WeightOnlyConfig


def filter(model, fqn):
return isinstance(model, torch.nn.Linear) and "gate" not in fqn

quantize_(model, Int8WeightOnlyConfig(), filter_fn=filter)


from torchao.quantization.quant_primitives import MappingType
from torchao.dtypes import to_affine_quantized_intx

def moe_filter(module, fqn):
return "MOEFeedForwardAOQuantizable" in str(type(module))

def cond_ffn_filter(module, fqn):
return "ConditionalFeedForwardAOQuantizable" in str(type(module))

def quant_convert_fn(module, config):
def quant_tensor(weight):
mapping_type = MappingType.SYMMETRIC
target_dtype = torch.int8
eps = torch.finfo(torch.float32).eps
zero_point_dtype = torch.int64
block_size = [1 for x in range(param.dim())]
block_size[-1] = param.shape[-1]
block_size = tuple(block_size)
new_weight = to_affine_quantized_intx(
weight,
mapping_type,
block_size,
target_dtype,
eps=eps,
zero_point_dtype=zero_point_dtype,
)
return new_weight
assert "ConditionalFeedForwardAOQuantizable" in str(type(module))
assert hasattr(module, "w1")
assert hasattr(module, "w2")
assert hasattr(module, "w3")

group_size = None if config.group_size is None else config.group_size
for weight_attr in ["w1", "w2", "w3"]:
param = getattr(module, weight_attr)
new_param = quant_tensor(param)
new_param = torch.nn.Parameter(new_param, requires_grad=False)
setattr(module, weight_attr, new_param)
del param
return module

from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter

# _replace_with_custom_fn_if_matches_filter(
# model,
# quant_convert_fn,
# cond_ffn_filter,
# extra_args=(Int8WeightOnlyConfig(),)
# )



if compile:
torch._inductor.config.assert_indirect_indexing = False
# torch._dynamo.config.capture_dynamic_output_shape_ops = True

global decode_one_token, prefill
decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead", fullgraph=True)
if batch_size > 1: # MoE code has graph break for multi token path so can't fullgraph compile
# decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead")
decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead")
else:
decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead", fullgraph=True)

# Uncomment to squeeze more perf out of prefill
if args.compile_prefill:
prefill = torch.compile(prefill, fullgraph=True, dynamic=True)

Expand Down Expand Up @@ -255,6 +328,7 @@ def callback(x):
model,
encoded,
max_new_tokens,
batch_size,
interactive=interactive,
callback=callback,
temperature=temperature,
Expand All @@ -272,16 +346,19 @@ def callback(x):
t = time.perf_counter() - t0

if not interactive:
print(tokenizer.decode(y.tolist()))
print(tokenizer.decode(y[0].tolist()))
else:
print()
tokens_generated = y.size(0) - prompt_length
tokens_generated = y.size(-1) - prompt_length
tokens_sec = tokens_generated / t
aggregate_metrics['tokens_per_sec'].append(tokens_sec)
print(f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec")
print(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s")

print(f"Average tokens/sec: {torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item():.2f}")
tokpersec = torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item()
print(f"Average tokens/sec: {tokpersec:.2f}")
if batch_size > 1:
print(f"Average tokens/sec including batches {batch_size*tokpersec:.2f}")
print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")


Expand All @@ -291,8 +368,10 @@ def callback(x):

parser.add_argument('--prompt', type=str, default="Hello, my name is", help='Input prompt.')
parser.add_argument('--interactive', action='store_true', help='Whether to launch in interactive mode')
parser.add_argument('--num_samples', type=int, default=5, help='Number of samples.')
# parser.add_argument('--num_samples', type=int, default=1, help='Number of samples.')
parser.add_argument('--num_samples', type=int, default=2, help='Number of samples.')
parser.add_argument('--max_new_tokens', type=int, default=200, help='Maximum number of new tokens.')
parser.add_argument('--batch_size', type=int, default=1, help='Batch size to benchmark with')
parser.add_argument('--top_k', type=int, default=200, help='Top-k for sampling.')
parser.add_argument('--temperature', type=float, default=0.8, help='Temperature for sampling.')
parser.add_argument('--checkpoint_path', type=Path, default=Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"), help='Model checkpoint path.')
Expand All @@ -303,6 +382,6 @@ def callback(x):

args = parser.parse_args()
main(
args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k,
args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.batch_size, args.top_k,
args.temperature, args.checkpoint_path, args.compile, args.compile_prefill, args.profile, args.device
)
Loading