Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[shardformer]: support gpt-j, falcon and add interleaved pipeline for bert #5088

Merged
merged 27 commits into from
Nov 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
7e81ba2
[shardformer] GPT-J policy dev 5 Sep
ppt0011 Sep 5, 2023
0204ef3
[shardformer] implement shard policy for base gpt-j model
ppt0011 Sep 6, 2023
472faa5
[shardformer] implement shard policy for base gpt-j model 06 Sep
ppt0011 Sep 6, 2023
a879ccb
[shardformer] implement policy for all GPT-J models and test
ppt0011 Sep 7, 2023
d0e6939
[shardformer] test GPT-J sharding policy
ppt0011 Sep 8, 2023
1f237cc
[shardformer] finished testing pp gpt-j sharding
ppt0011 Sep 9, 2023
c013a0d
[shardformer] clean up for pr
ppt0011 Sep 11, 2023
ec9cba2
[shardformer] support all GPT-J shard former techniques except lazy init
ppt0011 Sep 11, 2023
dde39a4
[shardformer] support interleaved pipeline parallel for bert finetune…
ppt0011 Sep 26, 2023
eb496e0
[shardformer] generlize interleave solution for non bert model 26 Sep
ppt0011 Sep 26, 2023
f631c3e
[shardformer] interleaved pipeline parallel for bert fine tune example
ppt0011 Sep 27, 2023
ad33868
[shardformer] refactor interleave implementation so that replaced for…
ppt0011 Sep 29, 2023
de53d0e
[shardformer] move layer attr to stage manager, and style changes
ppt0011 Oct 11, 2023
341effc
[shardformer] sync gptj config with hf due to flash attn head dim req…
ppt0011 Oct 12, 2023
a0b7877
Merge pull request #4825 from ppt0011/feature/shardformer
ppt0011 Oct 13, 2023
29ebe11
[shardformer] shardformer support falcon (#4883)
flybird11111 Oct 17, 2023
013fb01
[shardformer] increase micro batch size due to convergence issue
ppt0011 Oct 31, 2023
53fe53c
increase ci timeout time after discussion
ppt0011 Nov 1, 2023
0995bba
Merge pull request #4834 from ppt0011/bert-finetune
ppt0011 Nov 2, 2023
ae187e0
[shardformer]: fix interleaved pipeline for bert model (#5048)
CWHer Nov 17, 2023
9d5e04d
[hotfix]: disable seq parallel for gptj and falcon, and polish code (…
CWHer Nov 22, 2023
2e04af1
Add Mistral support for Shardformer (#5103)
eric8607242 Nov 24, 2023
c8420cd
[shardformer] add tests to mistral (#5105)
flybird11111 Nov 26, 2023
2cca8b7
[hotfix]: polish code and use packaging.version (#5112)
CWHer Nov 27, 2023
08e868e
[hotfix]: remove unused code and deprecate colo_init_context (#5117)
CWHer Nov 28, 2023
694adb3
[hotfix]: add memory utils to avoid import legacy module (#5123)
CWHer Nov 28, 2023
a12418c
[hotfix]: add NPU TODO mark (#5124)
CWHer Nov 28, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/example_check_on_pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ jobs:
container:
image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
options: --gpus all --rm -v /data/scratch/examples-data:/data/
timeout-minutes: 10
timeout-minutes: 20
CWHer marked this conversation as resolved.
Show resolved Hide resolved
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-run-example-${{ matrix.directory }}
cancel-in-progress: true
Expand Down
33 changes: 29 additions & 4 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO
from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.shardformer.policies.base_policy import Policy
Expand Down Expand Up @@ -317,6 +317,8 @@ class HybridParallelPlugin(PipelinePluginBase):
communication_dtype (torch.dtype, optional): Communication dtype when using ZeRO. If not specified, the dtype of param will be used. Defaults to None.
overlap_communication (bool, optional): Whether to overlap communication and computation when using ZeRO. Defaults to True.
custom_policy (Policy, optional): Custom policy for Shardformer. Defaults to None.
pp_style (str, optional): The style for pipeline parallelism. Defaults to '1f1b'.
num_model_chunks (int, optional): The number of model chunks for interleaved pipeline parallelism. Defaults to 1.
"""

def __init__(
Expand Down Expand Up @@ -352,6 +354,8 @@ def __init__(
communication_dtype: Optional[torch.dtype] = None,
overlap_communication: bool = True,
custom_policy: Policy = None,
pp_style: str = "1f1b",
num_model_chunks: int = 1,
) -> None:
super().__init__()
assert (
Expand All @@ -378,17 +382,38 @@ def __init__(
self.custom_policy = custom_policy
assert zero_stage in (0, 1, 2)
if self.pp_size > 1:
assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style"
assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b"
assert (
num_microbatches is not None or microbatch_size is not None
), "num_microbatches or microbatch_size must be specified when using pipeline parallelism"
assert self.zero_stage <= 1, "zero stage must be 0 or 1 when using pipeline parallelism"
self.stage_manager = PipelineStageManager(self.pg_mesh, PP_AXIS)
self.schedule = OneForwardOneBackwardSchedule(
self.stage_manager, num_microbatches=num_microbatches, microbatch_size=microbatch_size
self.stage_manager = PipelineStageManager(
self.pg_mesh,
pipeline_axis=PP_AXIS,
enable_interleave=pp_style == "interleaved",
num_model_chunks=num_model_chunks,
)

if pp_style == "interleaved":
assert num_model_chunks > 1, "number of model chunks must be > 1 when using interleaved"
self.schedule = InterleavedSchedule(
stage_manager=self.stage_manager,
num_model_chunks=num_model_chunks,
num_microbatch=num_microbatches,
microbatch_size=microbatch_size,
)
elif pp_style == "1f1b":
self.schedule = OneForwardOneBackwardSchedule(
self.stage_manager, num_microbatches=num_microbatches, microbatch_size=microbatch_size
)
else:
raise NotImplementedError()

self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS)
self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS)

self.shard_config = ShardConfig(
tensor_parallel_process_group=self.tp_group,
pipeline_stage_manager=self.stage_manager,
Expand Down
3 changes: 3 additions & 0 deletions colossalai/legacy/zero/gemini/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .colo_init_context import ColoInitContext, post_process_colo_init_ctx
from .ophooks import BaseOpHook, register_ophooks_recursively
from .stateful_tensor import StatefulTensor
from .stateful_tensor_mgr import StatefulTensorMgr
Expand All @@ -11,4 +12,6 @@
"AutoTensorPlacementPolicy",
"register_ophooks_recursively",
"BaseOpHook",
"ColoInitContext",
"post_process_colo_init_ctx",
]
Loading