From e4aadeee20b248dd68b5aa7447b8eb1f127a259b Mon Sep 17 00:00:00 2001 From: botbw Date: Fri, 9 Aug 2024 15:51:06 +0800 Subject: [PATCH] [fp8] use torch compile (torch >= 2.3.0) (#5979) * [fp8] use torch compile (torch >= 2.4.0) * [fp8] set use_fast_accum in linear * [chore] formal version check * [chore] fix sig --- colossalai/quantization/fp8.py | 27 ++++++++++++++++++++++----- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index 53febd16c8f6..cfbf1fcf7e40 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -1,13 +1,14 @@ -from typing import Any, Optional +from typing import Any, Optional, Tuple import numpy as np import torch import torch.distributed as dist import torch.nn.functional as F +from packaging.version import Version from torch.distributed import ReduceOp -def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) -> (torch.Tensor, torch.Tensor): +def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) -> Tuple[torch.Tensor, torch.Tensor]: r""" casting torch Tensor into specified fp8 tensor with per-channel scaling or per-tensor scaling. Args: @@ -624,7 +625,13 @@ def forward( ctx.inv_scale_x = inv_scale_x ctx.inv_scale_w = inv_scale_w out = torch._scaled_mm( - x_fp8, ctx.w_fp8_t, bias=bias, out_dtype=ctx.out_dtype, scale_a=inv_scale_x, scale_b=inv_scale_w + x_fp8, + ctx.w_fp8_t, + bias=bias, + out_dtype=ctx.out_dtype, + scale_a=inv_scale_x, + scale_b=inv_scale_w, + use_fast_accum=True, )[0] return out.reshape(*ctx.x_shape[:-1], w.shape[0]) @@ -638,6 +645,7 @@ def backward(ctx: Any, out_grad) -> Any: out_dtype=ctx.out_dtype, scale_a=out_grad_scale, scale_b=ctx.inv_scale_w, + use_fast_accum=True, )[0] w_grad = torch._scaled_mm( out_grad_fp8.t().contiguous(), @@ -645,6 +653,7 @@ def backward(ctx: Any, out_grad) -> Any: out_dtype=ctx.out_dtype, scale_a=out_grad_scale, scale_b=ctx.inv_scale_x, + use_fast_accum=True, )[0] bias_grad = None if ctx.has_bias: @@ -652,5 +661,13 @@ def backward(ctx: Any, out_grad) -> Any: return x_grad.reshape(ctx.x_shape), w_grad, bias_grad -def linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - return _LinearFp8.apply(input, weight, bias) +if Version(torch.__version__) >= Version("2.3.0"): # TODO failed on torch < 2.3.0 + + @torch.compile(mode="reduce-overhead", fullgraph=True) + def linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: + return _LinearFp8.apply(input, weight, bias) + +else: + + def linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: + return _LinearFp8.apply(input, weight, bias)