Skip to content

Commit

Permalink
fix the contiguous kernel error and prep for eora release
Browse files Browse the repository at this point in the history
  • Loading branch information
nbasyl committed Mar 3, 2025
1 parent 1ef13b1 commit bf9aa2f
Show file tree
Hide file tree
Showing 11 changed files with 13 additions and 141 deletions.
51 changes: 0 additions & 51 deletions eora_no_bug.py

This file was deleted.

File renamed without changes.
6 changes: 3 additions & 3 deletions gptqmodel/nn_modules/qlinear/exllama_eora.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ def gptq_gemm(x, qweight, qzeros, scales, g_idx, bit):
return gptqmodel_exllama_eora.gptq_gemm(x, qweight, qzeros, scales, g_idx, True, bit)


def gptq_gemm_lora(x, qweight, qzeros, scales, g_idx, bit, A, B):
return gptqmodel_exllama_eora.gptq_gemm_lora(x, qweight, qzeros, scales, g_idx, True, bit, A, B)
def gptq_gemm_eora(x, qweight, qzeros, scales, g_idx, bit, A, B):
return gptqmodel_exllama_eora.gptq_gemm_eora(x, qweight, qzeros, scales, g_idx, True, bit, A, B)

def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor,
bit: int) -> None:
Expand Down Expand Up @@ -184,7 +184,7 @@ def forward(self, x):
if self.adapter:
# only 4 bits fused eora kernel has been validated
if self.bits == 4:
output = gptq_gemm_lora(x, self.qweight, self.qzeros, self.scales, self.g_idx, self.bits, x @ self.adapter.lora_A, self.adapter.lora_B) # fused
output = gptq_gemm_eora(x, self.qweight, self.qzeros, self.scales, self.g_idx, self.bits, x @ self.adapter.lora_A, self.adapter.lora_B) # fused
else:
output = gptq_gemm(reshaped_x, self.qweight, self.qzeros, self.scales, self.g_idx, self.bits).add_((reshaped_x @ self.adapter.lora_A) @ self.adapter.lora_B) # normal
else:
Expand Down
79 changes: 2 additions & 77 deletions gptqmodel_ext/exllama_eora/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,83 +19,8 @@ To see the delta between the proposed and the original implementation one can di
- `python3 benchmark.py` # benchmarking

### Benchmarking results:
Speedup ranging between 2.05x and 1.09x is observed for batch sizes ranging from 1 to 8 on a single RTX 3090 GPU.
Speedup ranging between ~2.3x and 1.2x is observed for batch sizes ranging from 1 to 16 on a single RTX 3090 GPU.
Speedup ranging between ~2.5x and 1.2x is observed for batch sizes ranging from 1 to 16 on a single H100.
The baseline is `gptq kernel + pytorch for LORA` is compared with `gptq eora kernel`.
```bash
gptq-eora ➜ python3 ./benchmark.py t 1
pytorch baseline: 0.10021328926086426 msec
pytorch LORA baseline: 0.11120986938476562 msec
pytorch baseline: 0.07351875305175781 msec
pytorch LORA baseline: 0.0958395004272461 msec
gptq: 0.018501758575439453 msec
gptq + pytorch for LORA: 0.04210519790649414 msec
gptq eora kernel: 0.020452022552490234 msec
gptq+pytorch/fused_kernel ratio for batch size 1: 2.0587302697535614
pytorch_lora/fused_kernel ratio for batch size 1: 4.686064675572964

pytorch baseline: 0.09366106986999512 msec
pytorch LORA baseline: 0.12542033195495605 msec
gptq: 0.019073963165283203 msec
gptq + pytorch for LORA: 0.043236494064331055 msec
gptq eora kernel: 0.02179884910583496 msec
gptq+pytorch/fused_kernel ratio for batch size 2: 1.9834301276372346
pytorch_lora/fused_kernel ratio for batch size 2: 5.7535299843597905

pytorch baseline: 0.09362173080444336 msec
pytorch LORA baseline: 0.12170100212097168 msec
gptq: 0.019705533981323242 msec
gptq + pytorch for LORA: 0.0429532527923584 msec
gptq eora kernel: 0.023361921310424805 msec
gptq+pytorch/fused_kernel ratio for batch size 3: 1.8386010389133252
pytorch_lora/fused_kernel ratio for batch size 3: 5.209374712972129

pytorch baseline: 0.09506535530090332 msec
pytorch LORA baseline: 0.1078331470489502 msec
gptq: 0.020968198776245117 msec
gptq + pytorch for LORA: 0.04309487342834473 msec
gptq eora kernel: 0.025162220001220703 msec
gptq+pytorch/fused_kernel ratio for batch size 4: 1.7126816881123388
pytorch_lora/fused_kernel ratio for batch size 4: 4.285518012469442

pytorch baseline: 0.09542036056518555 msec
pytorch LORA baseline: 0.1076815128326416 msec
gptq: 0.022510766983032227 msec
gptq + pytorch for LORA: 0.052427053451538086 msec
gptq eora kernel: 0.028439998626708984 msec
gptq+pytorch/fused_kernel ratio for batch size 5: 1.843426722331204
pytorch_lora/fused_kernel ratio for batch size 5: 3.7862699730060525

pytorch baseline: 0.09557318687438965 msec
pytorch LORA baseline: 0.10774064064025879 msec
gptq: 0.025467395782470703 msec
gptq + pytorch for LORA: 0.04637646675109863 msec
gptq eora kernel: 0.033232927322387695 msec
gptq+pytorch/fused_kernel ratio for batch size 6: 1.395497492628543
pytorch_lora/fused_kernel ratio for batch size 6: 3.241984661630401

pytorch baseline: 0.09484624862670898 msec
pytorch LORA baseline: 0.10790395736694336 msec
gptq: 0.02785944938659668 msec
gptq + pytorch for LORA: 0.04564833641052246 msec
gptq eora kernel: 0.03971362113952637 msec
gptq+pytorch/fused_kernel ratio for batch size 7: 1.149437777284161
pytorch_lora/fused_kernel ratio for batch size 7: 2.717051587611289

pytorch baseline: 0.0950167179107666 msec
pytorch LORA baseline: 0.10870051383972168 msec
gptq: 0.029795169830322266 msec
gptq + pytorch for LORA: 0.044673919677734375 msec
gptq eora kernel: 0.04362607002258301 msec
gptq+pytorch/fused_kernel ratio for batch size 8: 1.0240188872068685
pytorch_lora/fused_kernel ratio for batch size 8: 2.4916412086500785

pytorch baseline: 0.09513998031616211 msec
pytorch LORA baseline: 0.10854911804199219 msec
gptq: 0.04927778244018555 msec
gptq + pytorch for LORA: 0.05824875831604004 msec
gptq eora kernel: 0.06363630294799805 msec
gptq+pytorch/fused_kernel ratio for batch size 9: 0.9153385036154509
pytorch_lora/fused_kernel ratio for batch size 9: 1.7057734816979506
```


Binary file added gptqmodel_ext/exllama_eora/benchmark_results.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion gptqmodel_ext/exllama_eora/eora/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
torch::Tensor b_gptq_scales, torch::Tensor b_g_idx,
bool use_exllama, int64_t bit);

torch::Tensor gptq_gemm_lora(torch::Tensor a, torch::Tensor b_q_weight,
torch::Tensor gptq_gemm_eora(torch::Tensor a, torch::Tensor b_q_weight,
torch::Tensor b_gptq_qzeros,
torch::Tensor b_gptq_scales, torch::Tensor b_g_idx,
bool use_exllama, int64_t bit,
Expand Down
2 changes: 1 addition & 1 deletion gptqmodel_ext/exllama_eora/eora/pybind.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("gptq_gemm", &gptq_gemm, "gptq_gemm")
.def("gptq_gemm_lora", &gptq_gemm_lora, "gptq_gemm_lora")
.def("gptq_gemm_eora", &gptq_gemm_eora, "gptq_gemm_eora")
.def("gptq_shuffle", &gptq_shuffle, "gptq_shuffle")
;
}
2 changes: 1 addition & 1 deletion gptqmodel_ext/exllama_eora/eora/q_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2099,7 +2099,7 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
return c;
}

torch::Tensor gptq_gemm_lora(torch::Tensor a, torch::Tensor b_q_weight,
torch::Tensor gptq_gemm_eora(torch::Tensor a, torch::Tensor b_q_weight,
torch::Tensor b_gptq_qzeros,
torch::Tensor b_gptq_scales, torch::Tensor b_g_idx,
bool use_exllama, int64_t bit,
Expand Down
4 changes: 2 additions & 2 deletions gptqmodel_ext/exllama_eora/test_actual_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import gptqmodel_exllama_eora
import torch
# from eora_test import fused_concurrent, fused_sequential, cublas_reference, gptq_gemm_eora, gptq_gemm
from gptqmodel_exllama_eora import gptq_gemm, gptq_gemm_lora
from gptqmodel_exllama_eora import gptq_gemm, gptq_gemm_eora
from gptqmodel_exllama_kernels import make_q4, q4_matmul
from safetensors import safe_open

Expand Down Expand Up @@ -255,7 +255,7 @@ def test_eora_kernel():

gptq_pytorch_out = gptq_gemm(reshaped_x, weight, zeros, scales, idx, use_exllama, bit) + (ax @ eora_b)

gptq_eora_fused_out = gptq_gemm_lora(reshaped_x, weight, zeros, scales, idx, use_exllama, bit, ax, eora_b)
gptq_eora_fused_out = gptq_gemm_eora(reshaped_x, weight, zeros, scales, idx, use_exllama, bit, ax, eora_b)
torch.set_printoptions(precision=6)
# print("gptq exllama kernel out: ")
# print(exllama_out[0][:10])
Expand Down
4 changes: 2 additions & 2 deletions gptqmodel_ext/exllama_eora/test_eora.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch
# from eora import fused_concurrent, fused_sequential, cublas_reference, gptq_gemm_eora, gptq_gemm
from gptqmodel_exllama_eora import gptq_gemm, gptq_gemm_lora
from gptqmodel_exllama_eora import gptq_gemm, gptq_gemm_eora

m = 1
k = 4096
Expand Down Expand Up @@ -34,7 +34,7 @@ def test_eora_kernel():
gptq_pytorch_out = gptq_gemm(x, weight, zeros, scales, idx, use_exllama, bit) + (ax @ eora_b)
print("gptq_pytorch_out: ")
print(gptq_pytorch_out[0][:10])
gptq_eora_fused_out = gptq_gemm_lora(x, weight, zeros, scales, idx, use_exllama, bit, ax, eora_b)
gptq_eora_fused_out = gptq_gemm_eora(x, weight, zeros, scales, idx, use_exllama, bit, ax, eora_b)
print("gptq_eora_fused_out: ")
print(gptq_eora_fused_out[0][:10])
torch.testing.assert_close(gptq_pytorch_out, gptq_eora_fused_out, rtol=0.05, atol=0.5) # 5 % relative tolerance, 0.5 absolute tolerance
Expand Down
4 changes: 1 addition & 3 deletions tests/test_kernel_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

class TestKernelOutput(unittest.TestCase):
model_path = "sliuau/llama3.2-1b-4bit-group128" # hf "sliuau/llama3.2-1b-4bit-group128"

target_qliner_map = {
BACKEND.EXLLAMA_V1: ExllamaQuantLinear,
BACKEND.EXLLAMA_EORA: ExllamaEoraQuantLinear,
Expand All @@ -41,7 +40,6 @@ class TestKernelOutput(unittest.TestCase):
def setUpClass(cls):
lora_path = "sliuau/llama3.2-1b-4bit-group128-eora-rank128-arc" # adapter_model.safetensors
# hf "sliuau-llama3.2-1b-4bit-group128/llama3.2-1b-4bit-group128-eora-rank128-arc/"

cls.m = 1
cls.k = -1
cls.x = None # random X input of shape (m, k)
Expand Down Expand Up @@ -108,7 +106,7 @@ def test_kernel_output(self, backend: BACKEND, r_tolerance: float, a_tolerance:
(BACKEND.EXLLAMA_V2, 0.16, 0.0003),
(BACKEND.MARLIN, 0.00001, 0.00003),
(BACKEND.MARLIN_FP16, 0.0001, 0.0035),
(BACKEND.EXLLAMA_EORA, 0.00001, 0.00147)
(BACKEND.EXLLAMA_EORA, 0.0001, 0.0035)
])
def test_kernel_output_with_lora(self, backend: BACKEND, r_tolerance: float, a_tolerance: float):
out = self.forward(backend=backend, adapter=self.adapter)
Expand Down

0 comments on commit bf9aa2f

Please sign in to comment.