Skip to content

Commit

Permalink
Merge branch 'master' into fix-6848-forbid-repeated-init
Browse files Browse the repository at this point in the history
  • Loading branch information
loadams authored Jan 31, 2025
2 parents f84cca6 + 029e0a3 commit 13dbe56
Show file tree
Hide file tree
Showing 21 changed files with 367 additions and 101 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/cpu-torch-latest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,5 +59,5 @@ jobs:
run: |
unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch
cd tests
HF_HOME=/tmp/hf_home/ pytest $PYTEST_OPTS -n 4 unit/ --torch_ver="2.5"
HF_HOME=/tmp/hf_home/ pytest $PYTEST_OPTS -m 'sequential' unit/ --torch_ver="2.5"
HF_HOME=/tmp/hf_home/ pytest $PYTEST_OPTS -n 4 unit/ --torch_ver="2.6"
HF_HOME=/tmp/hf_home/ pytest $PYTEST_OPTS -m 'sequential' unit/ --torch_ver="2.6"
3 changes: 2 additions & 1 deletion .github/workflows/nv-ds-chat.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ jobs:

- name: Install pytorch
run: |
pip3 install -U --cache-dir $TORCH_CACHE torch --index-url https://download.pytorch.org/whl/cu121
pip install -U --cache-dir $TORCH_CACHE torch torchvision --index-url https://download.pytorch.org/whl/cu121
python -c "import torch; print('torch:', torch.__version__, torch)"
python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
Expand Down Expand Up @@ -67,6 +67,7 @@ jobs:
run: |
cd DeepSpeedExamples/applications/DeepSpeed-Chat
unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch
unset NCCL_DEBUG
cd tests
pytest $PYTEST_OPTS ./
Expand Down
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -192,11 +192,11 @@ of JIT compiling) or install pre-compiled ops via PyPI please see our [advanced
installation instructions](https://www.deepspeed.ai/tutorials/advanced-install/).

## Windows
Windows support is partially supported with DeepSpeed. On Windows you can build wheel with following steps, currently only inference mode is supported.
1. Install pytorch, such as pytorch 1.8 + cuda 11.1
2. Install visual cpp build tools, such as VS2019 C++ x64/x86 build tools
3. Launch cmd console with Administrator privilege for creating required symlink folders
4. Run `python setup.py bdist_wheel` to build wheel in `dist` folder
Many DeepSpeed features are supported on Windows for both training and inference. You can read more about this in the original blog post [here](https://github.com/microsoft/DeepSpeed/tree/master/blogs/windows/08-2024/README.md). Among features that are currently not supported are async io (AIO) and GDS (which does not support Windows).
1. Install PyTorch, such as pytorch 2.3+cu121.
2. Install Visual C++ build tools, such as VS2022 C++ x64/x86 build tools.
3. Launch Cmd console with Administrator permissions for creating required symlink folders and ensure MSVC tools are added to your PATH or launch the Developer Command Prompt for Visual Studio 2022 with administrator permissions.
4. Run `build_win.bat` to build wheel in `dist` folder.

# Features

Expand Down
2 changes: 2 additions & 0 deletions csrc/fp_quantizer/fp_quantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

at::Tensor quantize(torch::Tensor& out,
torch::Tensor& val,
torch::Tensor& scale,
int group_size,
int stochastic_rounding,
int q_bits,
Expand Down Expand Up @@ -59,6 +60,7 @@ at::Tensor quantize(torch::Tensor& out,

void dequantize(torch::Tensor& val,
torch::Tensor& val_q,
torch::Tensor& scale,
int group_size,
int q_mantisa_bits,
int q_exponent_bits)
Expand Down
19 changes: 19 additions & 0 deletions deepspeed/launcher/multinode_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,31 @@ def name(self):

def validate_args(self):
super().validate_args()

# Validate and set MPI environment variables
self._setup_mpi_environment()

#TODO: Allow for include/exclude at node-level but not gpu-level
if self.args.include != "" or self.args.exclude != "":
raise ValueError(f"{self.name} backend does not support worker include/exclusion")
if self.args.num_nodes != -1 or self.args.num_gpus != -1:
raise ValueError(f"{self.name} backend does not support limiting num nodes/gpus")

def _setup_mpi_environment(self):
"""Sets up MPI-related environment variables or raises an error if they're missing."""

required_vars = ['OMPI_COMM_WORLD_LOCAL_RANK', 'OMPI_COMM_WORLD_RANK', 'OMPI_COMM_WORLD_SIZE']

# Check if all these are present
if not all(var in os.environ for var in required_vars):
raise EnvironmentError("MPI environment variables are not set. "
"Ensure you are running the script with an MPI-compatible launcher.")

# Now safe to read all
os.environ['LOCAL_RANK'] = os.environ['OMPI_COMM_WORLD_LOCAL_RANK']
os.environ['RANK'] = os.environ['OMPI_COMM_WORLD_RANK']
os.environ['WORLD_SIZE'] = os.environ['OMPI_COMM_WORLD_SIZE']

def get_cmd(self, environment, active_resources):
total_process_count = sum(self.resource_pool.values())

Expand Down
9 changes: 7 additions & 2 deletions deepspeed/linear/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,16 @@
from dataclasses import dataclass, field
from typing import List

import torch


@dataclass
class LoRAConfig:
"""
Configuration settings for LoRAOptimizedLinear.
Attributes:
lora_r (int): LoRA attention dimension, also know as the rank. Defaults is 64.
lora_r (int): LoRA attention dimension, also known as the rank. Defaults is 64.
lora_alpha (float): LoRA scaling factor, default is 16.
base_weight_sharding (int): The degree to which the base weights are sharded,
should typically be set to the data-parallel world size to maximize the memory
Expand Down Expand Up @@ -42,8 +44,11 @@ class QuantizationConfig:
Attributes:
q_bits (int): The number of bits used for quantization. Default is 8.
mantissa_bits (int): The number of bits reserved for the mantissa in fixed-point quantization. Default is 3.
group_size (int): The size of the group used for quantization. Default is 512.
group_size (int): The number of elements used for quantization. Default is 512.
q_dtype (torch.dtype): The data type to quantize to. Default is uint8. (in CUDA, buffers are allocated as
uint8, but inside the kernels the quantization is done to fp8)
"""
q_bits: int = 8
mantissa_bits: int = 3
group_size: int = 512
q_dtype: torch.dtype = torch.uint8
8 changes: 4 additions & 4 deletions deepspeed/linear/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,24 +51,24 @@ def __new__(
self.quantizer = quantizer
else:
# if FPQuantizerBuilder is not compatible in this env this init will fail
self.quantizer = FP_Quantize(group_size=self.quantization_config.group_size)
self.quantizer = FP_Quantize(quantization_config=self.quantization_config)
self._ensure_quantized(self)
return self

def _ensure_quantized(self, tensor: torch.Tensor):
# If the tensor is on the accelerator and is not quantized, then quantize it in-place.
if get_accelerator().on_accelerator(tensor) and tensor.dtype != torch.uint8:
if get_accelerator().on_accelerator(tensor) and tensor.dtype != self.quantization_config.q_dtype:
with get_accelerator().stream(get_accelerator().current_stream(tensor.device)):
tensor.data = self.quantizer.quantize(tensor.data,
q_bits=self.quantization_config.q_bits,
q_mantisa_bits=self.quantization_config.mantissa_bits)
assert tensor.dtype == torch.uint8
assert tensor.dtype == self.quantization_config.q_dtype

def dequantized(self) -> torch.Tensor:
"""
Return a tensor containing the dequantized weights of this parameter.
"""
if get_accelerator().on_accelerator(self.data) and self.data.dtype == torch.uint8:
if get_accelerator().on_accelerator(self.data) and self.data.dtype == self.quantization_config.q_dtype:
with get_accelerator().stream(get_accelerator().current_stream(self.data.device)):
return self.quantizer.dequantize(self.data,
q_bits=self.quantization_config.q_bits,
Expand Down
61 changes: 42 additions & 19 deletions deepspeed/ops/fp_quantizer/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

class Quantizer(ABC):
"""
Abstract Quantizer class that implmenents quantize/dequantize methods.
Abstract Quantizer class that implements quantize/dequantize methods.
Arguments:
group_size (int, optional): number of values or elements that are grouped
Expand All @@ -42,12 +42,18 @@ def dequantize(self, input_q, fp_out=None, q_bits=8, q_mantisa_bits=3, scale=Non

class FP_Quantize(Quantizer):

def __init__(self, group_size=512) -> None:
def __init__(self, quantization_config) -> None:
global fp_quant_module
super().__init__(group_size=group_size)
super().__init__(group_size=quantization_config.group_size)
if fp_quant_module is None:
fp_quant_module = FPQuantizerBuilder().load()
self.cuda_impl = getattr(fp_quant_module, "CUDA_IMPL", True)
self.q_config = quantization_config

self.orig_dtype = None
self.num_groups = None
self.input_q = None
self.scale = None

def quantize(self,
input,
Expand All @@ -73,15 +79,27 @@ def quantize(self,
else:
assert (0), \
f"Missing {q_bits}-quantization, please add the template arguments for the kernel to support this precision!"
self.num_groups = input.numel() // self.group_size
self.input_q = torch.ones(self.num_groups,
int(self.group_size * q_bits) // 8 + 4,
dtype=torch.uint8,
device=input.device)
out = fp_quant_module.quantize(self.input_q, input, self.group_size, stochastic_mode, q_bits, q_mantisa_bits)

# Adding (group_size - 1) is for padding
self.num_groups = (input.numel() + self.q_config.group_size - 1) // self.q_config.group_size
# group_size should be the minimal number between the defined group size and number of elements in tensor.
group_size = int(min(self.q_config.group_size, input.numel()) * q_bits) // 8
# CUDA quantization kernel saves the scale as (fp32) inside the quantized tensor for each group
if self.cuda_impl:
group_size += 4
# CUDA quantization kernel allocates tensors as uint8, but handles them as fp8 inside the kernel.
self.input_q = torch.ones(self.num_groups, group_size, dtype=self.q_config.q_dtype, device=input.device)
# CUDA quantization kernel attaches scales to quantized result, in python implementation it can't be done
# because they are of different types.
self.scale = torch.ones(self.num_groups, 1, device=input.device)
out = fp_quant_module.quantize(self.input_q, input, self.scale, group_size, stochastic_mode, q_bits,
q_mantisa_bits)
if return_meta_tensor:
data, self.scale = out.split(self.group_size, dim=-1)
data = data.contiguous().reshape(input.shape)
if self.cuda_impl:
data, self.scale = out.split(group_size, dim=-1)
data = data.contiguous().reshape(input.shape)
else:
data = out.contiguous().reshape(input.shape)
self.scale = self.scale.contiguous()
del self.input_q
del out
Expand All @@ -93,9 +111,9 @@ def quantize(self,

def to(self, *args, **kwargs):
# Intermediate tensors may need to be moved to different devices
if hasattr(self, 'input_q'):
if hasattr(self, 'input_q') and self.input_q is not None:
self.input_q = self.input_q.to(*args, **kwargs)
if hasattr(self, 'scale'):
if hasattr(self, 'scale') and self.scale is not None:
self.scale = self.scale.to(*args, **kwargs)

def get_scales(self):
Expand All @@ -118,11 +136,16 @@ def dequantize(self, input_q, fp_out=None, q_bits=8, q_mantisa_bits=3, scale=Non
assert (0), \
f"Missing {q_bits}-dequantization, please add the template arguments for the kernel to support this precision!"

if scale is not None:
if scale is not None and self.cuda_impl:
assert input_q.numel() == fp_out.numel(), \
f'[De-quantization Error]: quantized data should have the same size as original tensor when scale is not None!'
input_q = torch.cat([input_q.reshape(-1, self.group_size), scale], dim=-1).contiguous()
fp_quant_module.dequantize(fp_out, input_q, self.group_size, q_mantisa_bits, q_bits - q_mantisa_bits - 1)
input_q = torch.cat([input_q.reshape(-1, self.q_config.group_size), scale], dim=-1).contiguous()
elif scale is not None and not self.cuda_impl:
group_size = int(min(self.q_config.group_size, input_q.numel()) * q_bits) // 8
input_q = input_q.reshape(-1, group_size)

fp_quant_module.dequantize(fp_out, input_q, self.scale, self.q_config.group_size, q_mantisa_bits,
q_bits - q_mantisa_bits - 1)
return fp_out

def selective_dequantize(self,
Expand Down Expand Up @@ -151,11 +174,11 @@ def selective_dequantize(self,
assert (0), \
f"Missing {q_bits}-dequantization, please add the template arguments for the kernel to support this precision!"

if scale is not None:
if scale is not None and self.cuda_impl:
assert input_q.numel() == fp_out.numel(), \
f'[De-quantization Error]: quantized data should have the same size as original tensor when scale is not None!'
input_q = torch.cat([input_q.reshape(-1, self.group_size), scale], dim=-1).contiguous()
input_q = torch.cat([input_q.reshape(-1, self.q_config.group_size), scale], dim=-1).contiguous()

fp_quant_module.selective_dequantize(fp_out, input_q, indexes, self.group_size, q_mantisa_bits,
fp_quant_module.selective_dequantize(fp_out, input_q, indexes, self.q_config.group_size, q_mantisa_bits,
q_bits - q_mantisa_bits - 1)
return fp_out
24 changes: 13 additions & 11 deletions deepspeed/runtime/zero/parameter_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def _start_of_forward_hook(module, *args):
self.module.register_forward_pre_hook(_start_of_forward_hook)

#likely one of them should be enough but just to be safe
self._register_hooks_recursively(self.module)
self._register_deepspeed_module(self.module)

# Add top module to stack trace
global FWD_MODULE_STACK
Expand All @@ -269,19 +269,19 @@ def mark_persistent_parameters(self, param_threshold, model_threshold):

return persistent_params

def _register_hooks_recursively(self, module, count=[0]):
def _register_deepspeed_module(self, module, count=[0]):
my_count = count[0]
module.id = my_count
module.ds_id = my_count

#print(f"{module.__class__} : {module.id}")
#print(f"{module.__class__} : {module.ds_id}")

if z3_leaf_module(module):
for param in module.parameters():
param.ds_z3_leaf_module = module
else:
for child in module.children():
count[0] = count[0] + 1
self._register_hooks_recursively(child, count=count)
self._register_deepspeed_module(child, count=count)

@instrument_w_nvtx
def _pre_forward_module_hook(module, *args):
Expand Down Expand Up @@ -466,14 +466,16 @@ def pre_sub_module_forward_function(self, sub_module):

@torch.no_grad()
def post_sub_module_forward_function(self, sub_module):
see_memory_usage(f"After sub module function {sub_module.__class__.__name__} {sub_module.id} before release",
force=False)
see_memory_usage(
f"After sub module function {sub_module.__class__.__name__} {sub_module.ds_id} before release",
force=False)

param_coordinator = self.get_param_coordinator()
param_coordinator.release_sub_module(sub_module)

see_memory_usage(f"After sub module function {sub_module.__class__.__name__} {sub_module.id} after release",
force=False)
see_memory_usage(
f"After sub module function {sub_module.__class__.__name__} {sub_module.ds_id} after release",
force=False)

@torch.no_grad()
def pre_sub_module_backward_function(self, sub_module):
Expand All @@ -488,13 +490,13 @@ def pre_sub_module_backward_function(self, sub_module):
def post_sub_module_backward_function(self, sub_module):
# assert sub_module.training, "backward pass is invalid for module in evaluation mode"
see_memory_usage(
f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} before release",
f"After sub module backward function {sub_module.__class__.__name__} {sub_module.ds_id} before release",
force=False)

self.get_param_coordinator().release_sub_module(sub_module)

see_memory_usage(
f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} after release",
f"After sub module backward function {sub_module.__class__.__name__} {sub_module.ds_id} after release",
force=False)

def _set_z3_leaf_modules_by_threshold(self, module, zero_module_granularity_threshold):
Expand Down
Loading

0 comments on commit 13dbe56

Please sign in to comment.