Skip to content

Commit

Permalink
Fix lint issues
Browse files Browse the repository at this point in the history
  • Loading branch information
Xia-Weiwen committed Apr 25, 2024
1 parent 717245d commit 93e04b5
Show file tree
Hide file tree
Showing 7 changed files with 42 additions and 56 deletions.
5 changes: 3 additions & 2 deletions bitsandbytes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.

import torch

from . import research, utils
from .autograd._functions import (
MatmulLtState,
Expand All @@ -13,11 +14,11 @@
matmul_cublas,
mm_cublas,
)
from .backends import register_backend
from .backends.cpu import CPUBackend
from .cextension import lib
from .nn import modules
from .backends import register_backend

from .backends.cpu import CPUBackend
register_backend("cpu", CPUBackend)

if lib and lib.compiled_with_cuda:
Expand Down
4 changes: 2 additions & 2 deletions bitsandbytes/autograd/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def backward(ctx, grad_output):

def supports_igemmlt(device: torch.device) -> bool:
"""check if this device supports the optimized int8 kernel"""
if device == torch.device('cpu'):
if device == torch.device("cpu"):
return True
if torch.cuda.get_device_capability(device=device) < (7, 5):
return False
Expand Down Expand Up @@ -315,7 +315,7 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState):

# Cast A to fp16
A_dtype = torch.float16
if A.device == torch.device('cpu'):
if A.device == torch.device("cpu"):
A_dtype = torch.bfloat16
if A.dtype != A_dtype:
warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to {A_dtype} during quantization")
Expand Down
31 changes: 11 additions & 20 deletions bitsandbytes/backends/cpu.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,24 @@
import torch

from .cpu_xpu_common import (
double_quant_impl,
igemmlt_impl,
mm_dequant_impl,
)


Tensor = torch.Tensor


def assert_on_cpu(tensors):
on_cpu = True
for t in tensors:
if t is None: continue # NULL pointers are fine
on_cpu &= (t.device.type == 'cpu')
if t is None:
continue # NULL pointers are fine
on_cpu &= t.device.type == "cpu"
if not on_cpu:
raise TypeError(
'All input tensors need to be on CPU, but found some tensors to not be on CPU:\n' \
f' {[(t.shape, t.device) if isinstance(t, Tensor) else None for t in tensors]}'
"All input tensors need to be on CPU, but found some tensors to not be on CPU:\n"
f" {[(t.shape, t.device) if isinstance(t, Tensor) else None for t in tensors]}"
)
return on_cpu

Expand All @@ -27,14 +28,12 @@ class CPUBackend:
mm_dequant_output_dtype = torch.bfloat16

@classmethod
def double_quant(
cls, A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0
):
def double_quant(cls, A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0):
assert_on_cpu([A, col_stats, row_stats, out_col, out_row])
return double_quant_impl(A, col_stats, row_stats, out_col, out_row, threshold)

@classmethod
def transform(cls, A, to_order=None, from_order='row', out=None, transpose=False, state=None, ld=None):
def transform(cls, A, to_order=None, from_order="row", out=None, transpose=False, state=None, ld=None):
"""
Transform tensor A to to_order. It is originally designed for CUDA.
For CPU, it returns the original tensor if transpose=False.
Expand All @@ -60,15 +59,7 @@ def igemmlt(cls, A, B, SA=None, SB=None, out=None, Sout=None, dtype=torch.int32)

@classmethod
def mm_dequant(
cls,
A,
quant_state,
row_stats,
col_stats,
out=None,
new_row_stats=None,
new_col_stats=None,
bias=None
cls, A, quant_state, row_stats, col_stats, out=None, new_row_stats=None, new_col_stats=None, bias=None
):
assert_on_cpu([A, row_stats, col_stats, out, bias])
return mm_dequant_impl(
Expand All @@ -81,7 +72,7 @@ def mm_dequant(
new_col_stats,
bias,
cls.mm_dequant_compute_dtype,
cls.mm_dequant_output_dtype
cls.mm_dequant_output_dtype,
)

@classmethod
Expand All @@ -108,7 +99,7 @@ def quantize_4bit(
def dequantize_4bit(
cls,
A: Tensor,
quant_state = None,
quant_state=None,
absmax: Tensor = None,
out: Tensor = None,
blocksize: int = 64,
Expand Down
43 changes: 20 additions & 23 deletions bitsandbytes/backends/cpu_xpu_common.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import torch
import warnings

import torch

Tensor = torch.Tensor


def _torch_version_prereq(major, minor):
ver_major = int(torch.__version__.split('.')[0])
ver_minor = int(torch.__version__.split('.')[1])
ver_major = int(torch.__version__.split(".")[0])
ver_minor = int(torch.__version__.split(".")[1])
return ver_major * 32 + ver_minor >= major * 32 + minor


Expand All @@ -23,14 +23,12 @@ def _maybe_torch_compile(func):


# Don't use torch.compile for now due to PyTorch issue https://github.com/pytorch/pytorch/issues/124382
def double_quant_impl(
A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0
):
def double_quant_impl(A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0):
"""
Find absolute max valus of each row/column of a tensor, and symmetrically quantize it to int8.
If threshold > 0.0, only values <= threshold are counted. All outliers are zeroed out in
the original tensor and they are kept in COO format: (rows, cols, valus)
If threashold == 0.0, there are no outliers.
If threshold == 0.0, there are no outliers.
Args:
A The tensor to be analyzed and quantized.
col_stats Absolute max values of each column of A. If it is not None, use the values directly.
Expand All @@ -45,6 +43,7 @@ def double_quant_impl(
each row of A, absolute max values of each column of A, outliers in COO format
"""
from ..functional import COOSparseTensor

cols = A.shape[-1]
if len(A.shape) == 3:
rows = A.shape[0] * A.shape[1]
Expand All @@ -56,8 +55,8 @@ def double_quant_impl(
coo_tensor = None

def get_row_col_stats(A):
row_stats = torch.max(torch.abs(A), 1).values # absolute max of each row
col_stats = torch.max(torch.abs(A), 0).values # absolute max of each col
row_stats = torch.max(torch.abs(A), 1).values # absolute max of each row
col_stats = torch.max(torch.abs(A), 0).values # absolute max of each col
return row_stats, col_stats

def quant_to_int8(A, stats):
Expand All @@ -67,23 +66,23 @@ def quant_to_int8(A, stats):
if row_stats is None or col_stats is None:
row_stats, col_stats = get_row_col_stats(A)
else:
outlier_indices = torch.abs(A) >= threshold # find outliers
outlier_coord = outlier_indices.nonzero() # get outlier coordinates
outlier_rows = outlier_coord[:, 0] # outlier row for COO sparse tensor
outlier_cols = outlier_coord[:, 1] # outlier column for COO sparse tensor
outlier_values = A[outlier_indices] # outlier values for COO sparse tensor
outlier_indices = torch.abs(A) >= threshold # find outliers
outlier_coord = outlier_indices.nonzero() # get outlier coordinates
outlier_rows = outlier_coord[:, 0] # outlier row for COO sparse tensor
outlier_cols = outlier_coord[:, 1] # outlier column for COO sparse tensor
outlier_values = A[outlier_indices] # outlier values for COO sparse tensor
coo_tensor = COOSparseTensor(
A.shape[0], A.shape[1], outlier_values.numel(), outlier_rows.int(), outlier_cols.int(), outlier_values
)
if row_stats is None or col_stats is None:
A[outlier_indices] = 0 # zero out outliers
A[outlier_indices] = 0 # zero out outliers
row_stats, col_stats = get_row_col_stats(A)

quant_by_row = quant_to_int8(A, row_stats.unsqueeze(-1))
quant_by_col = quant_to_int8(A, col_stats.unsqueeze(0))

if coo_tensor is not None:
A[outlier_indices] = outlier_values # restore outliers for later use
A[outlier_indices] = outlier_values # restore outliers for later use

if out_row is not None:
out_row.copy_(quant_by_row)
Expand All @@ -97,9 +96,7 @@ def quant_to_int8(A, stats):
return out_row, out_col, row_stats.float(), col_stats.float(), coo_tensor


def igemmlt_impl(
A, B, SA=None, SB=None, out=None, Sout=None, dtype=torch.int32
):
def igemmlt_impl(A, B, SA=None, SB=None, out=None, Sout=None, dtype=torch.int32):
"""
Do GEMMM computation. Data type: int8 * int8 -> int32.
Args:
Expand All @@ -122,16 +119,16 @@ def igemmlt_impl(
dimsB = B.ndim
shapeA = A.shape
shapeB = B.shape
assert dimsA in [2, 3], 'Only two or three dimensional matrices are supported for argument A'
assert dimsB == 2, 'Only two dimensional matrices are supported for argument B'
assert dimsA in [2, 3], "Only two or three dimensional matrices are supported for argument A"
assert dimsB == 2, "Only two dimensional matrices are supported for argument B"

if dimsA == 2:
m = shapeA[0]
elif dimsA == 3:
m = shapeA[0] * shapeA[1]
n = shapeB[0]
k = shapeA[-1]
assert shapeA[-1] == shapeB[-1], f'Shapes of A and B do not match, got {shapeA} and {shapeB}'
assert shapeA[-1] == shapeB[-1], f"Shapes of A and B do not match, got {shapeA} and {shapeB}"

# if the tensor is empty, return a transformed empty tensor with the right dimensions
if shapeA[0] == 0 and dimsA == 2:
Expand Down Expand Up @@ -169,7 +166,7 @@ def mm_dequant_impl(
new_col_stats=None,
bias=None,
compute_dtype=torch.float32,
output_dtype=torch.float32
output_dtype=torch.float32,
):
"""
Dequant and add bias
Expand Down
2 changes: 1 addition & 1 deletion bitsandbytes/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1945,7 +1945,7 @@ class COOSparseTensor:
def __init__(self, rows, cols, nnz, rowidx, colidx, values):
assert rowidx.dtype == torch.int32
assert colidx.dtype == torch.int32
if values.device == torch.device('cpu'):
if values.device == torch.device("cpu"):
assert values.dtype in [torch.bfloat16, torch.float]
else:
assert values.dtype == torch.float16
Expand Down
9 changes: 3 additions & 6 deletions bitsandbytes/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,8 +590,8 @@ def cpu(self):
if SCBt is not None:
del SCBt
self.data = CB
setattr(self, "CB", CB)
setattr(self, "SCB", SCB)
self.CB = CB
self.SCB = SCB
return self

@overload
Expand All @@ -613,10 +613,7 @@ def to(self, *args, **kwargs):

if device.type == "cuda" and self.data.device.type == "cpu":
return self.cuda(device)
elif (
device.type == "cpu"
and self.data.dtype != torch.int8
):
elif device.type == "cpu" and self.data.dtype != torch.int8:
return self.cpu()
else:
new_param = Int8Params(
Expand Down
4 changes: 2 additions & 2 deletions tests/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -928,7 +928,7 @@ def test_colrow_absmax(dim1, dim2, dims):

@pytest.mark.parametrize("dim1", get_test_dims(1, 4 * 1024, n=2), ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", get_test_dims(1, 4 * 1024, n=2), ids=id_formatter("dim2"))
@pytest.mark.parametrize("device", ["cuda","cpu"], ids=id_formatter("device"))
@pytest.mark.parametrize("device", ["cuda", "cpu"], ids=id_formatter("device"))
@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16], ids=id_formatter("dtype"))
def test_double_quant(dim1, dim2, device, dtype):
if device == "cuda" and dtype == torch.bfloat16:
Expand Down Expand Up @@ -1208,7 +1208,7 @@ def test_overflow():

@pytest.mark.parametrize("dim1", get_test_dims(1, 4 * 1024, n=2), ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", get_test_dims(1, 4 * 1024, n=2), ids=id_formatter("dim2"))
@pytest.mark.parametrize("device", ["cuda","cpu"], ids=id_formatter("device"))
@pytest.mark.parametrize("device", ["cuda", "cpu"], ids=id_formatter("device"))
@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16], ids=id_formatter("dtype"))
def test_coo_double_quant(dim1, dim2, device, dtype):
if device == "cuda" and dtype == torch.bfloat16:
Expand Down

0 comments on commit 93e04b5

Please sign in to comment.