Skip to content

Commit 13dbe56

Browse files
authored
Merge branch 'master' into fix-6848-forbid-repeated-init
2 parents f84cca6 + 029e0a3 commit 13dbe56

File tree

21 files changed

+367
-101
lines changed

21 files changed

+367
-101
lines changed

.github/workflows/cpu-torch-latest.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,5 +59,5 @@ jobs:
5959
run: |
6060
unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch
6161
cd tests
62-
HF_HOME=/tmp/hf_home/ pytest $PYTEST_OPTS -n 4 unit/ --torch_ver="2.5"
63-
HF_HOME=/tmp/hf_home/ pytest $PYTEST_OPTS -m 'sequential' unit/ --torch_ver="2.5"
62+
HF_HOME=/tmp/hf_home/ pytest $PYTEST_OPTS -n 4 unit/ --torch_ver="2.6"
63+
HF_HOME=/tmp/hf_home/ pytest $PYTEST_OPTS -m 'sequential' unit/ --torch_ver="2.6"

.github/workflows/nv-ds-chat.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ jobs:
3737

3838
- name: Install pytorch
3939
run: |
40-
pip3 install -U --cache-dir $TORCH_CACHE torch --index-url https://download.pytorch.org/whl/cu121
40+
pip install -U --cache-dir $TORCH_CACHE torch torchvision --index-url https://download.pytorch.org/whl/cu121
4141
python -c "import torch; print('torch:', torch.__version__, torch)"
4242
python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
4343
@@ -67,6 +67,7 @@ jobs:
6767
run: |
6868
cd DeepSpeedExamples/applications/DeepSpeed-Chat
6969
unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch
70+
unset NCCL_DEBUG
7071
cd tests
7172
pytest $PYTEST_OPTS ./
7273

README.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -192,11 +192,11 @@ of JIT compiling) or install pre-compiled ops via PyPI please see our [advanced
192192
installation instructions](https://www.deepspeed.ai/tutorials/advanced-install/).
193193

194194
## Windows
195-
Windows support is partially supported with DeepSpeed. On Windows you can build wheel with following steps, currently only inference mode is supported.
196-
1. Install pytorch, such as pytorch 1.8 + cuda 11.1
197-
2. Install visual cpp build tools, such as VS2019 C++ x64/x86 build tools
198-
3. Launch cmd console with Administrator privilege for creating required symlink folders
199-
4. Run `python setup.py bdist_wheel` to build wheel in `dist` folder
195+
Many DeepSpeed features are supported on Windows for both training and inference. You can read more about this in the original blog post [here](https://github.com/microsoft/DeepSpeed/tree/master/blogs/windows/08-2024/README.md). Among features that are currently not supported are async io (AIO) and GDS (which does not support Windows).
196+
1. Install PyTorch, such as pytorch 2.3+cu121.
197+
2. Install Visual C++ build tools, such as VS2022 C++ x64/x86 build tools.
198+
3. Launch Cmd console with Administrator permissions for creating required symlink folders and ensure MSVC tools are added to your PATH or launch the Developer Command Prompt for Visual Studio 2022 with administrator permissions.
199+
4. Run `build_win.bat` to build wheel in `dist` folder.
200200

201201
# Features
202202

csrc/fp_quantizer/fp_quantize.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
at::Tensor quantize(torch::Tensor& out,
2626
torch::Tensor& val,
27+
torch::Tensor& scale,
2728
int group_size,
2829
int stochastic_rounding,
2930
int q_bits,
@@ -59,6 +60,7 @@ at::Tensor quantize(torch::Tensor& out,
5960

6061
void dequantize(torch::Tensor& val,
6162
torch::Tensor& val_q,
63+
torch::Tensor& scale,
6264
int group_size,
6365
int q_mantisa_bits,
6466
int q_exponent_bits)

deepspeed/launcher/multinode_runner.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,12 +134,31 @@ def name(self):
134134

135135
def validate_args(self):
136136
super().validate_args()
137+
138+
# Validate and set MPI environment variables
139+
self._setup_mpi_environment()
140+
137141
#TODO: Allow for include/exclude at node-level but not gpu-level
138142
if self.args.include != "" or self.args.exclude != "":
139143
raise ValueError(f"{self.name} backend does not support worker include/exclusion")
140144
if self.args.num_nodes != -1 or self.args.num_gpus != -1:
141145
raise ValueError(f"{self.name} backend does not support limiting num nodes/gpus")
142146

147+
def _setup_mpi_environment(self):
148+
"""Sets up MPI-related environment variables or raises an error if they're missing."""
149+
150+
required_vars = ['OMPI_COMM_WORLD_LOCAL_RANK', 'OMPI_COMM_WORLD_RANK', 'OMPI_COMM_WORLD_SIZE']
151+
152+
# Check if all these are present
153+
if not all(var in os.environ for var in required_vars):
154+
raise EnvironmentError("MPI environment variables are not set. "
155+
"Ensure you are running the script with an MPI-compatible launcher.")
156+
157+
# Now safe to read all
158+
os.environ['LOCAL_RANK'] = os.environ['OMPI_COMM_WORLD_LOCAL_RANK']
159+
os.environ['RANK'] = os.environ['OMPI_COMM_WORLD_RANK']
160+
os.environ['WORLD_SIZE'] = os.environ['OMPI_COMM_WORLD_SIZE']
161+
143162
def get_cmd(self, environment, active_resources):
144163
total_process_count = sum(self.resource_pool.values())
145164

deepspeed/linear/config.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,16 @@
66
from dataclasses import dataclass, field
77
from typing import List
88

9+
import torch
10+
911

1012
@dataclass
1113
class LoRAConfig:
1214
"""
1315
Configuration settings for LoRAOptimizedLinear.
1416
1517
Attributes:
16-
lora_r (int): LoRA attention dimension, also know as the rank. Defaults is 64.
18+
lora_r (int): LoRA attention dimension, also known as the rank. Defaults is 64.
1719
lora_alpha (float): LoRA scaling factor, default is 16.
1820
base_weight_sharding (int): The degree to which the base weights are sharded,
1921
should typically be set to the data-parallel world size to maximize the memory
@@ -42,8 +44,11 @@ class QuantizationConfig:
4244
Attributes:
4345
q_bits (int): The number of bits used for quantization. Default is 8.
4446
mantissa_bits (int): The number of bits reserved for the mantissa in fixed-point quantization. Default is 3.
45-
group_size (int): The size of the group used for quantization. Default is 512.
47+
group_size (int): The number of elements used for quantization. Default is 512.
48+
q_dtype (torch.dtype): The data type to quantize to. Default is uint8. (in CUDA, buffers are allocated as
49+
uint8, but inside the kernels the quantization is done to fp8)
4650
"""
4751
q_bits: int = 8
4852
mantissa_bits: int = 3
4953
group_size: int = 512
54+
q_dtype: torch.dtype = torch.uint8

deepspeed/linear/quantization.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,24 +51,24 @@ def __new__(
5151
self.quantizer = quantizer
5252
else:
5353
# if FPQuantizerBuilder is not compatible in this env this init will fail
54-
self.quantizer = FP_Quantize(group_size=self.quantization_config.group_size)
54+
self.quantizer = FP_Quantize(quantization_config=self.quantization_config)
5555
self._ensure_quantized(self)
5656
return self
5757

5858
def _ensure_quantized(self, tensor: torch.Tensor):
5959
# If the tensor is on the accelerator and is not quantized, then quantize it in-place.
60-
if get_accelerator().on_accelerator(tensor) and tensor.dtype != torch.uint8:
60+
if get_accelerator().on_accelerator(tensor) and tensor.dtype != self.quantization_config.q_dtype:
6161
with get_accelerator().stream(get_accelerator().current_stream(tensor.device)):
6262
tensor.data = self.quantizer.quantize(tensor.data,
6363
q_bits=self.quantization_config.q_bits,
6464
q_mantisa_bits=self.quantization_config.mantissa_bits)
65-
assert tensor.dtype == torch.uint8
65+
assert tensor.dtype == self.quantization_config.q_dtype
6666

6767
def dequantized(self) -> torch.Tensor:
6868
"""
6969
Return a tensor containing the dequantized weights of this parameter.
7070
"""
71-
if get_accelerator().on_accelerator(self.data) and self.data.dtype == torch.uint8:
71+
if get_accelerator().on_accelerator(self.data) and self.data.dtype == self.quantization_config.q_dtype:
7272
with get_accelerator().stream(get_accelerator().current_stream(self.data.device)):
7373
return self.quantizer.dequantize(self.data,
7474
q_bits=self.quantization_config.q_bits,

deepspeed/ops/fp_quantizer/quantize.py

Lines changed: 42 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
class Quantizer(ABC):
1818
"""
19-
Abstract Quantizer class that implmenents quantize/dequantize methods.
19+
Abstract Quantizer class that implements quantize/dequantize methods.
2020
2121
Arguments:
2222
group_size (int, optional): number of values or elements that are grouped
@@ -42,12 +42,18 @@ def dequantize(self, input_q, fp_out=None, q_bits=8, q_mantisa_bits=3, scale=Non
4242

4343
class FP_Quantize(Quantizer):
4444

45-
def __init__(self, group_size=512) -> None:
45+
def __init__(self, quantization_config) -> None:
4646
global fp_quant_module
47-
super().__init__(group_size=group_size)
47+
super().__init__(group_size=quantization_config.group_size)
4848
if fp_quant_module is None:
4949
fp_quant_module = FPQuantizerBuilder().load()
50+
self.cuda_impl = getattr(fp_quant_module, "CUDA_IMPL", True)
51+
self.q_config = quantization_config
52+
5053
self.orig_dtype = None
54+
self.num_groups = None
55+
self.input_q = None
56+
self.scale = None
5157

5258
def quantize(self,
5359
input,
@@ -73,15 +79,27 @@ def quantize(self,
7379
else:
7480
assert (0), \
7581
f"Missing {q_bits}-quantization, please add the template arguments for the kernel to support this precision!"
76-
self.num_groups = input.numel() // self.group_size
77-
self.input_q = torch.ones(self.num_groups,
78-
int(self.group_size * q_bits) // 8 + 4,
79-
dtype=torch.uint8,
80-
device=input.device)
81-
out = fp_quant_module.quantize(self.input_q, input, self.group_size, stochastic_mode, q_bits, q_mantisa_bits)
82+
83+
# Adding (group_size - 1) is for padding
84+
self.num_groups = (input.numel() + self.q_config.group_size - 1) // self.q_config.group_size
85+
# group_size should be the minimal number between the defined group size and number of elements in tensor.
86+
group_size = int(min(self.q_config.group_size, input.numel()) * q_bits) // 8
87+
# CUDA quantization kernel saves the scale as (fp32) inside the quantized tensor for each group
88+
if self.cuda_impl:
89+
group_size += 4
90+
# CUDA quantization kernel allocates tensors as uint8, but handles them as fp8 inside the kernel.
91+
self.input_q = torch.ones(self.num_groups, group_size, dtype=self.q_config.q_dtype, device=input.device)
92+
# CUDA quantization kernel attaches scales to quantized result, in python implementation it can't be done
93+
# because they are of different types.
94+
self.scale = torch.ones(self.num_groups, 1, device=input.device)
95+
out = fp_quant_module.quantize(self.input_q, input, self.scale, group_size, stochastic_mode, q_bits,
96+
q_mantisa_bits)
8297
if return_meta_tensor:
83-
data, self.scale = out.split(self.group_size, dim=-1)
84-
data = data.contiguous().reshape(input.shape)
98+
if self.cuda_impl:
99+
data, self.scale = out.split(group_size, dim=-1)
100+
data = data.contiguous().reshape(input.shape)
101+
else:
102+
data = out.contiguous().reshape(input.shape)
85103
self.scale = self.scale.contiguous()
86104
del self.input_q
87105
del out
@@ -93,9 +111,9 @@ def quantize(self,
93111

94112
def to(self, *args, **kwargs):
95113
# Intermediate tensors may need to be moved to different devices
96-
if hasattr(self, 'input_q'):
114+
if hasattr(self, 'input_q') and self.input_q is not None:
97115
self.input_q = self.input_q.to(*args, **kwargs)
98-
if hasattr(self, 'scale'):
116+
if hasattr(self, 'scale') and self.scale is not None:
99117
self.scale = self.scale.to(*args, **kwargs)
100118

101119
def get_scales(self):
@@ -118,11 +136,16 @@ def dequantize(self, input_q, fp_out=None, q_bits=8, q_mantisa_bits=3, scale=Non
118136
assert (0), \
119137
f"Missing {q_bits}-dequantization, please add the template arguments for the kernel to support this precision!"
120138

121-
if scale is not None:
139+
if scale is not None and self.cuda_impl:
122140
assert input_q.numel() == fp_out.numel(), \
123141
f'[De-quantization Error]: quantized data should have the same size as original tensor when scale is not None!'
124-
input_q = torch.cat([input_q.reshape(-1, self.group_size), scale], dim=-1).contiguous()
125-
fp_quant_module.dequantize(fp_out, input_q, self.group_size, q_mantisa_bits, q_bits - q_mantisa_bits - 1)
142+
input_q = torch.cat([input_q.reshape(-1, self.q_config.group_size), scale], dim=-1).contiguous()
143+
elif scale is not None and not self.cuda_impl:
144+
group_size = int(min(self.q_config.group_size, input_q.numel()) * q_bits) // 8
145+
input_q = input_q.reshape(-1, group_size)
146+
147+
fp_quant_module.dequantize(fp_out, input_q, self.scale, self.q_config.group_size, q_mantisa_bits,
148+
q_bits - q_mantisa_bits - 1)
126149
return fp_out
127150

128151
def selective_dequantize(self,
@@ -151,11 +174,11 @@ def selective_dequantize(self,
151174
assert (0), \
152175
f"Missing {q_bits}-dequantization, please add the template arguments for the kernel to support this precision!"
153176

154-
if scale is not None:
177+
if scale is not None and self.cuda_impl:
155178
assert input_q.numel() == fp_out.numel(), \
156179
f'[De-quantization Error]: quantized data should have the same size as original tensor when scale is not None!'
157-
input_q = torch.cat([input_q.reshape(-1, self.group_size), scale], dim=-1).contiguous()
180+
input_q = torch.cat([input_q.reshape(-1, self.q_config.group_size), scale], dim=-1).contiguous()
158181

159-
fp_quant_module.selective_dequantize(fp_out, input_q, indexes, self.group_size, q_mantisa_bits,
182+
fp_quant_module.selective_dequantize(fp_out, input_q, indexes, self.q_config.group_size, q_mantisa_bits,
160183
q_bits - q_mantisa_bits - 1)
161184
return fp_out

deepspeed/runtime/zero/parameter_offload.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ def _start_of_forward_hook(module, *args):
243243
self.module.register_forward_pre_hook(_start_of_forward_hook)
244244

245245
#likely one of them should be enough but just to be safe
246-
self._register_hooks_recursively(self.module)
246+
self._register_deepspeed_module(self.module)
247247

248248
# Add top module to stack trace
249249
global FWD_MODULE_STACK
@@ -269,19 +269,19 @@ def mark_persistent_parameters(self, param_threshold, model_threshold):
269269

270270
return persistent_params
271271

272-
def _register_hooks_recursively(self, module, count=[0]):
272+
def _register_deepspeed_module(self, module, count=[0]):
273273
my_count = count[0]
274-
module.id = my_count
274+
module.ds_id = my_count
275275

276-
#print(f"{module.__class__} : {module.id}")
276+
#print(f"{module.__class__} : {module.ds_id}")
277277

278278
if z3_leaf_module(module):
279279
for param in module.parameters():
280280
param.ds_z3_leaf_module = module
281281
else:
282282
for child in module.children():
283283
count[0] = count[0] + 1
284-
self._register_hooks_recursively(child, count=count)
284+
self._register_deepspeed_module(child, count=count)
285285

286286
@instrument_w_nvtx
287287
def _pre_forward_module_hook(module, *args):
@@ -466,14 +466,16 @@ def pre_sub_module_forward_function(self, sub_module):
466466

467467
@torch.no_grad()
468468
def post_sub_module_forward_function(self, sub_module):
469-
see_memory_usage(f"After sub module function {sub_module.__class__.__name__} {sub_module.id} before release",
470-
force=False)
469+
see_memory_usage(
470+
f"After sub module function {sub_module.__class__.__name__} {sub_module.ds_id} before release",
471+
force=False)
471472

472473
param_coordinator = self.get_param_coordinator()
473474
param_coordinator.release_sub_module(sub_module)
474475

475-
see_memory_usage(f"After sub module function {sub_module.__class__.__name__} {sub_module.id} after release",
476-
force=False)
476+
see_memory_usage(
477+
f"After sub module function {sub_module.__class__.__name__} {sub_module.ds_id} after release",
478+
force=False)
477479

478480
@torch.no_grad()
479481
def pre_sub_module_backward_function(self, sub_module):
@@ -488,13 +490,13 @@ def pre_sub_module_backward_function(self, sub_module):
488490
def post_sub_module_backward_function(self, sub_module):
489491
# assert sub_module.training, "backward pass is invalid for module in evaluation mode"
490492
see_memory_usage(
491-
f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} before release",
493+
f"After sub module backward function {sub_module.__class__.__name__} {sub_module.ds_id} before release",
492494
force=False)
493495

494496
self.get_param_coordinator().release_sub_module(sub_module)
495497

496498
see_memory_usage(
497-
f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} after release",
499+
f"After sub module backward function {sub_module.__class__.__name__} {sub_module.ds_id} after release",
498500
force=False)
499501

500502
def _set_z3_leaf_modules_by_threshold(self, module, zero_module_granularity_threshold):

0 commit comments

Comments
 (0)