diff --git a/eora_no_bug.py b/eora_no_bug.py deleted file mode 100644 index 22fa708a3..000000000 --- a/eora_no_bug.py +++ /dev/null @@ -1,51 +0,0 @@ -import torch -from datasets import load_dataset -from gptqmodel import GPTQModel, QuantizeConfig - -# from gptqmodel.eora import get_eora, get_eora_optimize - - -bit = 4 -model_id = "meta-llama/Llama-3.2-1B" -model = None - -quant_path = "Llama-3.2-1B-gptqmodel-4bit" -fake_quant_path = "Llama-3.2-1B-gptqmodel-4bit-fakequantized/qw.pt" -eora_path = "Llama-3.2-1B-gptqmodel-4bit-eora-rank-128-v2/eora.pt" -quant_config = QuantizeConfig(bits=bit, group_size=128) - - -calibration_dataset = load_dataset( - "allenai/c4", - data_files="en/c4-train.00001-of-01024.json.gz", - split="train" -).select(range(1024))["text"] - -print(f"{type(calibration_dataset)}") - -### 3-bit group_size = 128 leads to out: IndexError: index 192 is out of bounds when packing -model = GPTQModel.load(model_id, quant_config) - -# increase `batch_size` to match gpu/vram specs to speed up quantization -quant_log, quantized_weights = model.quantize(calibration_dataset, batch_size=2) - -model.save(quant_path) - -torch.save(quantized_weights, fake_quant_path) -quantized_weights = torch.load(fake_quant_path, map_location='cpu') - -## 4-bit gs=128 Acc: 0.2850 - -batch_size = 2 -from test_prepare_dataset import construct_ARC - -calibration_dataset = construct_ARC(nsamples=1024) -eora_rank = 128 -model = GPTQModel.load(model_id, quant_config) - -eora_weight = model.get_eora(calibration_dataset, batch_size, quantized_weights, eora_rank) - -torch.save(eora_weight, eora_path) - -eora_weight = torch.load(eora_path, map_location='cpu') -print(eora_weight) diff --git a/eora_load_and_infer.py b/gptqmodel/eora/eora_load_and_infer.py similarity index 100% rename from eora_load_and_infer.py rename to gptqmodel/eora/eora_load_and_infer.py diff --git a/gptqmodel/nn_modules/qlinear/exllama_eora.py b/gptqmodel/nn_modules/qlinear/exllama_eora.py index e957df188..ef770dcfa 100644 --- a/gptqmodel/nn_modules/qlinear/exllama_eora.py +++ b/gptqmodel/nn_modules/qlinear/exllama_eora.py @@ -45,8 +45,8 @@ def gptq_gemm(x, qweight, qzeros, scales, g_idx, bit): return gptqmodel_exllama_eora.gptq_gemm(x, qweight, qzeros, scales, g_idx, True, bit) -def gptq_gemm_lora(x, qweight, qzeros, scales, g_idx, bit, A, B): - return gptqmodel_exllama_eora.gptq_gemm_lora(x, qweight, qzeros, scales, g_idx, True, bit, A, B) +def gptq_gemm_eora(x, qweight, qzeros, scales, g_idx, bit, A, B): + return gptqmodel_exllama_eora.gptq_gemm_eora(x, qweight, qzeros, scales, g_idx, True, bit, A, B) def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, bit: int) -> None: @@ -184,7 +184,7 @@ def forward(self, x): if self.adapter: # only 4 bits fused eora kernel has been validated if self.bits == 4: - output = gptq_gemm_lora(x, self.qweight, self.qzeros, self.scales, self.g_idx, self.bits, x @ self.adapter.lora_A, self.adapter.lora_B) # fused + output = gptq_gemm_eora(x, self.qweight, self.qzeros, self.scales, self.g_idx, self.bits, x @ self.adapter.lora_A, self.adapter.lora_B) # fused else: output = gptq_gemm(reshaped_x, self.qweight, self.qzeros, self.scales, self.g_idx, self.bits).add_((reshaped_x @ self.adapter.lora_A) @ self.adapter.lora_B) # normal else: diff --git a/gptqmodel_ext/exllama_eora/README.md b/gptqmodel_ext/exllama_eora/README.md index a46910731..be9d3279e 100644 --- a/gptqmodel_ext/exllama_eora/README.md +++ b/gptqmodel_ext/exllama_eora/README.md @@ -19,83 +19,8 @@ To see the delta between the proposed and the original implementation one can di - `python3 benchmark.py` # benchmarking ### Benchmarking results: -Speedup ranging between 2.05x and 1.09x is observed for batch sizes ranging from 1 to 8 on a single RTX 3090 GPU. +Speedup ranging between ~2.3x and 1.2x is observed for batch sizes ranging from 1 to 16 on a single RTX 3090 GPU. +Speedup ranging between ~2.5x and 1.2x is observed for batch sizes ranging from 1 to 16 on a single H100. The baseline is `gptq kernel + pytorch for LORA` is compared with `gptq eora kernel`. -```bash -gptq-eora ➜ python3 ./benchmark.py t 1 -pytorch baseline: 0.10021328926086426 msec -pytorch LORA baseline: 0.11120986938476562 msec -pytorch baseline: 0.07351875305175781 msec -pytorch LORA baseline: 0.0958395004272461 msec -gptq: 0.018501758575439453 msec -gptq + pytorch for LORA: 0.04210519790649414 msec -gptq eora kernel: 0.020452022552490234 msec -gptq+pytorch/fused_kernel ratio for batch size 1: 2.0587302697535614 -pytorch_lora/fused_kernel ratio for batch size 1: 4.686064675572964 - -pytorch baseline: 0.09366106986999512 msec -pytorch LORA baseline: 0.12542033195495605 msec -gptq: 0.019073963165283203 msec -gptq + pytorch for LORA: 0.043236494064331055 msec -gptq eora kernel: 0.02179884910583496 msec -gptq+pytorch/fused_kernel ratio for batch size 2: 1.9834301276372346 -pytorch_lora/fused_kernel ratio for batch size 2: 5.7535299843597905 - -pytorch baseline: 0.09362173080444336 msec -pytorch LORA baseline: 0.12170100212097168 msec -gptq: 0.019705533981323242 msec -gptq + pytorch for LORA: 0.0429532527923584 msec -gptq eora kernel: 0.023361921310424805 msec -gptq+pytorch/fused_kernel ratio for batch size 3: 1.8386010389133252 -pytorch_lora/fused_kernel ratio for batch size 3: 5.209374712972129 - -pytorch baseline: 0.09506535530090332 msec -pytorch LORA baseline: 0.1078331470489502 msec -gptq: 0.020968198776245117 msec -gptq + pytorch for LORA: 0.04309487342834473 msec -gptq eora kernel: 0.025162220001220703 msec -gptq+pytorch/fused_kernel ratio for batch size 4: 1.7126816881123388 -pytorch_lora/fused_kernel ratio for batch size 4: 4.285518012469442 - -pytorch baseline: 0.09542036056518555 msec -pytorch LORA baseline: 0.1076815128326416 msec -gptq: 0.022510766983032227 msec -gptq + pytorch for LORA: 0.052427053451538086 msec -gptq eora kernel: 0.028439998626708984 msec -gptq+pytorch/fused_kernel ratio for batch size 5: 1.843426722331204 -pytorch_lora/fused_kernel ratio for batch size 5: 3.7862699730060525 - -pytorch baseline: 0.09557318687438965 msec -pytorch LORA baseline: 0.10774064064025879 msec -gptq: 0.025467395782470703 msec -gptq + pytorch for LORA: 0.04637646675109863 msec -gptq eora kernel: 0.033232927322387695 msec -gptq+pytorch/fused_kernel ratio for batch size 6: 1.395497492628543 -pytorch_lora/fused_kernel ratio for batch size 6: 3.241984661630401 - -pytorch baseline: 0.09484624862670898 msec -pytorch LORA baseline: 0.10790395736694336 msec -gptq: 0.02785944938659668 msec -gptq + pytorch for LORA: 0.04564833641052246 msec -gptq eora kernel: 0.03971362113952637 msec -gptq+pytorch/fused_kernel ratio for batch size 7: 1.149437777284161 -pytorch_lora/fused_kernel ratio for batch size 7: 2.717051587611289 - -pytorch baseline: 0.0950167179107666 msec -pytorch LORA baseline: 0.10870051383972168 msec -gptq: 0.029795169830322266 msec -gptq + pytorch for LORA: 0.044673919677734375 msec -gptq eora kernel: 0.04362607002258301 msec -gptq+pytorch/fused_kernel ratio for batch size 8: 1.0240188872068685 -pytorch_lora/fused_kernel ratio for batch size 8: 2.4916412086500785 - -pytorch baseline: 0.09513998031616211 msec -pytorch LORA baseline: 0.10854911804199219 msec -gptq: 0.04927778244018555 msec -gptq + pytorch for LORA: 0.05824875831604004 msec -gptq eora kernel: 0.06363630294799805 msec -gptq+pytorch/fused_kernel ratio for batch size 9: 0.9153385036154509 -pytorch_lora/fused_kernel ratio for batch size 9: 1.7057734816979506 -``` diff --git a/gptqmodel_ext/exllama_eora/benchmark_results.png b/gptqmodel_ext/exllama_eora/benchmark_results.png new file mode 100644 index 000000000..4509e6e3b Binary files /dev/null and b/gptqmodel_ext/exllama_eora/benchmark_results.png differ diff --git a/gptqmodel_ext/exllama_eora/eora/ops.h b/gptqmodel_ext/exllama_eora/eora/ops.h index be28d9745..45f137449 100644 --- a/gptqmodel_ext/exllama_eora/eora/ops.h +++ b/gptqmodel_ext/exllama_eora/eora/ops.h @@ -8,7 +8,7 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, torch::Tensor b_gptq_scales, torch::Tensor b_g_idx, bool use_exllama, int64_t bit); -torch::Tensor gptq_gemm_lora(torch::Tensor a, torch::Tensor b_q_weight, +torch::Tensor gptq_gemm_eora(torch::Tensor a, torch::Tensor b_q_weight, torch::Tensor b_gptq_qzeros, torch::Tensor b_gptq_scales, torch::Tensor b_g_idx, bool use_exllama, int64_t bit, diff --git a/gptqmodel_ext/exllama_eora/eora/pybind.cu b/gptqmodel_ext/exllama_eora/eora/pybind.cu index a4fe68907..a136adf9b 100644 --- a/gptqmodel_ext/exllama_eora/eora/pybind.cu +++ b/gptqmodel_ext/exllama_eora/eora/pybind.cu @@ -3,7 +3,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("gptq_gemm", &gptq_gemm, "gptq_gemm") - .def("gptq_gemm_lora", &gptq_gemm_lora, "gptq_gemm_lora") + .def("gptq_gemm_eora", &gptq_gemm_eora, "gptq_gemm_eora") .def("gptq_shuffle", &gptq_shuffle, "gptq_shuffle") ; } \ No newline at end of file diff --git a/gptqmodel_ext/exllama_eora/eora/q_gemm.cu b/gptqmodel_ext/exllama_eora/eora/q_gemm.cu index fc6060373..ba276cbbc 100644 --- a/gptqmodel_ext/exllama_eora/eora/q_gemm.cu +++ b/gptqmodel_ext/exllama_eora/eora/q_gemm.cu @@ -2099,7 +2099,7 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, return c; } -torch::Tensor gptq_gemm_lora(torch::Tensor a, torch::Tensor b_q_weight, +torch::Tensor gptq_gemm_eora(torch::Tensor a, torch::Tensor b_q_weight, torch::Tensor b_gptq_qzeros, torch::Tensor b_gptq_scales, torch::Tensor b_g_idx, bool use_exllama, int64_t bit, diff --git a/gptqmodel_ext/exllama_eora/test_actual_value.py b/gptqmodel_ext/exllama_eora/test_actual_value.py index a7998ff55..d720c711c 100644 --- a/gptqmodel_ext/exllama_eora/test_actual_value.py +++ b/gptqmodel_ext/exllama_eora/test_actual_value.py @@ -4,7 +4,7 @@ import gptqmodel_exllama_eora import torch # from eora_test import fused_concurrent, fused_sequential, cublas_reference, gptq_gemm_eora, gptq_gemm -from gptqmodel_exllama_eora import gptq_gemm, gptq_gemm_lora +from gptqmodel_exllama_eora import gptq_gemm, gptq_gemm_eora from gptqmodel_exllama_kernels import make_q4, q4_matmul from safetensors import safe_open @@ -255,7 +255,7 @@ def test_eora_kernel(): gptq_pytorch_out = gptq_gemm(reshaped_x, weight, zeros, scales, idx, use_exllama, bit) + (ax @ eora_b) - gptq_eora_fused_out = gptq_gemm_lora(reshaped_x, weight, zeros, scales, idx, use_exllama, bit, ax, eora_b) + gptq_eora_fused_out = gptq_gemm_eora(reshaped_x, weight, zeros, scales, idx, use_exllama, bit, ax, eora_b) torch.set_printoptions(precision=6) # print("gptq exllama kernel out: ") # print(exllama_out[0][:10]) diff --git a/gptqmodel_ext/exllama_eora/test_eora.py b/gptqmodel_ext/exllama_eora/test_eora.py index 63e91113f..a747381d3 100644 --- a/gptqmodel_ext/exllama_eora/test_eora.py +++ b/gptqmodel_ext/exllama_eora/test_eora.py @@ -2,7 +2,7 @@ import torch # from eora import fused_concurrent, fused_sequential, cublas_reference, gptq_gemm_eora, gptq_gemm -from gptqmodel_exllama_eora import gptq_gemm, gptq_gemm_lora +from gptqmodel_exllama_eora import gptq_gemm, gptq_gemm_eora m = 1 k = 4096 @@ -34,7 +34,7 @@ def test_eora_kernel(): gptq_pytorch_out = gptq_gemm(x, weight, zeros, scales, idx, use_exllama, bit) + (ax @ eora_b) print("gptq_pytorch_out: ") print(gptq_pytorch_out[0][:10]) - gptq_eora_fused_out = gptq_gemm_lora(x, weight, zeros, scales, idx, use_exllama, bit, ax, eora_b) + gptq_eora_fused_out = gptq_gemm_eora(x, weight, zeros, scales, idx, use_exllama, bit, ax, eora_b) print("gptq_eora_fused_out: ") print(gptq_eora_fused_out[0][:10]) torch.testing.assert_close(gptq_pytorch_out, gptq_eora_fused_out, rtol=0.05, atol=0.5) # 5 % relative tolerance, 0.5 absolute tolerance diff --git a/tests/test_kernel_output.py b/tests/test_kernel_output.py index be94531ce..677f95a5a 100644 --- a/tests/test_kernel_output.py +++ b/tests/test_kernel_output.py @@ -21,7 +21,6 @@ class TestKernelOutput(unittest.TestCase): model_path = "sliuau/llama3.2-1b-4bit-group128" # hf "sliuau/llama3.2-1b-4bit-group128" - target_qliner_map = { BACKEND.EXLLAMA_V1: ExllamaQuantLinear, BACKEND.EXLLAMA_EORA: ExllamaEoraQuantLinear, @@ -41,7 +40,6 @@ class TestKernelOutput(unittest.TestCase): def setUpClass(cls): lora_path = "sliuau/llama3.2-1b-4bit-group128-eora-rank128-arc" # adapter_model.safetensors # hf "sliuau-llama3.2-1b-4bit-group128/llama3.2-1b-4bit-group128-eora-rank128-arc/" - cls.m = 1 cls.k = -1 cls.x = None # random X input of shape (m, k) @@ -108,7 +106,7 @@ def test_kernel_output(self, backend: BACKEND, r_tolerance: float, a_tolerance: (BACKEND.EXLLAMA_V2, 0.16, 0.0003), (BACKEND.MARLIN, 0.00001, 0.00003), (BACKEND.MARLIN_FP16, 0.0001, 0.0035), - (BACKEND.EXLLAMA_EORA, 0.00001, 0.00147) + (BACKEND.EXLLAMA_EORA, 0.0001, 0.0035) ]) def test_kernel_output_with_lora(self, backend: BACKEND, r_tolerance: float, a_tolerance: float): out = self.forward(backend=backend, adapter=self.adapter)