Skip to content

Commit

Permalink
Remove assumption that padding only occurs on last rank (#6974)
Browse files Browse the repository at this point in the history
As discussed in
[PR-6918](#6918), padding can
occur on multiple ranks with large DP degrees.

For example, with:
- Flattened tensor size: 266240
- DP degree: 768
- Alignment: 1536 
- Required padding: 1024 (1536 * 174 - 266240)
- Per-rank partition size: 348 (1536 * 174 / 768)
- The padding occurs on last three ranks. 

This PR removes the single-rank padding assumption for more general
cases.

---------

Co-authored-by: Sam Foreman <[email protected]>
Co-authored-by: Logan Adams <[email protected]>
Co-authored-by: Olatunji Ruwase <[email protected]>
  • Loading branch information
4 people authored Jan 31, 2025
1 parent c963c21 commit 4fea41f
Showing 1 changed file with 12 additions and 7 deletions.
19 changes: 12 additions & 7 deletions deepspeed/runtime/zero/stage_1_and_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,13 +366,6 @@ def __init__(self,

see_memory_usage(f"After flattening and moving param group {i} to GPU", force=False)

# Record padding required for alignment
if partition_id == dist.get_world_size(group=self.real_dp_process_group[i]) - 1:
padding = self.bit16_groups_flat[i].numel() - orig_group_numel
else:
padding = 0
self.groups_padding.append(padding)

if dist.get_rank(group=self.real_dp_process_group[i]) == 0:
see_memory_usage(f"After Flattening and after emptying param group {i} cache", force=False)

Expand All @@ -384,6 +377,18 @@ def __init__(self,
data_parallel_partitions = self.get_data_parallel_partitions(self.bit16_groups_flat[i], i)
self.parallel_partitioned_bit16_groups.append(data_parallel_partitions)

# Record padding required for alignment
left_boundary = sum([t.numel() for t in data_parallel_partitions[:partition_id]])
curr_partition_size = data_parallel_partitions[partition_id].numel()

if orig_group_numel <= left_boundary:
padding = curr_partition_size
elif orig_group_numel < left_boundary + curr_partition_size:
padding = left_boundary + curr_partition_size - orig_group_numel
else:
padding = 0
self.groups_padding.append(padding)

# verify that data partition start locations are 4-byte aligned
for partitioned_data in data_parallel_partitions:
assert (partitioned_data.data_ptr() % (2 * self.nccl_start_alignment_factor) == 0)
Expand Down

0 comments on commit 4fea41f

Please sign in to comment.