Skip to content

Commit

Permalink
save/load peft lora (#1358)
Browse files Browse the repository at this point in the history
* save eora to hf format

Signed-off-by: Qubitium <[email protected]>

* test needs to store self.x in cls to stay consistent

Signed-off-by: Qubitium <[email protected]>

* temp disable torch kernel auto compile that is causing dynamo errors

Signed-off-by: Qubitium <[email protected]>

* fix shape error

Signed-off-by: ZX-ModelCloud <[email protected]>

* cleanup debug logs

Signed-off-by: Qubitium <[email protected]>

* re-enable auto torch compile code

Signed-off-by: Qubitium <[email protected]>

* add lora config validation

Signed-off-by: Qubitium <[email protected]>

* refractor loading cache into AdapterCache cls

Signed-off-by: Qubitium <[email protected]>

* add lora rank override code from LoraConfig

Signed-off-by: Qubitium <[email protected]>

* remove `peft` dependency

Signed-off-by: Qubitium <[email protected]>

* comment on original HF repo path for test files

Signed-off-by: Qubitium <[email protected]>

* clean up HF download logic

Signed-off-by: Qubitium <[email protected]>

* save to PEFT compatible format

Signed-off-by: Qubitium <[email protected]>

* add test_quant_and_eora_transformers.py

Signed-off-by: ZX-ModelCloud <[email protected]>

* fix missing task_type in adapter_config.json

Signed-off-by: ZX-ModelCloud <[email protected]>

* fix regex rule prefix not stripped

Signed-off-by: Qubitium <[email protected]>

* push peft compat changes

Signed-off-by: Qubitium <[email protected]>

* prevent preft doing alpha / r scaling. set alpha eq r so math is just 1, no scaling

Signed-off-by: Qubitium <[email protected]>

* fix lora load with transformers

Signed-off-by: ZX-ModelCloud <[email protected]>

* fix device

Signed-off-by: ZX-ModelCloud <[email protected]>

* format

Signed-off-by: Qubitium <[email protected]>

* assert lora weight

Signed-off-by: ZX-ModelCloud <[email protected]>

* fix empty base_model_name_or_path

Signed-off-by: ZX-ModelCloud <[email protected]>

* assert dynamic rank

Signed-off-by: ZX-ModelCloud <[email protected]>

* remove dynamic adapter config when save quantize_config

Signed-off-by: ZX-ModelCloud <[email protected]>

* fix dynamic is none

* [CI] install bitblas for test_inference_speed

---------

Signed-off-by: Qubitium <[email protected]>
Signed-off-by: ZX-ModelCloud <[email protected]>
Co-authored-by: ZX-ModelCloud <[email protected]>
Co-authored-by: CSY-ModelCloud <[email protected]>
  • Loading branch information
3 people authored Mar 1, 2025
1 parent a6a8e82 commit f7b86a5
Show file tree
Hide file tree
Showing 17 changed files with 936 additions and 150 deletions.
1 change: 1 addition & 0 deletions .github/workflows/unit_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,7 @@ jobs:
uv pip install -U transformers
uv pip install -U logbar==0.0.3
if [ "${{ matrix.test_script }}" == "test_perplexity" ] || \
[ "${{ matrix.test_script }}" == "test_inference_speed" ] || \
[ "${{ matrix.test_script }}" == "test_q4_bitblas" ] || \
[ "${{ matrix.test_script }}" == "test_save_loaded_quantized_model" ]; then
echo "===== install bitblas==0.0.1.dev13 ====="
Expand Down
172 changes: 112 additions & 60 deletions gptqmodel/adapter/adapter.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,66 @@
import os
from dataclasses import dataclass, field
from typing import Dict, List, Union
from urllib.parse import urlparse

import re
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union

import safetensors
import torch

from ..utils.logger import setup_logger
from .peft import LoraConfig
from .remote import resolve_path

logger = setup_logger()
LORA_MERGED_WEIGHT_PATHS = [None, ""]
HF_ADAPTER_FILE_NAME = "adapter_model.safetensors"
HF_ADAPTER_CONFIG_FILE_NAME = "adapter_config.json"
HF_ADAPTER_WEIGHT_KEY_PREFIX = "base_model.model."


class AdapterCache():
cache: Dict[str, Dict[str, Union[LoraConfig, torch.Tensor]]] = {} # first level key is `path`, second level keys [ `config` = LoraConfig, `weights` = Dict[str, Tensors]

@classmethod
def get(cls, path: str) -> Optional[Tuple[LoraConfig, Dict[str, torch.Tensor]]]:
data = cls.cache.get(path)
if not data:
return None
else:
return data["config"], data["weights"]

@classmethod
def reset(cls):
logger.info("Adapter Cache: Resetting cache")
cls.cache = {}

@classmethod
def add(cls, path: str, config: LoraConfig, weights: Dict[str, torch.Tensor]):
cls.cache[path] = {"config": config, "weights": weights}

@classmethod
def remove(cls, path):
cls.cache.pop(path, None)

# TODO FIX ME: cache of adapter tensors loaded from disk
adapter_load_cache = None

class Adapter():
def __init__(self, rank: int, path: str = None):
self.rank = rank
def __init__(self, rank: int = None, path: str = None):
self.rank = rank # rank may be zero, when loading, and rank will be re-populated by loading saved LoraConfig file
self.path = path.lower().strip() if isinstance(path, str) else path

def validate_path(self, local_only=False):
def validate_path(self, local=False):
if not self.path or not isinstance(self.path, str):
raise ValueError("Adapter: `path` str is required.")

if local_only:
# path should not be a file but a directory
if self.path.endswith(".safetensors"):
raise ValueError(
f"Adapter: `path` must be a directory path or repo depending if you are saving (directory path) or loading (repo): actual = `{self.path}`")

if local:
if self.path.startswith("http"):
raise ValueError(f"Adapter: `path` str in this context must be a local os path: actual = `{self.path}`.")


# override me
def apply(self, x: torch.Tensor, out: torch.Tensor) -> torch.Tensor:
pass
Expand Down Expand Up @@ -97,52 +131,69 @@ def post_init(self, weight_key: str, device:torch.device, lora_A: torch.Tensor=N
self.lora_A, self.lora_B = lora_A, lora_B
return

global adapter_load_cache
if adapter_load_cache is None:
if os.path.isfile(self.path):
lora_path = self.path
logger.info(f"Adapter: Loading `{self.path}` tensors from disk") # {adapter_load_cache}
elif self.path.startswith("http"):
from huggingface_hub import hf_hub_download
result = self.parse_url(self.path)
if len(result) == 3:
logger.info(f"Adapter: Downloading adapter weights from hf repo: `{result[0]}` revision: `{result[1]}` file: `{result[2]}`")
lora_path = hf_hub_download(repo_id=result[0], revision=result[1], filename=result[2])
elif len(result) == 1:
logger.info(f"Adapter: Downloading adapter weights from uri = `{self.path}`")
import requests
response = requests.get(self.path, stream=True)
lora_path = "lora.safetensors"
with open(lora_path, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
else:
raise Exception(f"Adapter: Lora path is invalid: `{self.path}`")
lora_cache = AdapterCache.get(self.path)
if lora_cache is None:
# get lora config
lora_cfg = LoraConfig.from_pretrained(path=self.path, filename=HF_ADAPTER_CONFIG_FILE_NAME)
lora_cfg.gptqmodel_path = self.path # hack: save this

if not isinstance(lora_cfg, LoraConfig):
raise ValueError(f"Adapter: Expected `LoraConfig` in `{self.path}`, actual = `{lora_cfg}`")

if self.rank is None:
self.rank = lora_cfg.r
else:
from huggingface_hub import HfApi, hf_hub_download
files = [f for f in HfApi().list_repo_files(self.path) if f in ["lora.safetensors", "eora_test.safetensors"]]
if self.rank != lora_cfg.r:
raise ValueError(f"Adapter: `rank` must match `LoraConfig.r`, expected `{self.rank}`, actual = `{lora_cfg.r}`")

lora_path = resolve_path(self.path, HF_ADAPTER_FILE_NAME)

# save to adapter cache
AdapterCache.add(self.path, lora_cfg, safetensors.torch.load_file(lora_path))

if files:
lora_path = hf_hub_download(repo_id=self.path, filename=files[0])
# print(f"Adapter tensors loaded from `{self.path}`")
else:
raise Exception(f"Adapter: There's no lora.safetensors or eora_test.safetensors on repo `{self.path}`")
lora_cache = AdapterCache.get(self.path)
assert lora_cache is not None

adapter_load_cache = safetensors.torch.load_file(lora_path)
# lora_cache result is a tuple
lora_cfg, lora_weights = lora_cache

weight_key = weight_key.lower()

# hack for HF Auto compat
if not f"{weight_key}.lora_A.weight" in adapter_load_cache:
weight_key = weight_key.removeprefix("model.")
lora_A_weight_key = f"{weight_key}.lora_A.weight"
lora_B_weight_key = f"{weight_key}.lora_B.weight"

#print(f"loaded lora weight keys: {adapter_load_cache.keys()}")
lora_A = adapter_load_cache.pop(f"{weight_key}.lora_A.weight").T
lora_B = adapter_load_cache.pop(f"{weight_key}.lora_B.weight").T
# print(f"lora_A_weight_key = {lora_A_weight_key}, lora_B_weight_key = {lora_B_weight_key}")
pop_keys = []
for k, v in lora_weights.items():
if k.endswith(lora_A_weight_key):
lora_A = v.T
pop_keys.append(k)
elif k.endswith(lora_B_weight_key):
lora_B = v.T
pop_keys.append(k)

# since loder cache is singleton, we need to reset to None to ci loop tests can pass
if len(adapter_load_cache) == 0:
adapter_load_cache = None

if pop_keys:
for k in pop_keys:
lora_weights.pop(k) # releasee lora weights from cache memory

# we have consumed all modules
if len(lora_weights) == 0:
AdapterCache.remove(self.path)
logger.info("Adapter: Consumed all Lora weights")

else:
logger.warn(f"Adapter: Lora weights not found for `{weight_key}`")

assert lora_A is not None and lora_B is not None, f"Adapter: `lora_A` and `lora_B` must both be present in the weights: actual = `{lora_A}` and `{lora_B}`"

# check for rank override from base config
self.dynamic_rank_override(lora_cfg=lora_cfg, weight_key=weight_key)

# # since loder cache is singleton, we need to reset to None to ci loop tests can pass
# if len(lora_weights) == 0:
# adapter_load_cache = None

# print(f"Adapter: {self.name()}, loaded lora_A shape: {lora_A.shape}")
# print(f"Adapter: {self.name()}, loaded lora_B shape: {lora_B.shape}")
Expand All @@ -155,21 +206,22 @@ def post_init(self, weight_key: str, device:torch.device, lora_A: torch.Tensor=N
#print(f"Adapter: lora_A {lora_A.shape}: `{lora_B}`")
#print(f"Adapter: lora_B {lora_B.shape}: `{lora_B}`")

def parse_url(self, url: str):
parsed_url = urlparse(url)
def dynamic_rank_override(self, lora_cfg: LoraConfig, weight_key: str) -> bool:
assert lora_cfg.rank_pattern is not None and weight_key is not None
if lora_cfg.rank_pattern:
for k, v in lora_cfg.rank_pattern.items():
assert isinstance(k, str) and isinstance(v, int)
k = k.lower()
assert v > 0 # check for invalid rank range
# first do string full match, then suffix match, then regex match
if weight_key == k or k.endswith(weight_key) or re.match(k, weight_key):
self.rank = v
logger.info(f"Adapter: Base Lora `rank` = `{self.rank}` has been overridden by `{k}` due to dynamic `LoraConfig.rank_pattern` control.")
return True

return False

if parsed_url.netloc.endswith("huggingface.co") or parsed_url.netloc.endswith("hf.co"):
parts = parsed_url.path.strip("/").split("/")

if "blob" in parts:
idx = parts.index("blob")
repo_id = "/".join(parts[:idx])
rev = parts[idx + 1]
filename = parts[idx + 2].split("?")[0] # remove ?download=true
return [repo_id, rev, filename]
else:
return [url]
return []

def to_dict(self):
return {
Expand Down
Loading

0 comments on commit f7b86a5

Please sign in to comment.