Skip to content

Commit b1d4330

Browse files
authored
Merge branch 'master' into fix-6848-forbid-repeated-init
2 parents a9837f9 + 8d1bc0a commit b1d4330

File tree

6 files changed

+77
-65
lines changed

6 files changed

+77
-65
lines changed

deepspeed/runtime/comm/compressed.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def compressed_allreduce(self, buffer_m: torch.tensor, worker_error, server_erro
9696

9797
compensated_server_m.add_(server_error)
9898

99-
server_scale = torch.norm(compensated_server_m) / np.sqrt(compensated_server_m.numel())
99+
server_scale = torch.linalg.norm(compensated_server_m) / np.sqrt(compensated_server_m.numel())
100100

101101
server_error.set_(compensated_server_m -
102102
server_scale * compensated_server_m.sign().add_(1).bool().float().add_(-0.5).mul_(2.0))

deepspeed/runtime/comm/hccl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def compressed_allreduce(self, buffer_m: torch.tensor, worker_error, server_erro
8383

8484
compensated_server_m.add_(server_error)
8585

86-
server_scale = torch.norm(compensated_server_m) / np.sqrt(compensated_server_m.numel())
86+
server_scale = torch.linalg.norm(compensated_server_m) / np.sqrt(compensated_server_m.numel())
8787

8888
server_error.set_(compensated_server_m -
8989
server_scale * compensated_server_m.sign().add_(1).bool().float().add_(-0.5).mul_(2.0))

deepspeed/runtime/fp16/onebit/lamb.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def step(self, closure=None, grads=None):
177177
# This is used to reduce compression error during compression stage.
178178
momentum_scales = []
179179
for group in self.param_groups:
180-
momentum_scales.append([(torch.linalg.norm(self.state[p]['exp_avg']) /
180+
momentum_scales.append([(torch.linalg.vector_norm(self.state[p]['exp_avg']) /
181181
np.sqrt(torch.numel(self.state[p]['exp_avg']))).item()
182182
for p in group['params']])
183183
united_scale = sum([sum(x) for x in momentum_scales]) / sum([len(x) for x in momentum_scales])

deepspeed/runtime/zero/stage3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2101,7 +2101,7 @@ def step(self, closure=None):
21012101
return
21022102

21032103
norm_groups = self._get_norm_groups()
2104-
scaled_global_grad_norm = torch.linalg.norm(torch.stack(norm_groups))
2104+
scaled_global_grad_norm = torch.linalg.vector_norm(torch.stack(norm_groups))
21052105

21062106
# Stash unscaled gradient norm
21072107
self._global_grad_norm = scaled_global_grad_norm / self.loss_scale

deepspeed/runtime/zero/stage_1_and_2.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1691,7 +1691,8 @@ def get_grad_norm_direct(self, gradients, params, norm_type=2):
16911691
continue
16921692
if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0):
16931693
all_norms.append(
1694-
torch.norm(g.data.double().detach(), norm_type).to(get_accelerator().current_device_name()))
1694+
torch.linalg.vector_norm(g.data.double().detach(),
1695+
ord=norm_type).to(get_accelerator().current_device_name()))
16951696
if len(all_norms) > 0:
16961697
total_norm = torch.stack(all_norms).square().sum().float()
16971698
else:
@@ -1795,7 +1796,7 @@ def scaled_global_norm(self, norm_type=2):
17951796
self._average_expert_grad_norms(norm_groups)
17961797

17971798
# calculating L2 norm
1798-
return torch.norm(torch.stack(norm_groups), p=norm_type)
1799+
return torch.linalg.vector_norm(torch.stack(norm_groups), ord=norm_type)
17991800

18001801
def get_bit16_param_group(self, group_no):
18011802
bit16_partitions = self.parallel_partitioned_bit16_groups[group_no]

deepspeed/sequence/layer.py

Lines changed: 70 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,71 @@
1616
from deepspeed.utils import groups
1717

1818

19+
def _generate_layout_params(scatter_idx, batch_dim_idx, seq_world_size, input):
20+
"""
21+
This function generates the parameters required for `permute` and `reshape` operations,
22+
which are used to process data before and after `all2all` communication.
23+
"""
24+
if batch_dim_idx == 0:
25+
if scatter_idx < 2:
26+
bs, global_seq_len, num_local_head, head_dim = input.shape
27+
pre_all2all_inp_shape = [bs, seq_world_size, global_seq_len // seq_world_size, num_local_head, head_dim]
28+
pre_all2all_permute_idx = (1, 0, 2, 3, 4)
29+
30+
post_all2all_permute_idx = (1, 2, 0, 3, 4)
31+
post_all2all_res_shape = [bs, global_seq_len // seq_world_size, seq_world_size * num_local_head, head_dim]
32+
else:
33+
bs, local_seq_len, num_total_head, head_dim = input.shape
34+
assert num_total_head % seq_world_size == 0, f"Number of heads ({num_total_head}) must be divisible by the sequence parallel size ({seq_world_size})!"
35+
pre_all2all_inp_shape = [bs, local_seq_len, seq_world_size, num_total_head // seq_world_size, head_dim]
36+
pre_all2all_permute_idx = (2, 0, 1, 3, 4)
37+
38+
post_all2all_permute_idx = (1, 0, 2, 3, 4)
39+
post_all2all_res_shape = [bs, seq_world_size * local_seq_len, num_total_head // seq_world_size, head_dim]
40+
else:
41+
if scatter_idx < 2:
42+
global_seq_len, bs, num_local_head, head_dim = input.shape
43+
pre_all2all_inp_shape = [seq_world_size, global_seq_len // seq_world_size, bs, num_local_head, head_dim]
44+
pre_all2all_permute_idx = None
45+
46+
post_all2all_permute_idx = (1, 2, 0, 3, 4)
47+
post_all2all_res_shape = [bs, seq_world_size * global_seq_len, num_local_head // seq_world_size, head_dim]
48+
else:
49+
local_seq_len, bs, num_total_head, head_dim = input.shape
50+
assert num_total_head % seq_world_size == 0, f"Number of heads ({num_total_head}) must be divisible by the sequence parallel size ({seq_world_size})!"
51+
pre_all2all_inp_shape = [local_seq_len, bs, seq_world_size, num_total_head // seq_world_size, head_dim]
52+
pre_all2all_permute_idx = (2, 0, 1, 3, 4)
53+
post_all2all_permute_idx = None
54+
post_all2all_res_shape = [local_seq_len * seq_world_size, bs, num_total_head // seq_world_size, head_dim]
55+
56+
return pre_all2all_permute_idx, pre_all2all_inp_shape, post_all2all_permute_idx, post_all2all_res_shape
57+
58+
59+
def post_all2all(permute_idx, res_shape):
60+
"""
61+
Post-processing function for `all2all` communication.
62+
"""
63+
64+
def post_func(input):
65+
if permute_idx is not None:
66+
input = input.permute(permute_idx).contiguous()
67+
output = input.reshape(res_shape).contiguous()
68+
69+
return output
70+
71+
return post_func
72+
73+
74+
def pre_all2all_fun(permute_idx, inp_shape, input):
75+
"""
76+
Pre-processing function for `all2all` communication.
77+
"""
78+
input_t = input.reshape(inp_shape).contiguous()
79+
if permute_idx is not None:
80+
input_t = input_t.permute(permute_idx).contiguous()
81+
return input_t
82+
83+
1984
def _rotate_half(x):
2085
"""
2186
change sign so the last dimension becomes [-odd, +even]
@@ -43,32 +108,6 @@ def apply_rotary_pos_emb(t, freqs_cos, freqs_sin):
43108
return res
44109

45110

46-
def post_all2all(scatter_idx, batch_dim_idx, seq_world_size, bs, seq_len, num_head, head_dim):
47-
48-
def post_func(input):
49-
if batch_dim_idx == 0:
50-
# b, s, n, h
51-
if scatter_idx < 2:
52-
output = input.permute(1, 2, 0, 3, 4).contiguous()
53-
output = output.reshape(bs, seq_len // seq_world_size, seq_world_size * num_head,
54-
head_dim).contiguous()
55-
else:
56-
output = input.permute(1, 0, 2, 3, 4).contiguous()
57-
output = output.reshape(bs, seq_world_size * seq_len, num_head // seq_world_size,
58-
head_dim).contiguous()
59-
else:
60-
# s, b, n, h
61-
if scatter_idx < 2:
62-
output = input.permute(1, 2, 0, 3, 4).contiguous()
63-
output = output.reshape(seq_len // seq_world_size, bs, seq_world_size * num_head,
64-
head_dim).contiguous()
65-
else:
66-
output = input.reshape(seq_len * seq_world_size, bs, num_head // seq_world_size, head_dim).contiguous()
67-
return output
68-
69-
return post_func
70-
71-
72111
def uneven_heads_all2all(input, scatter_idx, gather_idx, batch_dim_idx, group):
73112
seq_world_size = dist.get_world_size(group)
74113
inp_shape = list(input.shape)
@@ -195,39 +234,12 @@ def single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, asyn
195234
assert async_op == False, "uneven head sp does not support async op"
196235
return uneven_heads_all2all(input, scatter_idx, gather_idx, batch_dim_idx, group)
197236

198-
if batch_dim_idx == 0:
199-
# b, s, n, h
200-
if scatter_idx < 2:
201-
bs, global_seq_len, num_local_head, head_dim = input.shape
202-
input_t = input.reshape([bs, seq_world_size, global_seq_len // seq_world_size, num_local_head,
203-
head_dim]).contiguous()
204-
input_t = input_t.permute(1, 0, 2, 3, 4).contiguous()
205-
else:
206-
bs, local_seq_len, num_total_head, head_dim = input.shape
207-
assert num_total_head % seq_world_size == 0, f"Number of heads ({num_total_head}) must be divisible by the sequence parallel size ({seq_world_size})!"
208-
input_t = input.reshape([bs, local_seq_len, seq_world_size, num_total_head // seq_world_size,
209-
head_dim]).contiguous()
210-
input_t = input_t.permute(2, 0, 1, 3, 4).contiguous()
211-
else:
212-
# s, b, n, h
213-
if scatter_idx < 2:
214-
global_seq_len, bs, num_local_head, head_dim = input.shape
215-
input_t = input.reshape([seq_world_size, global_seq_len // seq_world_size, bs, num_local_head,
216-
head_dim]).contiguous()
217-
else:
218-
local_seq_len, bs, num_total_head, head_dim = input.shape
219-
assert num_total_head % seq_world_size == 0, f"Number of heads ({num_total_head}) must be divisible by the sequence parallel size ({seq_world_size})!"
220-
input_t = input.reshape([local_seq_len, bs, seq_world_size, num_total_head // seq_world_size,
221-
head_dim]).contiguous()
222-
input_t = input_t.permute(2, 0, 1, 3, 4).contiguous()
237+
pre_all2all_permute_idx, pre_all2all_inp_shape, post_all2all_permute_idx, post_all2all_res_shape = _generate_layout_params(
238+
scatter_idx, batch_dim_idx, seq_world_size, input)
223239

224-
if scatter_idx < 2:
225-
post_all2all_fun = post_all2all(scatter_idx, batch_dim_idx, seq_world_size, bs, global_seq_len, num_local_head,
226-
head_dim)
227-
else:
228-
post_all2all_fun = post_all2all(scatter_idx, batch_dim_idx, seq_world_size, bs, local_seq_len, num_total_head,
229-
head_dim)
240+
input_t = pre_all2all_fun(pre_all2all_permute_idx, pre_all2all_inp_shape, input)
230241

242+
post_all2all_fun = post_all2all(post_all2all_permute_idx, post_all2all_res_shape)
231243
output = torch.empty_like(input_t)
232244
work = dist.all_to_all_single(output, input_t, group=group, async_op=async_op)
233245

@@ -236,7 +248,7 @@ def single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, asyn
236248
handle[type + '_work'] = work
237249
handle[type + '_grad'] = output
238250
handle[type + '_post_all2all_func'] = post_all2all_fun
239-
return output
251+
return output.view(post_all2all_res_shape)
240252

241253
res = post_all2all_fun(output)
242254
return res
@@ -271,7 +283,6 @@ def forward(ctx: Any,
271283
assert ctx.stream != None
272284
res = single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, False)
273285
get_accelerator().current_stream().wait_stream(ctx.stream)
274-
del ctx.stream.activation_buffer_list
275286
# The computation of d o_weight can overlap with the communication of d o_input
276287

277288
elif not is_fwd and type in ('q', 'k'):

0 commit comments

Comments
 (0)