Skip to content

Commit

Permalink
[fp8] use torch compile (torch >= 2.3.0) (hpcaitech#5979)
Browse files Browse the repository at this point in the history
* [fp8] use torch compile (torch >= 2.4.0)

* [fp8] set use_fast_accum in linear

* [chore] formal version check

* [chore] fix sig
  • Loading branch information
botbw authored Aug 9, 2024
1 parent 8241c0c commit e4aadee
Showing 1 changed file with 22 additions and 5 deletions.
27 changes: 22 additions & 5 deletions colossalai/quantization/fp8.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from typing import Any, Optional
from typing import Any, Optional, Tuple

import numpy as np
import torch
import torch.distributed as dist
import torch.nn.functional as F
from packaging.version import Version
from torch.distributed import ReduceOp


def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) -> (torch.Tensor, torch.Tensor):
def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
casting torch Tensor into specified fp8 tensor with per-channel scaling or per-tensor scaling.
Args:
Expand Down Expand Up @@ -624,7 +625,13 @@ def forward(
ctx.inv_scale_x = inv_scale_x
ctx.inv_scale_w = inv_scale_w
out = torch._scaled_mm(
x_fp8, ctx.w_fp8_t, bias=bias, out_dtype=ctx.out_dtype, scale_a=inv_scale_x, scale_b=inv_scale_w
x_fp8,
ctx.w_fp8_t,
bias=bias,
out_dtype=ctx.out_dtype,
scale_a=inv_scale_x,
scale_b=inv_scale_w,
use_fast_accum=True,
)[0]
return out.reshape(*ctx.x_shape[:-1], w.shape[0])

Expand All @@ -638,19 +645,29 @@ def backward(ctx: Any, out_grad) -> Any:
out_dtype=ctx.out_dtype,
scale_a=out_grad_scale,
scale_b=ctx.inv_scale_w,
use_fast_accum=True,
)[0]
w_grad = torch._scaled_mm(
out_grad_fp8.t().contiguous(),
ctx.x_fp8.t().contiguous().t(),
out_dtype=ctx.out_dtype,
scale_a=out_grad_scale,
scale_b=ctx.inv_scale_x,
use_fast_accum=True,
)[0]
bias_grad = None
if ctx.has_bias:
bias_grad = out_grad.sum(0)
return x_grad.reshape(ctx.x_shape), w_grad, bias_grad


def linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return _LinearFp8.apply(input, weight, bias)
if Version(torch.__version__) >= Version("2.3.0"): # TODO failed on torch < 2.3.0

@torch.compile(mode="reduce-overhead", fullgraph=True)
def linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return _LinearFp8.apply(input, weight, bias)

else:

def linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return _LinearFp8.apply(input, weight, bias)

0 comments on commit e4aadee

Please sign in to comment.