Skip to content

Commit

Permalink
Merge branch 'vllm-project:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
siddartha-RE authored Feb 5, 2025
2 parents 256475a + bf3b79e commit 34a8799
Show file tree
Hide file tree
Showing 117 changed files with 5,471 additions and 2,585 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# bash ./run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_fp8-BitM -b "auto" -t 2
model_name: "nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_fp8-BitM"
tasks:
- name: "gsm8k"
metrics:
- name: "exact_match,strict-match"
value: 0.6353
- name: "exact_match,flexible-extract"
value: 0.637
limit: null
num_fewshot: null
2 changes: 2 additions & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ steps:
- tests/spec_decode/e2e/test_integration_dist_tp4
- tests/compile
- examples/offline_inference/rlhf.py
- examples/offline_inference/ray_placement.py
commands:
- pytest -v -s distributed/test_utils.py
- pytest -v -s compile/test_basic_correctness.py
Expand All @@ -136,6 +137,7 @@ steps:
# TODO: create a dedicated test section for multi-GPU example tests
# when we have multiple distributed example tests
- python3 ../examples/offline_inference/rlhf.py
- RAY_DEDUP_LOGS=0 python3 ../examples/offline_inference/ray_placement.py

- label: Metrics, Tracing Test # 10min
num_gpus: 2
Expand Down
8 changes: 6 additions & 2 deletions .github/workflows/reminder_comment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ name: PR Reminder Comment Bot
on:
pull_request_target:
types: [opened]

jobs:
pr_reminder:
runs-on: ubuntu-latest
Expand All @@ -15,7 +14,12 @@ jobs:
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: context.issue.number,
body: '👋 Hi! Thank you for contributing to the vLLM project.\n Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run `fastcheck` CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your `fastcheck` build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping `simon-mo` or `khluu` to add you in our Buildkite org. \n\nOnce the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.\n\n To run CI, PR reviewers can do one of these:\n- Add `ready` label to the PR\n- Enable auto-merge.\n\n🚀'
body: '👋 Hi! Thank you for contributing to the vLLM project.\n\n' +
'💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.\n\n' +
'Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run `fastcheck` CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your `fastcheck` build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping `simon-mo` or `khluu` to add you in our Buildkite org.\n\n' +
'Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.\n\n' +
'To run CI, PR reviewers can either: Add `ready` label to the PR or enable auto-merge.\n\n' +
'🚀'
})
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
3 changes: 3 additions & 0 deletions csrc/cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ void copy_blocks(std::vector<torch::Tensor> const& key_caches,
std::vector<torch::Tensor> const& value_caches,
const torch::Tensor& block_mapping);

void copy_blocks_mla(std::vector<torch::Tensor> const& kv_caches,
const torch::Tensor& block_mapping);

void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key_cache, torch::Tensor& value_cache,
torch::Tensor& slot_mapping,
Expand Down
82 changes: 70 additions & 12 deletions csrc/cache_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,10 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
char* src_ptr = static_cast<char*>(src.data_ptr());
char* dst_ptr = static_cast<char*>(dst.data_ptr());

const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
// We use the stride instead of numel in case the cache is padded for memory
// alignment reasons, we assume the blocks data (inclusive of any padding)
// is contiguous in memory
const int64_t block_size_in_bytes = src.element_size() * src.stride(0);
const at::cuda::OptionalCUDAGuard device_guard(
src_device.is_cuda() ? src_device : dst_device);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
Expand Down Expand Up @@ -93,6 +96,24 @@ __global__ void copy_blocks_kernel(int64_t* key_cache_ptrs,
}
}

// Kernel for MLA, which works on a single joint kv_cache
// Grid: (num_layers, num_pairs)
template <typename scalar_t>
__global__ void copy_blocks_mla_kernel(
int64_t* cache_ptrs, const int64_t* __restrict__ block_mapping,
const int mem_footprint_per_block) {
const int layer_idx = blockIdx.x;
const int pair_idx = blockIdx.y;
scalar_t* cache = reinterpret_cast<scalar_t*>(cache_ptrs[layer_idx]);
int64_t src_block = block_mapping[2 * pair_idx];
int64_t dst_block = block_mapping[2 * pair_idx + 1];
int64_t src_offset = src_block * mem_footprint_per_block;
int64_t dst_offset = dst_block * mem_footprint_per_block;
for (int i = threadIdx.x; i < mem_footprint_per_block; i += blockDim.x) {
cache[dst_offset + i] = cache[src_offset + i];
}
}

} // namespace vllm

// Note: the key_caches and value_caches vectors are constant but
Expand Down Expand Up @@ -147,6 +168,42 @@ void copy_blocks(std::vector<torch::Tensor> const& key_caches,
}));
}

// copy blocks kernel for MLA (assumes a joint KV-cache)
void copy_blocks_mla(std::vector<torch::Tensor> const& kv_caches,
const torch::Tensor& block_mapping) {
int num_layers = kv_caches.size();
if (num_layers == 0) {
return;
}
torch::Device cache_device = kv_caches[0].device();
TORCH_CHECK(cache_device.is_cuda(), "kv_cache must be on CUDA");

std::vector<int64_t> cache_ptrs(num_layers);
for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) {
cache_ptrs[layer_idx] =
reinterpret_cast<int64_t>(kv_caches[layer_idx].data_ptr());
}
torch::Tensor cache_ptrs_tensor =
torch::from_blob(cache_ptrs.data(), {num_layers}, torch::kInt64)
.to(cache_device);

int num_pairs = block_mapping.size(0);
// We use the stride instead of numel in case the cache is padded for memory
// alignment reasons, we assume the blocks data (inclusive of any padding)
// is contiguous in memory
int mem_footprint_per_block = kv_caches[0].stride(0);
dim3 grid(num_layers, num_pairs);
dim3 block(std::min(1024, mem_footprint_per_block));
const at::cuda::OptionalCUDAGuard device_guard(cache_device);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(
kv_caches[0].scalar_type(), "copy_blocks_mla_kernel", ([&] {
vllm::copy_blocks_mla_kernel<scalar_t><<<grid, block, 0, stream>>>(
cache_ptrs_tensor.data_ptr<int64_t>(),
block_mapping.data_ptr<int64_t>(), mem_footprint_per_block);
}));
}

namespace vllm {

template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
Expand Down Expand Up @@ -254,6 +311,7 @@ __global__ void concat_and_cache_mla_kernel(
// + pe_dim)]
const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int block_stride, //
const int entry_stride, //
const int kv_c_stride, //
const int k_pe_stride, //
const int kv_lora_rank, //
Expand All @@ -274,9 +332,8 @@ __global__ void concat_and_cache_mla_kernel(
int src_stride, int dst_stride, int size, int offset) {
for (int i = threadIdx.x; i < size; i += blockDim.x) {
const int64_t src_idx = token_idx * src_stride + i;
const int64_t dst_idx = block_idx * block_stride +
block_offset * (kv_lora_rank + pe_dim) + i +
offset;
const int64_t dst_idx =
block_idx * block_stride + block_offset * entry_stride + i + offset;
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
dst[dst_idx] = src[src_idx];
} else {
Expand Down Expand Up @@ -391,14 +448,14 @@ void reshape_and_cache_flash(
// KV_T is the stored data type of kv-cache.
// CACHE_T is the data type of key and value tensors.
// KV_DTYPE is the real data type of kv-cache.
#define CALL_CONCAT_AND_CACHE_MLA(KV_T, CACHE_T, KV_DTYPE) \
vllm::concat_and_cache_mla_kernel<KV_T, CACHE_T, KV_DTYPE> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<KV_T*>(kv_c.data_ptr()), \
reinterpret_cast<KV_T*>(k_pe.data_ptr()), \
reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), block_stride, kv_c_stride, \
k_pe_stride, kv_lora_rank, pe_dim, block_size, \
#define CALL_CONCAT_AND_CACHE_MLA(KV_T, CACHE_T, KV_DTYPE) \
vllm::concat_and_cache_mla_kernel<KV_T, CACHE_T, KV_DTYPE> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<KV_T*>(kv_c.data_ptr()), \
reinterpret_cast<KV_T*>(k_pe.data_ptr()), \
reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), block_stride, entry_stride, \
kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size, \
reinterpret_cast<const float*>(scale.data_ptr()));

void concat_and_cache_mla(
Expand Down Expand Up @@ -428,6 +485,7 @@ void concat_and_cache_mla(
int kv_c_stride = kv_c.stride(0);
int k_pe_stride = k_pe.stride(0);
int block_stride = kv_cache.stride(0);
int entry_stride = kv_cache.stride(1);

dim3 grid(num_tokens);
dim3 block(std::min(kv_lora_rank, 512));
Expand Down
4 changes: 4 additions & 0 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,10 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
"Tensor block_mapping) -> ()");
cache_ops.impl("copy_blocks", torch::kCUDA, &copy_blocks);

cache_ops.def(
"copy_blocks_mla(Tensor(a!)[] kv_caches, Tensor block_mapping) -> ()");
cache_ops.impl("copy_blocks_mla", torch::kCUDA, &copy_blocks_mla);

// Reshape the key and value tensors and cache them.
cache_ops.def(
"reshape_and_cache(Tensor key, Tensor value,"
Expand Down
1 change: 0 additions & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
# ones.
extensions = [
"sphinx.ext.napoleon",
"sphinx.ext.viewcode",
"sphinx.ext.linkcode",
"sphinx.ext.intersphinx",
"sphinx_copybutton",
Expand Down
6 changes: 5 additions & 1 deletion docs/source/contributing/model/multimodal.md
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,11 @@ def get_max_image_tokens(self) -> int:
And thus, we can override the method as:

```python
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.get_max_image_tokens()}
```

Expand Down
6 changes: 0 additions & 6 deletions docs/source/features/quantization/auto_awq.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,6 @@

# AutoAWQ

:::{warning}
Please note that AWQ support in vLLM is under-optimized at the moment. We would recommend using the unquantized version of the model for better
accuracy and higher throughput. Currently, you can use AWQ as a way to reduce memory footprint. As of now, it is more suitable for low latency
inference with small number of concurrent requests. vLLM's AWQ implementation have lower throughput than unquantized version.
:::

To create a new 4-bit quantized model, you can leverage [AutoAWQ](https://github.com/casper-hansen/AutoAWQ).
Quantizing reduces the model's precision from FP16 to INT4 which effectively reduces the file size by ~70%.
The main benefits are lower latency and memory usage.
Expand Down
12 changes: 6 additions & 6 deletions docs/source/features/spec_decode.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
llm = LLM(
model="meta-llama/Meta-Llama-3.1-70B-Instruct",
tensor_parallel_size=4,
speculative_model="ibm-fms/llama3-70b-accelerator",
speculative_model="ibm-ai-platform/llama3-70b-accelerator",
speculative_draft_tensor_parallel_size=1,
)
outputs = llm.generate(prompts, sampling_params)
Expand All @@ -149,11 +149,11 @@ limitation will be fixed in a future release.

A variety of speculative models of this type are available on HF hub:

- [llama-13b-accelerator](https://huggingface.co/ibm-fms/llama-13b-accelerator)
- [llama3-8b-accelerator](https://huggingface.co/ibm-fms/llama3-8b-accelerator)
- [codellama-34b-accelerator](https://huggingface.co/ibm-fms/codellama-34b-accelerator)
- [llama2-70b-accelerator](https://huggingface.co/ibm-fms/llama2-70b-accelerator)
- [llama3-70b-accelerator](https://huggingface.co/ibm-fms/llama3-70b-accelerator)
- [llama-13b-accelerator](https://huggingface.co/ibm-ai-platform/llama-13b-accelerator)
- [llama3-8b-accelerator](https://huggingface.co/ibm-ai-platform/llama3-8b-accelerator)
- [codellama-34b-accelerator](https://huggingface.co/ibm-ai-platform/codellama-34b-accelerator)
- [llama2-70b-accelerator](https://huggingface.co/ibm-ai-platform/llama2-70b-accelerator)
- [llama3-70b-accelerator](https://huggingface.co/ibm-ai-platform/llama3-70b-accelerator)
- [granite-3b-code-instruct-accelerator](https://huggingface.co/ibm-granite/granite-3b-code-instruct-accelerator)
- [granite-8b-code-instruct-accelerator](https://huggingface.co/ibm-granite/granite-8b-code-instruct-accelerator)
- [granite-7b-instruct-accelerator](https://huggingface.co/ibm-granite/granite-7b-instruct-accelerator)
Expand Down
23 changes: 19 additions & 4 deletions docs/source/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -726,14 +726,14 @@ See [this page](#generative-models) for more information on how to use generativ
* `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc.
*
* ✅︎
*
* \*
- * `Idefics3ForConditionalGeneration`
* Idefics3
* T + I
* `HuggingFaceM4/Idefics3-8B-Llama3` etc.
* ✅︎
*
*
* ✅︎
- * `InternVLChatModel`
* InternVL 2.5, Mono-InternVL, InternVL 2.0
* T + I<sup>E+</sup>
Expand Down Expand Up @@ -799,7 +799,7 @@ See [this page](#generative-models) for more information on how to use generativ
* ✅︎
- * `NVLM_D_Model`
* NVLM-D 1.0
* T + I<sup>E+</sup>
* T + I<sup>+</sup>
* `nvidia/NVLM-D-72B`, etc.
*
* ✅︎
Expand Down Expand Up @@ -846,6 +846,13 @@ See [this page](#generative-models) for more information on how to use generativ
* ✅︎
* ✅︎
* ✅︎
- * `Qwen2_5_VLForConditionalGeneration`
* Qwen2.5-VL
* T + I<sup>E+</sup> + V<sup>E+</sup>
* `Qwen/Qwen2.5-VL-3B-Instruct`, `Qwen/Qwen2.5-VL-72B-Instruct`, etc.
*
* ✅︎
* ✅︎
- * `UltravoxModel`
* Ultravox
* T + A<sup>E+</sup>
Expand All @@ -859,7 +866,11 @@ See [this page](#generative-models) for more information on how to use generativ
<sup>+</sup> Multiple items can be inputted per text prompt for this modality.

:::{note}
To use `DeepSeek-VL2` series models, you have to pass `--hf_overrides '{"architectures": ["DeepseekVLV2ForCausalLM"]}'` when running vLLM.
To use DeepSeek-VL2 series models, you have to pass `--hf_overrides '{"architectures": ["DeepseekVLV2ForCausalLM"]}'` when running vLLM.
:::

:::{note}
H2O-VL series models will be available in V1 once we support backends other than FlashAttention.
:::

:::{note}
Expand All @@ -876,6 +887,10 @@ The chat template for Pixtral-HF is incorrect (see [discussion](https://huggingf
A corrected version is available at <gh-file:examples/template_pixtral_hf.jinja>.
:::

:::{note}
To use Qwen2.5-VL series models, you have to install Huggingface `transformers` library from source via `pip install git+https://github.com/huggingface/transformers`.
:::

### Pooling Models

See [this page](pooling-models) for more information on how to use pooling models.
Expand Down
2 changes: 1 addition & 1 deletion examples/offline_inference/mlpspeculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def time_generation(llm: LLM, prompts: List[str],
# Create an LLM with spec decoding
llm = LLM(
model="meta-llama/Llama-2-13b-chat-hf",
speculative_model="ibm-fms/llama-13b-accelerator",
speculative_model="ibm-ai-platform/llama-13b-accelerator",
)

print("With speculation")
Expand Down
Loading

0 comments on commit 34a8799

Please sign in to comment.