Skip to content

R1 woq #2148

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 25 commits into
base: dev/ds_r1
Choose a base branch
from
Draft

R1 woq #2148

Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions neural_compressor/common/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -28,16 +28,16 @@
VLLM_TP_SIZE = int(os.getenv("VLLM_TP_SIZE", "8"))
VLLM_EP_SIZE = int(os.getenv("VLLM_EP_SIZE", VLLM_TP_SIZE))
NUM_EXPERTS_PER_EP_RANK = DEEPSEEK_EXPERTS // VLLM_EP_SIZE # 32
NUM_EXPERTS_GROUPS = 8
NUM_EXPERTS_PER_GROUP_PER_RANK = NUM_EXPERTS_PER_EP_RANK // NUM_EXPERTS_GROUPS # 4
VLLM_MOE_N_SLICE = int(os.getenv("VLLM_MOE_N_SLICE", 8))
NUM_EXPERTS_PER_GROUP_PER_RANK = NUM_EXPERTS_PER_EP_RANK // VLLM_MOE_N_SLICE # 4
FUSED_MOE_EXPERTS = NUM_EXPERTS_PER_GROUP_PER_RANK # 4

logger.warning_once(
(
f"INC uses VLLM_TP_SIZE={VLLM_TP_SIZE},\n"
f"VLLM_EP_SIZE={VLLM_EP_SIZE},\n"
f"NUM_EXPERTS_PER_EP_RANK={NUM_EXPERTS_PER_EP_RANK},\n"
f"NUM_EXPERTS_GROUPS={NUM_EXPERTS_GROUPS},\n"
f"VLLM_MOE_N_SLICE={VLLM_MOE_N_SLICE},\n"
f"NUM_EXPERTS_PER_GROUP_PER_RANK={NUM_EXPERTS_PER_GROUP_PER_RANK},\n"
f"FUSED_MOE_EXPERTS={FUSED_MOE_EXPERTS}"
)
22 changes: 16 additions & 6 deletions neural_compressor/torch/algorithms/fp8_quant/_core/common.py
Original file line number Diff line number Diff line change
@@ -42,6 +42,14 @@

INFO_INTERVAL = 30 # seconds

def maybe_dequant_original_fp8_weight(mod: torch.nn.Module, param: torch.Tensor):
if param.dtype in [torch.float8_e4m3fn]:
if hasattr(mod, "get_dequant_weights_func"):
dequant_weights_func = mod.get_dequant_weights_func()
if dequant_weights_func is not None:
param = dequant_weights_func(mod)
return param

_mod_types = {
"linear": ModuleType(1, ["weight"], 1, False),
"matmul": ModuleType(2, [], 1, False),
@@ -222,15 +230,17 @@ def convert_scales_to_tensors_dict(scales_obj, scales_file_format, hp_dtype, dev
"Softmax": ModuleInfo("softmax", PatchedSoftmax),
"ModuleFusedSDPA": ModuleInfo("fused_sdpa", PatchedModuleFusedSDPA),
"MoeMatmul": ModuleInfo("linear", PatchedMoeMatmul),
"MoeFP8Matmul": ModuleInfo("linear", PatchedMoeFP8Matmul),
"ReplicatedLinear": ModuleInfo("linear", PatchedReplicatedLinear),
"VllmMixtureOfExpertsOpFP8": ModuleInfo("dynamic_moe", PatchedVllmMixtureOfExpertsOpFP8),
# FIXME (Yi) revert change
"FusedMoE": ModuleInfo("linear", PatchedMixtralMoE, False),
"GaudiMixtralSparseMoeBlock": ModuleInfo("dynamic_moe", PatchedGaudiMixtralSparseMoeBlock),
"VllmMixtureOfExpertsOp": (
ModuleInfo("dynamic_moe", PatchedVllmMixtureOfExpertsOpV2)
if os.getenv("LOW_CPU_MEM", "0") == "1"
else ModuleInfo("dynamic_moe", PatchedVllmMixtureOfExpertsOpV1)
),
# "GaudiMixtralSparseMoeBlock": ModuleInfo("dynamic_moe", PatchedGaudiMixtralSparseMoeBlock),
# "VllmMixtureOfExpertsOp": (
# ModuleInfo("dynamic_moe", PatchedVllmMixtureOfExpertsOpV2)
# if os.getenv("LOW_CPU_MEM", "0") == "1"
# else ModuleInfo("dynamic_moe", PatchedVllmMixtureOfExpertsOpV1)
# ),
}


Original file line number Diff line number Diff line change
@@ -22,7 +22,7 @@
import time
from .._quant_common.quant_config import MeasureExclude, QuantMode, ScaleMethod, get_hqt_config, set_hqt_config
# from ..utils.logger import logger
from neural_compressor.torch.algorithms.fp8_quant._core.common import INFO_INTERVAL
from neural_compressor.torch.algorithms.fp8_quant._core.common import INFO_INTERVAL, maybe_dequant_original_fp8_weight
from .common import *
from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator
from neural_compressor.torch.algorithms.fp8_quant.model_configs import (
@@ -149,6 +149,10 @@ def register_patched_measure_modules(model, mod_list, observer_class, d_shapes=N
logger.info(f"Patching measure module {name} {mod.__class__}")
num_info += 1
set_hqt_config(mod, top_level_config) # set config in the module, as it consumed by the patched module
if mod_type == "dynamic_moe" and hasattr(mod, "num_experts"):
# override default number of outputs for dynamic moe
mod_types[mod_type].num_outputs = mod.num_experts+1
logger.warning(f"Dynamic moe num_outputs set to {mod.num_experts+1}")
mod_extra_config = (
init_measure_object(
mod,
@@ -167,7 +171,10 @@ def register_patched_measure_modules(model, mod_list, observer_class, d_shapes=N
# logger.info(f"Pacthed module pmod: {pmod}")
if pmod._mod_extra_config:
for param_name in pmod._mod_extra_config.params:
# if torch.distributed.get_rank() == 0:
# import pdb; pdb.set_trace()
param = getattr(pmod, param_name)
param = maybe_dequant_original_fp8_weight(pmod.orig_mod, param)
if config["measure_on_hpu"]:
param = param.to(cur_accelerator.name())
pmod._mod_extra_config.params[param_name].measure(param)
27 changes: 21 additions & 6 deletions neural_compressor/torch/algorithms/fp8_quant/_core/quantize.py
Original file line number Diff line number Diff line change
@@ -33,7 +33,7 @@
import time
cur_accelerator = auto_detect_accelerator()

from neural_compressor.torch.algorithms.fp8_quant._core.common import INFO_INTERVAL
from neural_compressor.torch.algorithms.fp8_quant._core.common import INFO_INTERVAL, maybe_dequant_original_fp8_weight


@torch.no_grad()
@@ -78,9 +78,13 @@ def quantize_params(mod, mod_extra_config):
param = getattr(mod, param_name)
if param.dtype == torch.float16:
param = param.to(torch.bfloat16)
logger.debug(f"Quantizing parameter {param_name} of module {mod.__class__.__name__}")
param = maybe_dequant_original_fp8_weight(mod, param)
quantized_param = quantizer(param.to(cur_accelerator.name()))
delattr(mod, param_name)
setattr(mod, param_name, nn.Parameter(quantized_param))
# Note: in case of re-quantize the fp8 weights, we need to set `updated_fp8_weight` to True
mod.updated_fp8_weight = True
quantized_param = getattr(mod, param_name)
quantized_param.requires_grad_(False)
cur_accelerator.synchronize()
@@ -165,27 +169,38 @@ def prepare_model(model, mod_list, measurement, scale_file, scaling_method_name,
scale_config, save_file)
if not config.cfg["fake_quant"] and mod_default_dict[mod_type_str].should_measure_and_quant:
quantize_params(mod, mod_extra_config)
logger.debug(f"patching module {name}")
# logger.debug(f"patching module {name}")
patch_module(mod, mod_extra_config, mod_default_dict)
name = origin_name
patched_modules.append(name)
patched_module_types.add(type(mod))
htcore.mark_step()
logger.debug("Patched module name: %s", name)
cur_accelerator.synchronize()
if save_file: # cache calculated scales
save_scales(model, scales_obj, scales_file_format, scale_file + ".npz")
save_scales(model, scales_obj, scales_file_format, scale_file + ".json")
logger.debug("Patched module types: %s", patched_module_types)
logger.debug("Patched modules: %s", patched_modules)
logger.debug("Total patched modules: %d", len(patched_modules))

show_mem_info("before move all")
model = model.to(cur_accelerator.name())
for _, mod in model.named_modules():
if hasattr(mod, "post_process"):
mod.post_process()
torch.distributed.barrier()
show_mem_info("after move all")
postporcess_after_convert_(model)
show_mem_info("after post process")
convert_fp16_to_bf16(model)
show_mem_info("after convert_fp16_to_bf16")
cur_accelerator.synchronize()
show_mem_info("after synchronize")
torch.distributed.barrier()

def postporcess_after_convert_(model):
for _, mod in model.named_modules():
if hasattr(mod, "post_process"):
mod.post_process()
# Note: It is very important to synchronize after each post_process to avoid OoM.
cur_accelerator.synchronize()

def prepare_model_with_dummy_measurement(model, mod_list, scaling_method_name, scale_config):
"""Aim for loading, replace module with patched module for model on meta device.
Original file line number Diff line number Diff line change
@@ -20,7 +20,8 @@
from .scales_method import QuantTensorType
from ..quant_dequant import DequantOutput, QuantDequant, QuantDequantNone, QuantInput
from neural_compressor.common import utils as inc_utils

# from neural_compressor.torch.algorithms.fp8_quant.utils import
from neural_compressor.torch.algorithms.fp8_quant._core.common import maybe_dequant_original_fp8_weight
class BaseOpQuantizer:

def __init__(self, config, mod, measurement, params, op_type):
@@ -94,9 +95,11 @@ def get_scales_module_config(self):
input_scales = self.calc_input_scales(num_of_inputs=1)
output_measurement = self.measurement.outputs[0] if self.measurement is not None else []
rescaled_weight = self.mod.weight if hasattr(self.mod, 'weight') else None
if rescaled_weight is not None:
rescaled_weight = maybe_dequant_original_fp8_weight(self.mod, rescaled_weight)
if self.weight_ich_scale_calc is not None:
weight_scales_in_ch = self.weight_ich_scale_calc.calc_scales(input_scales[0], QuantTensorType.CONST)
rescaled_weight = torch.div(self.mod.weight, weight_scales_in_ch.reshape([1, -1]))
rescaled_weight = torch.div(rescaled_weight, weight_scales_in_ch.reshape([1, -1]))
weights_scales_out_ch = self.weight_och_scale_calc.calc_scales(rescaled_weight, QuantTensorType.CONST)
params_config = (
{"weight": weights_scales_out_ch}
Original file line number Diff line number Diff line change
@@ -418,7 +418,8 @@ def forward_quant(self, input):
def forward_measure(self, input):
resolved_input = self.resolve_input(input)
measure_input((resolved_input,), observer=self._mod_extra_config.inputs)
output = torch.matmul(resolved_input, self.weight.transpose(-1, -2))
# output = torch.matmul(resolved_input, self.weight.transpose(-1, -2))
output = self.orig_mod.quant_method.apply(self.orig_mod, resolved_input)
measure_output((output,), self._mod_extra_config.outputs)
if self.reduce_results:
output = self.collective_func(output)
@@ -474,11 +475,20 @@ def forward_quant(self, input):

def forward_measure(self, input):
measure_input((input,), observer=self._mod_extra_config.inputs)
output = torch.matmul(input, self.weight.transpose(-1, -2))
output = self.orig_mod.quant_method.apply(self.orig_mod, input)
measure_output((output,), self._mod_extra_config.outputs)
output, output_bias = self.add_bias(output)
if self.gather_output:
output = self.collective_func(output)
return self.post_all_reduce(output)
return output, output_bias

def add_bias(self, output):
if not self.skip_bias_add:
output = output + self.bias if self.bias is not None else output
output_bias = None
else:
output_bias = self.bias
return output, output_bias

def post_all_reduce(self, output):
if not self.skip_bias_add:
@@ -632,7 +642,16 @@ def extra_repr(self) -> str:
get_current_repr(self, "scale_input", "scale_weight"),
)


class PatchedMoeFP8Matmul(PatchedMoeMatmul):
def __init__(self, mod, parent, mod_extra_config, *args, **kwargs):
super().__init__(mod, parent, mod_extra_config, *args, **kwargs)
# if torch.distributed.get_rank() == 0:
# import pdb; pdb.set_trace()
# torch.distributed.barrier()
# self.block_size = self.orig_mod.block_size
# self.scale_inv_fp8 = self.orig_mod.scale_inv_fp8
self.get_dequant_weight = self.orig_mod.get_dequant_weight

class PatchedGaudiMixtralSparseMoeBlock(PatchedModuleBase):
def __init__(self, mod, parent, mod_extra_config, *args, **kwargs):
super().__init__(mod, parent, mod_extra_config, *args, **kwargs)
@@ -724,8 +743,8 @@ def extra_repr(self) -> str:
class PatchedVllmMixtureOfExpertsOpV1(PatchedModuleBase):
def __init__(self, mod, parent, mod_extra_config, *args, **kwargs):
super().__init__(mod, parent, mod_extra_config, *args, **kwargs)
self.experts_min = self.orig_mod.experts_min
self.experts_max = self.orig_mod.experts_max
self.experts_min = self.orig_mod.experts_min if hasattr(self.orig_mod, "experts_min") else 0
self.experts_max = self.orig_mod.experts_max if hasattr(self.orig_mod, "experts_max") else 7
if self.quantization_mode in [QuantMode.QUANTIZE, QuantMode.LOAD]:
self.forward = self.forward_quant
self.dynamic_moe_op = get_quantized_func_wrapper(OP_TYPE.DYNAMIC_MOE_FUSED_WEIGHTS, self.scale_format)
@@ -737,11 +756,18 @@ def __init__(self, mod, parent, mod_extra_config, *args, **kwargs):
[mod_extra_config.scale.inputs[x] for x in range(1, self.num_experts+1)],
self.scale_format,
)
for i in range(self.num_experts):
self.w13_list[i].weight = self.w13_list[i].weight.squeeze().t().contiguous()
self.w2_list[i].weight = self.w2_list[i].weight.squeeze().t().contiguous()
# if torch.distributed.get_rank() == 0:
# import pdb; pdb.set_trace()
# torch.distributed.barrier()
self._post_init_for_quant()

elif (self.quantization_mode == QuantMode.MEASURE) or (self.quantization_mode == QuantMode.SHAPE):
self.forward = self.forward_measure

def _post_init_for_quant(self):
for i in range(self.num_experts):
self.w13_list[i].weight = self.w13_list[i].weight.squeeze().t().contiguous()
self.w2_list[i].weight = self.w2_list[i].weight.squeeze().t().contiguous()

def forward_quant(self,
hidden_states,
@@ -813,6 +839,100 @@ def extra_repr(self) -> str:
f"quant_mode:{quant_mode}, {get_current_repr(self, *member_names)}",
)

class PatchedVllmMixtureOfExpertsOpFP8(PatchedVllmMixtureOfExpertsOpV1):
def _post_init_for_quant(self):
pass

def post_process(self):
# if torch.distributed.get_rank() == 0:
# import pdb; pdb.set_trace()
# torch.distributed.barrier()
for i in range(self.num_experts):
self.w13_list[i].weight = torch.nn.Parameter(self.w13_list[i].weight.squeeze().t().contiguous())
self.w2_list[i].weight = torch.nn.Parameter(self.w2_list[i].weight.squeeze().t().contiguous())
htcore.mark_step()

def forward_measure(
self,
x,
topk_ids,
topk_weights,
moe_n_slice=None,
n_expert_slice=None,
ep_shift=None,
):
hidden_states = x
measure_input((hidden_states,), observer=self._mod_extra_config.inputs)
# FIXME: (Yi) Assume moe_n_slice is 1, remove it?
# assert moe_n_slice == 1, f"moe_n_slice is {moe_n_slice}, expected 1"
min_expert = self.experts_min
max_expert = self.experts_max
w13_list_slice = []
w2_list_slice = []
for j in range(self.num_experts):
w13_list_slice.append(self.w13_list[j].get_dequant_weight())
w2_list_slice.append(self.w2_list[j].get_dequant_weight())

output, intermidiate_amax = torch.ops.hpu.mixture_of_experts.fp8_measurement_fused_weights(
hidden_states=x,
expert_routing_table=topk_ids.to(torch.int64),
router_weights=topk_weights.to(x.dtype),
w12=w13_list_slice,
w3=w2_list_slice,
permuted_weights=True,
activation="silu",
experts_min=min_expert,
experts_max=max_expert,
measurement_mode=True, # <=============
)
output_measure_list = [output]
# if torch.distributed.get_rank() == 0:
# import pdb; pdb.set_trace()
# torch.distributed.barrier()
for i in range(self.num_experts):
output_measure_list.append(intermidiate_amax[i])
measure_output(output_measure_list, self._mod_extra_config.outputs)
return output

def forward_quant(
self,
x,
topk_ids,
topk_weights,
moe_n_slice=None,
n_expert_slice=None,
ep_shift=None,
):
hidden_states = x
expert_routing_table = topk_ids.to(torch.int64)
router_weights = topk_weights.to(x.dtype)
permuted_weights = True
activation = "silu"
# if torch.distributed.get_rank() == 0:
# import pdb; pdb.set_trace()
# torch.distributed.barrier()
experts_range = range(self.num_experts)
w1_list = [self.w13_list[i].weight for i in experts_range]
w2_list = [self.w2_list[i].weight for i in experts_range]
scale_w1 = [self.w13_list[i].scale_weight for i in experts_range]
scale_w2 = [self.w2_list[i].scale_weight for i in experts_range]
qinput = self.quant_input(hidden_states)
output = self.dynamic_moe_op(
hidden_states=qinput,
expert_routing_table=expert_routing_table,
router_weights=router_weights,
w12=w1_list,
w3=w2_list,
d_scale_w12=scale_w1,
d_scale_w3=scale_w2,
d_scale_hidden_states=self.scale_input,
d_scale_intermediate_hidden_states=self.scale_intermediate,
permuted_weights=permuted_weights,
activation=activation,
experts_min=self.experts_min,
experts_max=self.experts_max,
)
return output

class PatchedVllmMixtureOfExpertsOpV2(PatchedVllmMixtureOfExpertsOpV1):
def __init__(self, mod, parent, mod_extra_config, *args, **kwargs):
@@ -954,41 +1074,31 @@ def forward_qdq(self, input, *args, **kwargs):
output_cache = self.orig_mod(qinput, *args, **kwargs)
return output_cache

# def forward_quant(self, input, *args, **kwargs):
# qinput = self.quant_input(input)
# output_cache = self.orig_mod(qinput, *args, **kwargs)
# return self.dequant_output(output_cache)
def forward_quant(self, input, *args, **kwargs):
qinput = self.quant_input(input)
output_cache = self.orig_mod(qinput, *args, **kwargs)
return self.dequant_output(output_cache)

def forward_measure(self, input, *args, **kwargs):
measure_input((input, ), self._mod_extra_config.inputs)
output_cache = self.orig_mod(input, *args, **kwargs)
measure_output((output_cache, ), self._mod_extra_config.outputs)
return output_cache

# def fetch_from_cache(self, cache, blocks, permutations=None):
# # quant_cache = self.quant_input(cache)
# quant_cache = cache
# if permutations:
# output_cache = self.orig_mod.fetch_from_cache(quant_cache, blocks, permutations)
# for i in range(len(output_cache)):
# output_cache[i] = self.dequant_output(output_cache[i])
# return output_cache
# output_cache = self.orig_mod.fetch_from_cache(quant_cache, blocks)
# return self.dequant_output(output_cache)

def forward_quant(self, input, *args, **kwargs):
qinput = self.quant_input(input)
return self.orig_mod(qinput, *args, **kwargs)

def fetch_from_cache(self, quant_cache, blocks, permutations=None):
def fetch_from_cache(self, cache, blocks, permutations=None):
# TODO: Remove this workaround in next release [SW-221595]
if cache.dtype != self.lp_dtype:
quant_cache = self.quant_input(cache)
else:
quant_cache = cache
if permutations:
output_cache = self.orig_mod.fetch_from_cache(quant_cache, blocks, permutations)
for i in range(len(output_cache)):
output_cache[i] = self.dequant_output(output_cache[i])
return output_cache
output_cache = self.orig_mod.fetch_from_cache(quant_cache, blocks)
return self.dequant_output(output_cache)

def extra_repr(self) -> str:
return f"PatchedVLLMKVCache"

11 changes: 7 additions & 4 deletions neural_compressor/torch/utils/environ.py
Original file line number Diff line number Diff line change
@@ -235,15 +235,18 @@ def is_tbb_available(): # pragma: no cover
return False
return True

def show_mem_info(loglevel="info"):
def show_mem_info(msg="", loglevel="info"):
hpu_mem_mb = get_used_hpu_mem_MB()
from neural_compressor.common.utils import logger
show_fn = getattr(logger, loglevel)
rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else -1
show_fn(f"[Rank {rank}] Used HPU memory: {hpu_mem_mb // 1000} GB {hpu_mem_mb % 1000} MB")
# show_fn(f"[Rank {rank}] Used HPU memory: {hpu_mem_mb // 1000} GB {hpu_mem_mb % 1000} MB")
cpu_mem_mb = get_used_cpu_mem_MB()
show_fn(f"[Rank {rank}] Used CPU memory: {cpu_mem_mb // 1000} GB {cpu_mem_mb % 1000} MB")

# show_fn(f"[Rank {rank}] Used CPU memory: {cpu_mem_mb // 1000} GB {cpu_mem_mb % 1000} MB")
show_fn(
f"[Rank {rank}] {msg}, HPU: {hpu_mem_mb // 1000} GB {hpu_mem_mb % 1000:.2f} MB; CPU: {cpu_mem_mb // 1000} GB {cpu_mem_mb % 1000:.2f} MB"
)


def get_used_hpu_mem_MB():
"""Get HPU used memory: MiB."""
15 changes: 12 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
@@ -75,7 +75,12 @@ def get_build_version():
],
),
"package_data": {"": ["*.json"]},
"install_requires": fetch_requirements("requirements_pt.txt"),
# FIXME: (Yi) force install neural_compressor_pt
# "install_requires": fetch_requirements("requirements_pt.txt"),
"install_requires": fetch_requirements("requirements.txt"),
"extras_require": {
"pt": fetch_requirements("requirements_pt.txt"),
}
},
# 3.x tf binary build config, pip install neural-compressor-tf, install 3.x TensorFlow API.
"neural_compressor_tf": {
@@ -102,15 +107,19 @@ def get_build_version():
# https://github.com/pytorch/pytorch/pull/114662
ext_modules = []
cmdclass = {}




if "pt" in sys.argv:
sys.argv.remove("pt")
cfg_key = "neural_compressor_pt"

if "tf" in sys.argv:
sys.argv.remove("tf")
cfg_key = "neural_compressor_tf"

# FIXME: (Yi) force install neural_compressor_pt
print(f"Forcing install neural_compressor_pt")
cfg_key = "neural_compressor_pt"
project_name = PKG_INSTALL_CFG[cfg_key].get("project_name")
include_packages = PKG_INSTALL_CFG[cfg_key].get("include_packages") or {}
package_data = PKG_INSTALL_CFG[cfg_key].get("package_data") or {}