Skip to content

Commit

Permalink
[fp8] support gemini plugin (hpcaitech#5978)
Browse files Browse the repository at this point in the history
* [fp8] refactor hook

* [fp8] support gemini plugin

* [example] add fp8 option for llama benchmark
  • Loading branch information
ver217 authored Aug 9, 2024
1 parent 4b9bec8 commit 8241c0c
Show file tree
Hide file tree
Showing 7 changed files with 21 additions and 7 deletions.
2 changes: 2 additions & 0 deletions colossalai/booster/plugin/gemini_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,7 @@ def __init__(
enable_jit_fused: bool = False,
enable_sequence_overlap: bool = False,
enable_async_reduce: bool = True,
use_fp8: bool = False,
verbose: bool = False,
fp8_communication: bool = False,
) -> None:
Expand Down Expand Up @@ -397,6 +398,7 @@ def __init__(
max_prefetch=max_prefetch,
enable_async_reduce=enable_async_reduce,
fp8_communication=fp8_communication,
use_fp8=use_fp8,
)
self.zero_optim_config = dict(
gpu_margin_mem_ratio=gpu_margin_mem_ratio,
Expand Down
2 changes: 1 addition & 1 deletion colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.quantization import BnbQuantizationConfig, quantize_model
from colossalai.quantization.fp8_hook import FP8Hook
from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer
from colossalai.shardformer.layer.utils import SeqParallelUtils
from colossalai.shardformer.policies.base_policy import Policy
Expand All @@ -40,7 +41,6 @@
from colossalai.zero.low_level import LowLevelZeroOptimizer
from colossalai.zero.low_level.zero_hook import ZeroOpHook, wait_all_gather_handle

from .fp8_hook import FP8Hook
from .pp_plugin_base import PipelinePluginBase

SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all"]
Expand Down
4 changes: 2 additions & 2 deletions colossalai/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,5 +652,5 @@ def backward(ctx: Any, out_grad) -> Any:
return x_grad.reshape(ctx.x_shape), w_grad, bias_grad


def linear_fp8(x: torch.Tensor, w: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return _LinearFp8.apply(x, w, bias)
def linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return _LinearFp8.apply(input, weight, bias)
File renamed without changes.
11 changes: 8 additions & 3 deletions colossalai/zero/gemini/gemini_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from colossalai.interface import ModelWrapper
from colossalai.lazy import LazyTensor
from colossalai.logging import get_dist_logger
from colossalai.quantization.fp8_hook import FP8Hook
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.tensor.d_tensor import (
distribute_tensor,
Expand Down Expand Up @@ -99,6 +100,7 @@ def __init__(
verbose: bool = False,
enable_async_reduce: bool = True,
fp8_communication: bool = False,
use_fp8: bool = False,
) -> None:
assert mixed_precision in (torch.float16, torch.bfloat16)
reuse_fp16_chunk = master_weights if not enable_gradient_accumulation else False
Expand Down Expand Up @@ -138,6 +140,9 @@ def __init__(
)
self.force_outputs_fp32 = force_outputs_fp32
self.param_op_hook = GeminiZeROHook(self.gemini_manager)
self.hooks = [self.param_op_hook]
if use_fp8:
self.hooks.append(FP8Hook())
self.fp32_params: List[torch.Tensor] = list()
self.fp16_params: List[ColoParameter] = list()
self.grads_device: Dict[torch.Tensor, torch.device] = dict()
Expand Down Expand Up @@ -310,7 +315,7 @@ def forward(self, *args, **kwargs):
outputs = self._inference_forward(*args, **kwargs)
else:
self.gemini_manager.pre_iter(*args)
with ColoParamOpHookManager.use_hooks(self.param_op_hook):
with ColoParamOpHookManager.use_hooks(*self.hooks):
outputs = self.module(*args, **kwargs)

if self.force_outputs_fp32:
Expand All @@ -319,7 +324,7 @@ def forward(self, *args, **kwargs):

def _inference_forward(self, *args, **kwargs):
"""This function is only triggered for inference."""
fwd_ctx = ColoParamOpHookManager.use_hooks(self.param_op_hook)
fwd_ctx = ColoParamOpHookManager.use_hooks(*self.hooks)
if not self.scatter_after_inference:
# gather all chunks
for chunk in self.chunk_manager.get_chunks(self.fp16_params):
Expand Down Expand Up @@ -372,7 +377,7 @@ def _post_backward(self):

def backward(self, loss: torch.Tensor):
self._pre_backward()
with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook):
with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(*self.hooks):
loss.backward()
self._post_backward()

Expand Down
7 changes: 7 additions & 0 deletions examples/language/llama/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ def main():
parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number")
parser.add_argument("--no_cache", action="store_true")
parser.add_argument("--use_fp8_comm", action="store_true", default=False, help="for using fp8 during communication")
parser.add_argument("--overlap_allgather", action="store_true")
parser.add_argument("--use_fp8", action="store_true")
args = parser.parse_args()

colossalai.launch_from_torch()
Expand Down Expand Up @@ -136,6 +138,7 @@ def empty_init():
enable_flash_attention=args.xformers,
max_prefetch=args.prefetch_num,
enable_async_reduce=not args.disable_async_reduce,
use_fp8=args.use_fp8,
)
elif args.plugin == "gemini_auto":
plugin = GeminiPlugin(
Expand All @@ -148,6 +151,7 @@ def empty_init():
max_prefetch=args.prefetch_num,
enable_async_reduce=not args.disable_async_reduce,
enable_flash_attention=args.xformers,
use_fp8=args.use_fp8,
)
elif args.plugin == "fsdp":
if use_empty_init:
Expand Down Expand Up @@ -207,6 +211,8 @@ def empty_init():
dp_outside=False,
overlap_p2p=args.overlap,
enable_metadata_cache=not args.no_cache,
overlap_allgather=args.overlap_allgather,
use_fp8=args.use_fp8,
**hybrid_kwargs,
)
elif args.plugin == "3d_cpu":
Expand All @@ -223,6 +229,7 @@ def empty_init():
initial_scale=2**8,
precision="bf16",
overlap_p2p=args.overlap,
use_fp8=args.use_fp8,
)
else:
raise ValueError(f"Unknown plugin {args.plugin}")
Expand Down
2 changes: 1 addition & 1 deletion tests/test_fp8/test_fp8_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import torch.nn.functional as F

from colossalai.accelerator import get_accelerator
from colossalai.booster.plugin.fp8_hook import FP8Hook
from colossalai.quantization.fp8 import linear_fp8
from colossalai.quantization.fp8_hook import FP8Hook
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.utils import get_current_device
Expand Down

0 comments on commit 8241c0c

Please sign in to comment.