Skip to content

Commit

Permalink
Merge pull request #254 from lucyknada/main
Browse files Browse the repository at this point in the history
add draft_gpu_split option for spec decoding
  • Loading branch information
kingbri1 authored Feb 11, 2025
2 parents e290b88 + beb6d8f commit 2e49147
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 16 deletions.
62 changes: 47 additions & 15 deletions backends/exllamav2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ class ExllamaV2Container:
generation_config: Optional[GenerationConfig] = None

# GPU split vars
gpu_split: Optional[list] = None
gpu_split: List[float] = []
draft_gpu_split: List[float] = []
gpu_split_auto: bool = True
autosplit_reserve: List[float] = [96 * 1024**2]
use_tp: bool = False
Expand Down Expand Up @@ -180,6 +181,7 @@ async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs):
)
draft_model_path = draft_model_path / draft_model_name

self.draft_gpu_split = draft_args.get("draft_gpu_split")
self.draft_model_dir = draft_model_path
self.draft_config.model_dir = str(draft_model_path.resolve())
self.draft_config.prepare()
Expand Down Expand Up @@ -232,6 +234,15 @@ async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs):
for value in autosplit_reserve_megabytes
]

# Change the GPU device list only if gpu_split's list is too small
# This allows for an uneven list specification
if self.draft_gpu_split and len(self.draft_gpu_split) > len(self.gpu_split):
gpu_device_list = [
device_idx
for device_idx, memory in enumerate(self.draft_gpu_split)
if memory > 0
]

# Hardcode max output length to 16
self.config.max_output_len = 16

Expand Down Expand Up @@ -375,6 +386,7 @@ async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs):
# Set draft cache mode
self.draft_cache_mode = unwrap(draft_args.get("draft_cache_mode"), "FP16")

# Edit the draft config size
if chunk_size:
self.draft_config.max_input_len = chunk_size
self.draft_config.max_attention_size = chunk_size**2
Expand Down Expand Up @@ -619,21 +631,41 @@ def progress(loaded_modules: int, total_modules: int)

# Draft uses the autosplit loader, so create a cache that reflects this
draft_cache_class = self.get_cache_class(self.draft_cache_mode)
self.draft_cache = self.create_cache(
cache_class=draft_cache_class,
autosplit=True,
use_tp=False,
model=self.draft_model,
)

for value in self.draft_model.load_autosplit_gen(
self.draft_cache,
reserve_vram=autosplit_reserve,
last_id_only=True,
callback_gen=progress_callback,
):
if value:
yield value
if self.draft_gpu_split:
logger.info("Loading with a manual GPU split (or a one GPU setup)")

for value in self.draft_model.load_gen(
self.draft_gpu_split,
callback_gen=progress_callback,
):
if value:
yield value

self.draft_cache = self.create_cache(
cache_class=draft_cache_class,
autosplit=False,
use_tp=False,
model=self.draft_model,
)
else:
logger.info("Loading with autosplit")

self.draft_cache = self.create_cache(
cache_class=draft_cache_class,
autosplit=True,
use_tp=False,
model=self.draft_model,
)

for value in self.draft_model.load_autosplit_gen(
self.draft_cache,
reserve_vram=autosplit_reserve,
last_id_only=True,
callback_gen=progress_callback,
):
if value:
yield value

# Test VRAM allocation with a full-length forward pass
input_ids = torch.zeros((1, self.config.max_input_len), dtype=torch.long)
Expand Down
7 changes: 7 additions & 0 deletions common/config_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,13 @@ class DraftModelConfig(BaseConfigModel):
f"Possible values: {str(CACHE_SIZES)[15:-1]}."
),
)
draft_gpu_split: List[float] = Field(
default_factory=list,
description=(
"An integer array of GBs of VRAM to split between GPUs (default: []).\n"
"If this isn't filled in, the draft model is autosplit."
),
)


class LoraInstanceModel(BaseConfigModel):
Expand Down
6 changes: 5 additions & 1 deletion config_sample.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ network:
# Turn on this option if you are ONLY connecting from localhost.
disable_auth: false

# Disable fetching external content in response to requests, such as images from URLs.
# Disable fetching external content in response to requests,such as images from URLs.
disable_fetch_requests: false

# Send tracebacks over the API (default: False).
Expand Down Expand Up @@ -166,6 +166,10 @@ draft_model:
# Possible values: 'FP16', 'Q8', 'Q6', 'Q4'.
draft_cache_mode: FP16

# An integer array of GBs of VRAM to split between GPUs (default: []).
# If this isn't filled in, the draft model is autosplit.
draft_gpu_split: []

# Options for Loras
lora:
# Directory to look for LoRAs (default: loras).
Expand Down

0 comments on commit 2e49147

Please sign in to comment.