diff --git a/.github/workflows/nv-a6000.yml b/.github/workflows/nv-a6000.yml index 3ce406948432..484948b28e34 100644 --- a/.github/workflows/nv-a6000.yml +++ b/.github/workflows/nv-a6000.yml @@ -47,7 +47,8 @@ jobs: - name: Install deepspeed run: | python -m pip install docutils==0.18.1 jinja2==3.0 urllib3==1.26.11 ninja - python -m pip install pydantic==1.10.11 + # Update packages included in the container that do not support pydantic 2+ to versions that do + python -m pip install thinc spacy confection --upgrade python -m pip install .[dev,1bit,autotuning,inf] ds_report - name: Python environment diff --git a/.github/workflows/xpu-max1100.yml b/.github/workflows/xpu-max1100.yml index 1042db100a21..adeeb0acade2 100644 --- a/.github/workflows/xpu-max1100.yml +++ b/.github/workflows/xpu-max1100.yml @@ -21,7 +21,7 @@ on: - "deepspeed/runtime/zero/parameter_offload.py" - "deepspeed/runtime/pipe/engine.py" - "deepspeed/runtime/utils.py" - - "opbuilder/xpu/**" + - "op_builder/xpu/**" concurrency: group: ${{ github.workflow }}-${{ github.ref }} diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 11a14256801a..c9b41439801b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -59,7 +59,7 @@ repos: # Do not check files that are automatically generated '--skip=docs/Gemfile.lock,tests/unit/gpt2-merges.txt,tests/unit/gpt2-vocab.json', '--ignore-regex=\\n', # Do not count the 'n' in an escaped newline as part of a word - '--ignore-words-list=youn,unsupport,noe', # Word used in error messages that need rewording + '--ignore-words-list=youn,unsupport,noe,cann', # Word used in error messages that need rewording --check-filenames, --check-hidden ] diff --git a/README.md b/README.md index 304169b56777..2f6661ef5860 100755 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ DeepSpeed empowers ChatGPT-like model training with a single click, offering 15x speedup over SOTA RLHF systems with unprecedented cost reduction at all scales; [learn how](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat). - +* [2024/08] [DeepSpeed on Windows](https://github.com/microsoft/DeepSpeed/tree/master/blogs/windows/08-2024/README.md) [[日本語](https://github.com/microsoft/DeepSpeed/tree/master/blogs/windows/08-2024/japanese/README.md)] * [2024/08] [DeepNVMe: Improving DL Applications through I/O Optimizations](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-gds/README.md) [[日本語](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-gds/japanese/README.md)] * [2024/07] [DeepSpeed Universal Checkpointing: Efficient and Flexible Checkpointing for Large Scale Distributed Training](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-ucp/README.md) [[中文](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-ucp/chinese/README.md)] [[日本語](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-ucp/japanese/README.md)] * [2024/03] [DeepSpeed-FP6:The power of FP6-Centric Serving for Large Language Models](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fp6/03-05-2024) [[English](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fp6/03-05-2024/README.md)] [[中文](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fp6/03-05-2024/README-Chinese.md)] diff --git a/blogs/windows/08-2024/japanese/README.md b/blogs/windows/08-2024/japanese/README.md new file mode 100644 index 000000000000..7e437f737f58 --- /dev/null +++ b/blogs/windows/08-2024/japanese/README.md @@ -0,0 +1,123 @@ +
+ +# DeepSpeedのWindowsサポート + +
+ +# はじめに + +DeepSpeedは、分散学習と推論を簡単かつ効率的に行うための人気のあるオープンソースの深層学習最適化ライブラリです。DeepSpeedは、その豊富かつ高度な最適化機能(例:ZeRO、3D parallelism, MoEなど)のおかげで、Phi-3、Megatron-Turing-530B、BLOOM-176B、Arcticなどの最先端モデルの学習に広く利用されています。しかし、最も普及しているオペレーティングシステムであるMicrosoft Windowsをネイティブにサポートしていなかったため、多くのAI開発者やユーザーが、DeepSpeedの革新的な機能を利用できない状態でした。この問題を解決するため、DeepSpeedの完全な機能をWindows上でネイティブに実行し、Linux上と同じ使いやすさを実現するための取り組みを開始しました。 + +このブログでは、この取り組みの最初の成果をお知らせします。現在、DeepSpeedはWindowsにインストールし、単一GPUでの学習、ファインチューニング、および推論をネイティブに実行できるようになりました。ここで重要なこととして、インストールと利用は、Linuxとまったく同じように行えます。ファインチューニングと推論のワークロードを通じて、HuggingFace Transformers との統合、LoRAのサポート、CPUオフロードの3つの重要なDeepSpeedの機能が、正しく動作していることが確認できました。このWindowsサポートは、バージョン0.14.5以降で利用可能です。このブログの残りの部分では、これらの成果を示す例を紹介します。 + +# テスト環境 + +Windows 11 Version 23H2 および Build 22631.3880 を実行している Surface Laptop Studio 2 でテストを行いました。このハードウェアには、4GBのVRAMを搭載した NVIDIA RTX A2000 GPU が1つ搭載されています。また、PyTorchバージョン 2.3.0 および HuggingFace Transformersバージョン 4.41.2 を使用しました。使用したサンプルスクリプトは[DeepSpeedExamplesリポジトリ](https://github.com/microsoft/DeepSpeedExamples)から取得できます。以下の例を実行する前にリポジトリをクローンしてください。 + +# インストール + +DeepSpeedは、2つの方法でWindowsにインストールできます。より簡単な方法は、pipパッケージマネージャーを使用することで、もう一方はソースからビルドする方法です。どちらの場合も、Python 3.xとCUDAサポート付きのPyTorchが必要です。 + +## pipを使用したインストール + +DeepSpeedをインストールするには、単に次のコマンドを実行します: `pip install deepspeed`。 +これにより、最新バージョンのDeepSpeed(現時点では0.14.5)がインストールされます。Linux版とは異なり、Windows版ではすべてのオペレーターがすでにビルド済みであるため、CUDA SDKやC++コンパイラをインストールする必要はありません。 + +
+ +
+ +
+ pipによるWindowsへのDeepSpeedのインストール +
+ + +## ソースからのビルド + +ソースからDeepSpeedをビルドするには、DeepSpeedリポジトリをクローンし、コンパイルスクリプトである `build_win.bat` を実行する必要があります。 + +## インストールの検証 + +インストール方法にかかわらず、`ds_report`を実行してインストールが成功したかどうかを確認できます。出力は次のようになります: + +
+ +
+ +
+ DeepSpeedのWindowsインストールを確認するds_reportの出力 +
+ +# 事前学習の例 + +Windows上でDeepSpeedを使用した事前学習の例として、画像分類モデルCIFAR10と言語モデルBERTの実行例を示します。 + +## CIFAR10の事前学習 + +CIFAR10の事前学習に必要なスクリプトとコードは、次のパスにあります: `DeepSpeedExamples\training\cifar` + +以下のコマンドを使用してCIFAR10の事前学習を開始できます: `deepspeed cifar10_deepspeed.py –deepspeed` + +出力は次のようになります。 + +
+ +
+ +
+ DeepSpeedによるWindowsでのCIFAR10モデルの事前学習 +
+ +## BERTの事前学習 + +BERTの事前学習に必要なスクリプトとコードは、次のパスにあります: `DeepSpeedExamples\training\HelloDeepSpeed` + +以下のコマンドを使用してBERTの事前学習を開始できます: `deepspeed train_bert_ds.py --checkpoint_dir experiment_deepspeed` + +出力は次のようになります。 + +
+ +
+ +
+ DeepSpeedによるWindowsでのBERTモデルの事前学習 +
+ +# ファインチューニングの例 + +DeepSpeed-Chatアプリケーションの教師ありファインチューニング(supervised fine tuning; SFT)を使用して、ファインチューニングの機能を示します。LoRAおよびCPUオフロードメモリ最適化を有効にして、 HuggingFace の `facebook/opt-125m` モデルのSFTを実施します。この例を実行するためのコマンドラインは次のとおりです: `deepspeed training\step1_supervised_finetuning\main.py --model_name_or_path facebook/opt-125m --gradient_accumulation_steps 8 --lora_dim 128 --only_optimize_lora --print_loss --zero_stage 2 --deepspeed --dtype bf16 --offload --output_dir output` + +出力は次のようになります。 + +
+ +
+ +
+ DeepSpeedを使用したWindowsでの facebook/opt-125m モデルのファインチューニング +
+ +# 推論の例 + +推論の機能を示すために、トークン生成のためのZeRO-Inferenceを使用します。ZeRO-Inferenceは、CPUまたはNVMeメモリにオフロードすることで推論のハードウェアコストを削減します。ここでは、サンプルスクリプトを使用して、HuggingFaceのLlama-2-7Bモデルを使用したトークン生成を実行します。4GBのVRAMではモデルと生成処理の両方を実効するのに十分ではないため、モデルパラメータをCPUメモリにオフロードします。 + +次のコマンドラインを使用して、8トークンのプロンプトから32トークンを生成します: `deepspeed run_model.py --model meta-llama/Llama-2-7b-hf --batch-size 64 --prompt-len 8 --gen-len 32 --cpu-offload` + +出力は次のようになります。 + +
+ +
+ +
+ DeepSpeedのZeRO-InferenceによるWindowsでのLLAMA2-7Bのトークン生成 +
+ +# まとめ + +最も広く使われているオペレーティングシステムであるWindowsで、深層学習フレームワークであるDeepSpeedをネイティブに実行できるようにすることは、多くの人と組織が、今まさに進行中のAI革命の恩恵を受けるための重要な一歩です。このブログでは、この目標に向けたプロジェクトの、最初の成果を共有しました。Windowsのサポートは現在進行中のプロジェクトですが、今回の成果が多くのユーザにとって活用され、またさらに発展していけることを願っています。次のロードマップには、複数のGPUでの実行、モデルパラメータの量子化、パフォーマンスの詳細な分析が含まれます。 + +# 謝辞 + +このプロジェクトは、Costin Eseanu、Logan Adams、Elton Zheng、Reza Yazdani Aminabadi、Martin Cai、Olatunji Ruwaseを含むDeepSpeedメンバーによる大きな貢献の結果です。また、この機能を必要とし、様々な問題の解決策や、建設的なフィードバックを提供し、私たちと共に歩んでくれたDeepSpeedユーザーの重要な貢献に感謝します。 diff --git a/csrc/aio/py_lib/deepspeed_aio_op_desc.h b/csrc/aio/py_lib/deepspeed_aio_op_desc.h index 7305f6920c91..350d28d29d58 100644 --- a/csrc/aio/py_lib/deepspeed_aio_op_desc.h +++ b/csrc/aio/py_lib/deepspeed_aio_op_desc.h @@ -16,7 +16,7 @@ struct io_op_desc_t { const std::string _filename; const long long int _file_num_bytes; const int _num_threads; - const int _num_bytes_per_thread; + const long long int _num_bytes_per_thread; torch::Tensor _contiguous_buffer; const bool _validate; diff --git a/deepspeed/comm/config.py b/deepspeed/comm/config.py index 1c441bb6bfe9..57501c9dd237 100644 --- a/deepspeed/comm/config.py +++ b/deepspeed/comm/config.py @@ -3,20 +3,12 @@ # DeepSpeed Team -from .constants import * -from ..pydantic_v1 import BaseModel - +from deepspeed.runtime.config_utils import DeepSpeedConfigModel -class CommsConfig(BaseModel): - - class Config: - validate_all = True - validate_assignment = True - use_enum_values = True - extra = 'forbid' +from .constants import * -class CommsLoggerConfig(CommsConfig): +class CommsLoggerConfig(DeepSpeedConfigModel): enabled: bool = COMMS_LOGGER_ENABLED_DEFAULT prof_all: bool = COMMS_LOGGER_PROF_ALL_DEFAULT prof_ops: list = COMMS_LOGGER_PROF_OPS_DEFAULT diff --git a/deepspeed/inference/config.py b/deepspeed/inference/config.py index 1d5018aaa75b..c7c7684fff79 100644 --- a/deepspeed/inference/config.py +++ b/deepspeed/inference/config.py @@ -5,38 +5,25 @@ import torch import deepspeed -from deepspeed.pydantic_v1 import Field, validator +from pydantic import Field, field_validator from deepspeed.runtime.config_utils import DeepSpeedConfigModel from deepspeed.runtime.zero.config import DeepSpeedZeroConfig -from typing import Dict, Union +from typing import Dict, Union, Optional from enum import Enum class DtypeEnum(Enum): - # The torch dtype must always be the first value (so we return torch.dtype) - fp16 = torch.float16, "torch.float16", "fp16", "float16", "half" - fp32 = torch.float32, "torch.float32", "fp32", "float32", "float" - bf16 = torch.bfloat16, "torch.bfloat16", "bf16", "bfloat16", "bfloat" - int8 = torch.int8, "torch.int8", "int8" - - # Copied from https://stackoverflow.com/a/43210118 - # Allows us to use multiple values for each Enum index and returns first - # listed value when Enum is called - def __new__(cls, *values): - obj = object.__new__(cls) - # first value is canonical value - obj._value_ = values[0] - for other_value in values[1:]: - cls._value2member_map_[other_value] = obj - obj._all_values = values - return obj - - def __repr__(self): - return "<%s.%s: %s>" % ( - self.__class__.__name__, - self._name_, - ", ".join([repr(v) for v in self._all_values]), - ) + fp16 = (torch.float16, "torch.float16", "fp16", "float16", "half") + fp32 = (torch.float32, "torch.float32", "fp32", "float32", "float") + bf16 = (torch.bfloat16, "torch.bfloat16", "bf16", "bfloat16", "bfloat") + int8 = (torch.int8, "torch.int8", "int8") + + @classmethod + def from_str(cls, value: str): + for dtype in cls: + if value in dtype.value: + return dtype + raise ValueError(f"'{value}' is not a valid DtypeEnum") class MoETypeEnum(str, Enum): @@ -91,24 +78,24 @@ class QuantTypeEnum(str, Enum): class BaseQuantConfig(DeepSpeedConfigModel): - enabled = True - num_bits = 8 + enabled: bool = True + num_bits: int = 8 q_type: QuantTypeEnum = QuantTypeEnum.sym q_groups: int = 1 class WeightQuantConfig(BaseQuantConfig): - enabled = True + enabled: bool = True quantized_initialization: Dict = {} post_init_quant: Dict = {} class ActivationQuantConfig(BaseQuantConfig): - enabled = True + enabled: bool = True class QKVQuantConfig(DeepSpeedConfigModel): - enabled = True + enabled: bool = True class QuantizationConfig(DeepSpeedConfigModel): @@ -120,9 +107,9 @@ class QuantizationConfig(DeepSpeedConfigModel): # todo: brainstorm on how to do ckpt loading for DS inference class InferenceCheckpointConfig(DeepSpeedConfigModel): - checkpoint_dir: str = None - save_mp_checkpoint_path: str = None - base_dir: str = None + checkpoint_dir: Optional[str] = None + save_mp_checkpoint_path: Optional[str] = None + base_dir: Optional[str] = None class DeepSpeedInferenceConfig(DeepSpeedConfigModel): @@ -136,7 +123,7 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel): `(attention_output projection, transformer output projection)` """ - dtype: DtypeEnum = torch.float16 + dtype: torch.dtype = torch.float16 """ Desired model data type, will convert model to this type. Supported target types: `torch.half`, `torch.int8`, `torch.float` @@ -198,7 +185,7 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel): """ #todo: refactor the following 3 into the new checkpoint_config - checkpoint: Union[str, Dict] = None + checkpoint: Optional[Union[str, Dict]] = None """ Path to deepspeed compatible checkpoint or path to JSON with load policy. """ @@ -214,7 +201,7 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel): specifying whether the inference-module is created with empty or real Tensor """ - save_mp_checkpoint_path: str = None + save_mp_checkpoint_path: Optional[str] = None """ The path for which we want to save the loaded model with a checkpoint. This feature is used for adjusting the parallelism degree to help alleviate the @@ -243,19 +230,21 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel): replace_method: str = Field( "auto", - deprecated=True, - deprecated_msg="This parameter is no longer needed, please remove from your call to DeepSpeed-inference") + json_schema_extra={ + "deprecated": True, + "deprecated_msg": "This parameter is no longer needed, please remove from your call to DeepSpeed-inference" + }) - injection_policy: Dict = Field(None, alias="injection_dict") + injection_policy: Optional[Dict] = Field(None, alias="injection_dict") """ Dictionary mapping a client nn.Module to its corresponding injection policy. e.g., `{BertLayer : deepspeed.inference.HFBertLayerPolicy}` """ - injection_policy_tuple: tuple = None + injection_policy_tuple: Optional[tuple] = None """ TODO: Add docs """ - config: Dict = Field(None, alias="args") # todo: really no need for this field if we can refactor + config: Optional[Dict] = Field(None, alias="args") # todo: really no need for this field if we can refactor max_out_tokens: int = Field(1024, alias="max_tokens") """ @@ -274,31 +263,49 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel): transposed_mode: bool = Field(False, alias="transposed_mode") - mp_size: int = Field(1, deprecated=True, new_param="tensor_parallel.tp_size") + mp_size: int = Field(1, json_schema_extra={"deprecated": True, "new_param": "tensor_parallel.tp_size"}) """ Desired model parallel size, default is 1 meaning no model parallelism. Deprecated, please use the ``tensor_parallel` config to control model parallelism. """ - mpu: object = Field(None, deprecated=True, new_param="tensor_parallel.mpu") - ep_size: int = Field(1, deprecated=True, new_param="moe.ep_size") - ep_group: object = Field(None, alias="expert_group", deprecated=True, new_param="moe.ep_group") - ep_mp_group: object = Field(None, alias="expert_mp_group", deprecated=True, new_param="moe.ep_mp_group") - moe_experts: list = Field([1], deprecated=True, new_param="moe.moe_experts") - moe_type: MoETypeEnum = Field(MoETypeEnum.standard, deprecated=True, new_param="moe.type") - - @validator("moe") + mpu: object = Field(None, json_schema_extra={"deprecated": True, "new_param": "tensor_parallel.mpu"}) + ep_size: int = Field(1, json_schema_extra={"deprecated": True, "new_param": "moe.ep_size"}) + ep_group: object = Field(None, + alias="expert_group", + json_schema_extra={ + "deprecated": True, + "new_param": "moe.ep_group" + }) + ep_mp_group: object = Field(None, + alias="expert_mp_group", + json_schema_extra={ + "deprecated": True, + "new_param": "moe.ep_mp_group" + }) + moe_experts: list = Field([1], json_schema_extra={"deprecated": True, "new_param": "moe.moe_experts"}) + moe_type: MoETypeEnum = Field(MoETypeEnum.standard, + json_schema_extra={ + "deprecated": True, + "new_param": "moe.type" + }) + + @field_validator("dtype", mode="before") + def validate_dtype(cls, field_value, values): + if isinstance(field_value, str): + return DtypeEnum.from_str(field_value).value[0] + if isinstance(field_value, torch.dtype): + return field_value + raise TypeError(f"Invalid type for dtype: {type(field_value)}") + + @field_validator("moe") def moe_backward_compat(cls, field_value, values): if isinstance(field_value, bool): return DeepSpeedMoEConfig(moe=field_value) return field_value - @validator("use_triton") + @field_validator("use_triton") def has_triton(cls, field_value, values): if field_value and not deepspeed.HAS_TRITON: raise ValueError('Triton needs to be installed to use deepspeed with triton kernels') return field_value - - class Config: - # Get the str representation of the datatype for serialization - json_encoders = {torch.dtype: lambda x: str(x)} diff --git a/deepspeed/inference/v2/checkpoint/huggingface_engine.py b/deepspeed/inference/v2/checkpoint/huggingface_engine.py index 77dfa7a23b1e..d88d99ebebfd 100644 --- a/deepspeed/inference/v2/checkpoint/huggingface_engine.py +++ b/deepspeed/inference/v2/checkpoint/huggingface_engine.py @@ -108,6 +108,12 @@ def parameters(self) -> Iterable[Tuple[str, torch.Tensor]]: for checkpoint in self._all_ckpt_paths: inference_logger().info(f"Loading checkpoint: {checkpoint}") checkpoint_sd = self._checkpoint_load_fn(checkpoint) + + # If the model has tied embeddings, we need to make sure the lm_head weights are tied to the embeddings weights + if hasattr(self.model_config, "tie_word_embeddings") and self.model_config.tie_word_embeddings: + if self.model_config.model_type == "qwen2": + checkpoint_sd["lm_head.weight"] = checkpoint_sd["model.embed_tokens.weight"] + param_keys = list(checkpoint_sd.keys()) for param_name in param_keys: param = checkpoint_sd[param_name] diff --git a/deepspeed/inference/v2/config_v2.py b/deepspeed/inference/v2/config_v2.py index 85e4b7a0e0a0..325b57d8f56a 100644 --- a/deepspeed/inference/v2/config_v2.py +++ b/deepspeed/inference/v2/config_v2.py @@ -3,8 +3,9 @@ # DeepSpeed Team +from pydantic import Field from typing import Optional -from deepspeed.pydantic_v1 import Field + from deepspeed.runtime.config_utils import DeepSpeedConfigModel from .ragged import DSStateManagerConfig diff --git a/deepspeed/inference/v2/kernels/ragged_ops/includes/top_k_utils.h b/deepspeed/inference/v2/kernels/ragged_ops/includes/top_k_utils.h index 2cc430ccfe34..f5104f899d9c 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/includes/top_k_utils.h +++ b/deepspeed/inference/v2/kernels/ragged_ops/includes/top_k_utils.h @@ -14,5 +14,8 @@ } else if (4 == N_TOP_K) { \ constexpr int CONST_TOP_K = 4; \ __VA_ARGS__(); \ + } else if (8 == N_TOP_K) { \ + constexpr int CONST_TOP_K = 8; \ + __VA_ARGS__(); \ } \ }() diff --git a/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.py b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.py index 7e1ec1a13cb9..aacbec0bd3ae 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.py +++ b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.py @@ -19,7 +19,7 @@ class BlockedRotaryEmbeddings(DSKernelBase): supported_dtypes = [DtypeEnum.fp16, DtypeEnum.bf16] supported_head_sizes = [64, 80, 96, 128] - supported_q_ratios = [1, 2, 4, 5, 8, 16, 29, 35, 36, 71] + supported_q_ratios = [1, 2, 4, 5, 6, 7, 8, 16, 29, 35, 36, 71] def __init__(self, head_size: int, n_q_heads: int, n_kv_heads: int, dtype: torch.dtype, rotary_dim: int, theta_base: float) -> None: diff --git a/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary_cuda.cu b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary_cuda.cu index fbafece5ccf2..f7bc693eefee 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary_cuda.cu +++ b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary_cuda.cu @@ -265,6 +265,8 @@ void launch_kv_rotary_kernel(T* kv_cache, LAUNCH_KV_ROTARY_FOR_Q_RATIO(2) LAUNCH_KV_ROTARY_FOR_Q_RATIO(4) LAUNCH_KV_ROTARY_FOR_Q_RATIO(5) + LAUNCH_KV_ROTARY_FOR_Q_RATIO(6) + LAUNCH_KV_ROTARY_FOR_Q_RATIO(7) LAUNCH_KV_ROTARY_FOR_Q_RATIO(8) LAUNCH_KV_ROTARY_FOR_Q_RATIO(16) LAUNCH_KV_ROTARY_FOR_Q_RATIO(29) diff --git a/deepspeed/inference/v2/model_implementations/flat_model_helpers.py b/deepspeed/inference/v2/model_implementations/flat_model_helpers.py index ebdb59bca920..c5e02adaffc4 100644 --- a/deepspeed/inference/v2/model_implementations/flat_model_helpers.py +++ b/deepspeed/inference/v2/model_implementations/flat_model_helpers.py @@ -27,9 +27,9 @@ class TensorMetadata(DeepSpeedConfigModel): """ A class to represent a tensor specification. """ - dtype: Optional[str] - shape: Optional[Tuple[int, ...]] - strides: Optional[Tuple[int, ...]] + dtype: Optional[str] = None + shape: Optional[Tuple[int, ...]] = None + strides: Optional[Tuple[int, ...]] = None offset: int @@ -37,7 +37,7 @@ class ParameterMetadata(DeepSpeedConfigModel): """ A class to represent a parameter specification. """ - core_param: TensorMetadata = None + core_param: Optional[TensorMetadata] = None aux_params: Dict[str, TensorMetadata] = {} diff --git a/deepspeed/inference/v2/model_implementations/qwen_v2_moe/container.py b/deepspeed/inference/v2/model_implementations/qwen_v2_moe/container.py index b4621257ff82..e499379da7e3 100644 --- a/deepspeed/inference/v2/model_implementations/qwen_v2_moe/container.py +++ b/deepspeed/inference/v2/model_implementations/qwen_v2_moe/container.py @@ -8,45 +8,45 @@ from ..common_parameters import * from ..layer_container_base import LayerContainer ''' - # HF Qwen1.5-MoE-A2.7B model looks like this: + # HF Qwen2-57B-A14B model looks like this: Qwen2MoeForCausalLM( (model): Qwen2MoeModel( - (embed_tokens): Embedding(151936, 2048) + (embed_tokens): Embedding(151936, 3584) (layers): ModuleList( - (0-23): 24 x Qwen2MoeDecoderLayer( + (0-27): 28 x Qwen2MoeDecoderLayer( (self_attn): Qwen2MoeSdpaAttention( - (q_proj): Linear(in_features=2048, out_features=2048, bias=True) - (k_proj): Linear(in_features=2048, out_features=2048, bias=True) - (v_proj): Linear(in_features=2048, out_features=2048, bias=True) - (o_proj): Linear(in_features=2048, out_features=2048, bias=False) + (q_proj): Linear(in_features=3584, out_features=3584, bias=True) + (k_proj): Linear(in_features=3584, out_features=512, bias=True) + (v_proj): Linear(in_features=3584, out_features=512, bias=True) + (o_proj): Linear(in_features=3584, out_features=3584, bias=False) (rotary_emb): Qwen2MoeRotaryEmbedding() ) (mlp): Qwen2MoeSparseMoeBlock( - (gate): Linear(in_features=2048, out_features=60, bias=False) + (gate): Linear(in_features=3584, out_features=64, bias=False) (experts): ModuleList( - (0-59): 60 x Qwen2MoeMLP( - (gate_proj): Linear(in_features=2048, out_features=1408, bias=False) - (up_proj): Linear(in_features=2048, out_features=1408, bias=False) - (down_proj): Linear(in_features=1408, out_features=2048, bias=False) + (0-63): 64 x Qwen2MoeMLP( + (gate_proj): Linear(in_features=3584, out_features=2560, bias=False) + (up_proj): Linear(in_features=3584, out_features=2560, bias=False) + (down_proj): Linear(in_features=2560, out_features=3584, bias=False) (act_fn): SiLU() ) ) (shared_expert): Qwen2MoeMLP( - (gate_proj): Linear(in_features=2048, out_features=5632, bias=False) - (up_proj): Linear(in_features=2048, out_features=5632, bias=False) - (down_proj): Linear(in_features=5632, out_features=2048, bias=False) + (gate_proj): Linear(in_features=3584, out_features=20480, bias=False) + (up_proj): Linear(in_features=3584, out_features=20480, bias=False) + (down_proj): Linear(in_features=20480, out_features=3584, bias=False) (act_fn): SiLU() ) - (shared_expert_gate): Linear(in_features=2048, out_features=1, bias=False) + (shared_expert_gate): Linear(in_features=3584, out_features=1, bias=False) ) - (input_layernorm): Qwen2MoeRMSNorm() - (post_attention_layernorm): Qwen2MoeRMSNorm() + (input_layernorm): Qwen2MoeRMSNorm((3584,), eps=1e-06) + (post_attention_layernorm): Qwen2MoeRMSNorm((3584,), eps=1e-06) ) ) - (norm): Qwen2MoeRMSNorm() + (norm): Qwen2MoeRMSNorm((3584,), eps=1e-06) ) - (lm_head): Linear(in_features=2048, out_features=151936, bias=False) + (lm_head): Linear(in_features=3584, out_features=151936, bias=False) ) ''' diff --git a/deepspeed/inference/v2/model_implementations/qwen_v2_moe/model.py b/deepspeed/inference/v2/model_implementations/qwen_v2_moe/model.py index 7cddbf978369..c7841b24e5fc 100644 --- a/deepspeed/inference/v2/model_implementations/qwen_v2_moe/model.py +++ b/deepspeed/inference/v2/model_implementations/qwen_v2_moe/model.py @@ -73,7 +73,7 @@ def n_heads(self) -> int: @property def intermediate_dim(self) -> int: - return self._config.intermediate_size + return self._config.shared_expert_intermediate_size @property def n_heads_kv(self) -> int: diff --git a/deepspeed/inference/v2/modules/implementations/moe/cutlass_multi_gemm.py b/deepspeed/inference/v2/modules/implementations/moe/cutlass_multi_gemm.py index bd90cbd5d697..a9b01d1233cd 100644 --- a/deepspeed/inference/v2/modules/implementations/moe/cutlass_multi_gemm.py +++ b/deepspeed/inference/v2/modules/implementations/moe/cutlass_multi_gemm.py @@ -42,7 +42,7 @@ def supports_config(config: DSMoEConfig) -> bool: if config.input_dtype != torch.float16 and config.input_dtype != torch.bfloat16: return False - if config.top_k != 1 and config.top_k != 2 and config.top_k != 4: + if config.top_k != 1 and config.top_k != 2 and config.top_k != 4 and config.top_k != 8: return False return True diff --git a/deepspeed/inference/v2/ragged/manager_configs.py b/deepspeed/inference/v2/ragged/manager_configs.py index a5e98e5bcef1..17283b8bc0c4 100644 --- a/deepspeed/inference/v2/ragged/manager_configs.py +++ b/deepspeed/inference/v2/ragged/manager_configs.py @@ -6,7 +6,7 @@ from enum import Enum from typing import Tuple -from deepspeed.pydantic_v1 import PositiveInt, validator +from pydantic import PositiveInt, model_validator from deepspeed.runtime.config_utils import DeepSpeedConfigModel from ..inference_utils import DtypeEnum @@ -173,11 +173,9 @@ class DSStateManagerConfig(DeepSpeedConfigModel): Enable tracking for offloading KV-cache to host memory. Currently unsupported. """ - @validator("max_ragged_sequence_count") - def max_ragged_sequence_count_validator(cls, v: int, values: dict): + @model_validator(mode="after") + def max_ragged_sequence_count_validator(self): # If the attributes below failed their validation they won't appear in the values dict. - if "max_tracked_sequences" in values and v > values["max_tracked_sequences"]: - raise ValueError("max_ragged_sequence_count must be less than max_tracked_sequences") - if "max_ragged_batch_size" in values and v > values["max_ragged_batch_size"]: - raise ValueError("max_ragged_sequence_count must be less than max_ragged_batch_size") - return v + assert self.max_ragged_sequence_count <= self.max_tracked_sequences, "max_ragged_sequence_count must be less than max_tracked_sequences" + assert self.max_ragged_sequence_count <= self.max_ragged_batch_size, "max_ragged_sequence_count must be less than max_ragged_batch_size" + return self diff --git a/deepspeed/moe/sharded_moe.py b/deepspeed/moe/sharded_moe.py index c09a11e213db..340bc82de508 100644 --- a/deepspeed/moe/sharded_moe.py +++ b/deepspeed/moe/sharded_moe.py @@ -208,7 +208,7 @@ def top1gating(logits: Tensor, mask1 = einsum("s,se->se", used_token, mask1) # gating decisions - exp_counts = torch.sum(mask1, dim=0).detach().to('cpu') + exp_counts = torch.sum(mask1, dim=0).detach().to(logits.device) # if we don't want to drop any tokens if not drop_tokens: @@ -324,7 +324,7 @@ def top2gating(logits: Tensor, l_aux = torch.mean(me * ce) * num_experts * num_experts # gating decisions - exp_counts = torch.sum(mask1 + mask2, dim=0) + exp_counts = torch.sum(mask1 + mask2, dim=0).detach().to(logits.device) if drop_tokens: # Calculate configured capacity and remove locations outside capacity from mask @@ -368,7 +368,7 @@ def top2gating(logits: Tensor, combine_weights = combine1_sec + combine2_sec dispatch_mask = combine_weights.bool() - return l_aux, combine_weights, dispatch_mask, exp_counts.detach().to('cpu') + return l_aux, combine_weights, dispatch_mask, exp_counts def topkgating( diff --git a/deepspeed/monitor/config.py b/deepspeed/monitor/config.py index c4200877089a..960ce1ba997a 100644 --- a/deepspeed/monitor/config.py +++ b/deepspeed/monitor/config.py @@ -5,7 +5,7 @@ from typing import Optional -from deepspeed.pydantic_v1 import root_validator +from pydantic import model_validator from deepspeed.runtime.config_utils import DeepSpeedConfigModel @@ -36,10 +36,10 @@ class WandbConfig(DeepSpeedConfigModel): enabled: bool = False """ Whether logging to WandB is enabled. Requires `wandb` package is installed. """ - group: str = None + group: Optional[str] = None """ Name for the WandB group. This can be used to group together runs. """ - team: str = None + team: Optional[str] = None """ Name for the WandB team. """ project: str = "deepspeed" @@ -137,8 +137,8 @@ class DeepSpeedMonitorConfig(DeepSpeedConfigModel): csv_monitor: CSVConfig = {} """ Local CSV output of monitoring data. """ - @root_validator - def check_enabled(cls, values): - values["enabled"] = values.get("tensorboard").enabled or values.get("wandb").enabled or values.get( - "csv_monitor").enabled or values.get("comet").enabled - return values + @model_validator(mode="after") + def check_enabled(self): + enabled = self.tensorboard.enabled or self.wandb.enabled or self.csv_monitor.enabled or self.comet.enabled + self.__dict__["enabled"] = enabled + return self diff --git a/deepspeed/pydantic_v1.py b/deepspeed/pydantic_v1.py deleted file mode 100644 index 6aba072ad929..000000000000 --- a/deepspeed/pydantic_v1.py +++ /dev/null @@ -1,16 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# SPDX-License-Identifier: Apache-2.0 - -# DeepSpeed Team -"""Pydantic v1 compatibility module. - -Pydantic v2 introduced breaking changes that hinder its adoption: -https://docs.pydantic.dev/latest/migration/. To provide deepspeed users the option to -migrate to pydantic v2 on their own timeline, deepspeed uses this compatibility module -as a pydantic-version-agnostic alias for pydantic's v1 API. -""" - -try: - from pydantic.v1 import * # noqa: F401 -except ImportError: - from pydantic import * # noqa: F401 diff --git a/deepspeed/runtime/config_utils.py b/deepspeed/runtime/config_utils.py index 5522a8e79d69..d5c3a1548360 100755 --- a/deepspeed/runtime/config_utils.py +++ b/deepspeed/runtime/config_utils.py @@ -5,11 +5,12 @@ """ Collection of DeepSpeed configuration utilities """ -import json import collections -import collections.abc +import json +import torch from functools import reduce -from deepspeed.pydantic_v1 import BaseModel +from pydantic import BaseModel, ConfigDict, field_serializer + from deepspeed.utils import logger @@ -54,67 +55,73 @@ def __init__(self, strict=False, **data): if (not strict): # This is temporary until we refactor all DS configs, allows HF to load models data = {k: v for k, v in data.items() if (v != "auto" or k == "replace_method")} super().__init__(**data) - self._deprecated_fields_check(self) + self._deprecated_fields_check() - def _process_deprecated_field(self, pydantic_config, field): + def _process_deprecated_field(self, dep_field): # Get information about the deprecated field - fields_set = pydantic_config.__fields_set__ - dep_param = field.name - kwargs = field.field_info.extra + pydantic_config = self + fields_set = pydantic_config.model_fields_set + kwargs = pydantic_config.model_fields[dep_field].json_schema_extra new_param_fn = kwargs.get("new_param_fn", lambda x: x) - param_value = new_param_fn(getattr(pydantic_config, dep_param)) - new_param = kwargs.get("new_param", "") + param_value = new_param_fn(getattr(pydantic_config, dep_field)) + new_field = kwargs.get("new_param", "") dep_msg = kwargs.get("deprecated_msg", "") - if dep_param in fields_set: - logger.warning(f"Config parameter {dep_param} is deprecated" + - (f" use {new_param} instead" if new_param else "") + (f". {dep_msg}" if dep_msg else "")) + if dep_field in fields_set: + logger.warning(f"Config parameter {dep_field} is deprecated" + + (f" use {new_field} instead" if new_field else "") + (f". {dep_msg}" if dep_msg else "")) # Check if there is a new param and if it should be set with a value - if new_param and kwargs.get("set_new_param", True): + if new_field and kwargs.get("set_new_param", True): # Remove the deprecate field if there is a replacing field try: - delattr(pydantic_config, dep_param) + delattr(pydantic_config, dep_field) except Exception as e: - logger.error(f"Tried removing deprecated '{dep_param}' from config") + logger.error(f"Tried removing deprecated '{dep_field}' from config") raise e # Set new param value - new_param_nested = new_param.split(".") + new_param_nested = new_field.split(".") if len(new_param_nested) > 1: # If the new param exists in a subconfig, we need to get # the fields set for that subconfig pydantic_config = reduce(getattr, new_param_nested[:-1], pydantic_config) - fields_set = pydantic_config.__fields_set__ + fields_set = pydantic_config.model_fields_set new_param_name = new_param_nested[-1] assert ( new_param_name not in fields_set - ), f"Cannot provide deprecated parameter '{dep_param}' and replacing parameter '{new_param}' together" + ), f"Cannot provide deprecated parameter '{dep_field}' and replacing parameter '{new_field}' together" # A custom function for converting the old param value to new param value can be provided try: setattr(pydantic_config, new_param_name, param_value) except Exception as e: - logger.error(f"Tried setting value for '{new_param}' with value from deprecated '{dep_param}'") + logger.error(f"Tried setting value for '{new_field}' with value from deprecated '{dep_field}'") raise e - def _deprecated_fields_check(self, pydantic_config): - fields = pydantic_config.__fields__ - for field in fields.values(): - if field.field_info.extra.get("deprecated", False): - self._process_deprecated_field(pydantic_config, field) + def _deprecated_fields_check(self): + fields = self.model_fields + for field_name, field_info in fields.items(): + if field_info.json_schema_extra and field_info.json_schema_extra.get("deprecated", False): + self._process_deprecated_field(field_name) + + model_config = ConfigDict( + validate_default=True, + validate_assignment=True, + use_enum_values=True, + populate_by_name=True, + extra="forbid", + arbitrary_types_allowed=True, + protected_namespaces=(), + ) - class Config: - validate_all = True - validate_assignment = True - use_enum_values = True - allow_population_by_field_name = True - extra = "forbid" - arbitrary_types_allowed = True + @field_serializer("dtype", check_fields=False) + def serialize_torch_dtype(dtype: torch.dtype) -> str: + return str(dtype) def get_config_default(config, field_name): - assert field_name in config.__fields__, f"'{field_name}' is not a field in {config}" - assert not config.__fields__.get( - field_name).required, f"'{field_name}' is a required field and does not have a default value" - return config.__fields__.get(field_name).default + assert field_name in config.model_fields, f"'{field_name}' is not a field in {config}" + assert not config.model_fields.get( + field_name).is_required(), f"'{field_name}' is a required field and does not have a default value" + return config.model_fields.get(field_name).get_default() class pp_int(int): diff --git a/deepspeed/runtime/zero/config.py b/deepspeed/runtime/zero/config.py index 2089d59dbce4..1cfcd784e2ce 100644 --- a/deepspeed/runtime/zero/config.py +++ b/deepspeed/runtime/zero/config.py @@ -6,7 +6,7 @@ import sys from typing import Optional from enum import Enum -from deepspeed.pydantic_v1 import Field, validator, root_validator +from pydantic import Field, model_validator from deepspeed.runtime.config_utils import get_scalar_param, pp_int, DeepSpeedConfigModel from deepspeed.utils import logger from .offload_config import DeepSpeedZeroOffloadParamConfig, DeepSpeedZeroOffloadOptimizerConfig, OffloadDeviceEnum @@ -30,7 +30,7 @@ "reduce_bucket_size": 500000000, "load_from_fp32_weights": [true|false], "cpu_offload": [true|false] (deprecated), - "cpu_offload_params" : [true|false] (deprecated), + "cpu_offload_param" : [true|false] (deprecated), "cpu_offload_use_pin_memory": [true|false] (deprecated), "sub_group_size" : 1000000000000, "offload_param": {...}, @@ -128,7 +128,7 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel): the allgather for large model sizes """ - overlap_comm: bool = None # None for dynamic default value (see validator `overlap_comm_valid` below) + overlap_comm: Optional[bool] = None # None for dynamic default value (see validator `overlap_comm_valid` below) """ Attempts to overlap the reduction of the gradients with backward computation """ @@ -168,27 +168,37 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel): parameters). Used by ZeRO3-Offload and ZeRO-Infinity """ - cpu_offload_param: bool = Field( + cpu_offload_param: Optional[bool] = Field( None, - deprecated=True, - new_param="offload_param", - new_param_fn=(lambda val: DeepSpeedZeroOffloadParamConfig(device=OffloadDeviceEnum.cpu) if val else None), + json_schema_extra={ + "deprecated": True, + "new_param": "offload_param", + "new_param_fn": (lambda val: DeepSpeedZeroOffloadParamConfig(device=OffloadDeviceEnum.cpu) + if val else None) + }, ) """ Deprecated, please use ``offload_param`` """ - cpu_offload_use_pin_memory: bool = Field( + cpu_offload_use_pin_memory: Optional[bool] = Field( None, - deprecated=True, - new_param="offload_param or offload_optimizer", - set_new_param=False, + json_schema_extra={ + "deprecated": True, + "new_param": "offload_param or offload_optimizer", + "set_new_param": False + }, ) """ Deprecated, please use ``offload_param`` or ``offload_optimizer`` """ - cpu_offload: bool = Field( + cpu_offload: Optional[bool] = Field( None, - deprecated=True, - new_param="offload_optimizer", - new_param_fn=(lambda val: DeepSpeedZeroOffloadOptimizerConfig(device=OffloadDeviceEnum.cpu) if val else None), + json_schema_extra={ + "deprecated": + True, + "new_param": + "offload_optimizer", + "new_param_fn": (lambda val: DeepSpeedZeroOffloadOptimizerConfig(device=OffloadDeviceEnum.cpu) + if val else None) + }, ) """ Deprecated, please use ``offload_optimizer`` """ @@ -242,8 +252,10 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel): """ stage3_gather_fp16_weights_on_model_save: bool = Field(False, - deprecated=True, - new_param="gather_16bit_weights_on_model_save") + json_schema_extra={ + "deprecated": True, + "new_param": "gather_16bit_weights_on_model_save" + }) """ Deprecated, please use ``gather_16bit_weights_on_model_save`` """ ignore_unused_parameters: bool = True @@ -309,16 +321,15 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel): """ # Validators - @validator("overlap_comm") - def overlap_comm_valid(cls, field_value, values): - if field_value is None: - assert ("stage" in values), "DeepSpeedZeroConfig: 'stage' must be defined before 'overlap_comm'" - field_value = values["stage"] == ZeroStageEnum.weights - return field_value - - @root_validator - def offload_ratio_check(cls, values): - offload_config = getattr(values, "offload_optimizer", {}) + @model_validator(mode="after") + def overlap_comm_valid(self): + if self.overlap_comm is None: + self.overlap_comm = self.stage == ZeroStageEnum.weights + return self + + @model_validator(mode="after") + def offload_ratio_check(self): + offload_config = self.offload_optimizer if offload_config and offload_config.ratio < 1.0: - assert values.get("stage") == ZeroStageEnum.weights, "Partial offloading only supported for ZeRO Stage 3." - return values + assert self.stage == ZeroStageEnum.weights, "Partial offloading only supported for ZeRO Stage 3." + return self diff --git a/deepspeed/runtime/zero/mics.py b/deepspeed/runtime/zero/mics.py index 99e3bce9ecd0..c9ae58a121de 100755 --- a/deepspeed/runtime/zero/mics.py +++ b/deepspeed/runtime/zero/mics.py @@ -17,7 +17,7 @@ from deepspeed.runtime.zero.parameter_offload import DeepSpeedZeRoOffload from deepspeed.runtime.zero.partition_parameters import Init, AllGatherCoalescedHandle, ZeroParamStatus from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3 -from deepspeed.utils import instrument_w_nvtx, log_dist +from deepspeed.utils import instrument_w_nvtx, log_dist, logger from deepspeed.accelerator import get_accelerator from torch import Tensor from torch.nn import Parameter @@ -88,6 +88,8 @@ def __init__(self, if it was constructed in the context. data_parallel_group (``deepspeed.comm`` process group, optional): The group of processes to partition among. Defaults to all processes. + Synonymous with sequence data parallel group for param partitioning + across both sequence and data parallel groups. mem_efficient_linear (bool, optional): Replace torch.nn.functional.linear with an implementation that allows DeepSpeed to partition parameters. Defaults to ``True``. @@ -149,16 +151,19 @@ def __init__(self, dist.init_distributed() assert dist.is_initialized(), "Parameters cannot be scattered without initializing deepspeed.comm" - if data_parallel_group is None and sequence_data_parallel_group is None: + if data_parallel_group is None: ds_process_group = dist.get_world_group() - elif sequence_data_parallel_group is not None: - ds_process_group = sequence_data_parallel_group - elif data_parallel_group is not None: + else: ds_process_group = data_parallel_group - else: # both given - raise ValueError( - "Both 'data_parallel_group' and 'sequence_data_parallel_group' were specified. Please provide only one of these arguments." - ) + + if sequence_data_parallel_group is not None: + logger.warning( + f"sequence_data_parallel_group' is deprecated and will be removed. Use 'data_parallel_group' instead.") + if data_parallel_group is not None: + raise ValueError( + "Both 'data_parallel_group' and 'sequence_data_parallel_group' were specified. Please provide only one of these arguments." + ) + self.ds_process_group = sequence_data_parallel_group self.mics_comm_groups = create_mics_comm_groups( _ds_config.mics_shard_size, diff --git a/deepspeed/runtime/zero/offload_config.py b/deepspeed/runtime/zero/offload_config.py index b7adc13a0ea2..74a5673bc1bc 100644 --- a/deepspeed/runtime/zero/offload_config.py +++ b/deepspeed/runtime/zero/offload_config.py @@ -5,7 +5,9 @@ from enum import Enum from pathlib import Path -from deepspeed.pydantic_v1 import Field, validator +from pydantic import Field, model_validator +from typing import Optional + from deepspeed.runtime.config_utils import DeepSpeedConfigModel, pp_int @@ -25,7 +27,7 @@ class DeepSpeedZeroOffloadParamConfig(DeepSpeedConfigModel): `nvme`. """ - nvme_path: Path = None + nvme_path: Optional[Path] = None """ Filesystem path for NVMe device for parameter offloading. """ buffer_count: int = Field(5, ge=0) @@ -56,7 +58,7 @@ class DeepSpeedZeroOffloadOptimizerConfig(DeepSpeedConfigModel): `nvme`. Optimizer computation is offload to CPU regardless of device option. """ - nvme_path: Path = None + nvme_path: Optional[Path] = None """ Filesystem path for NVMe device for optimizer state offloading. """ buffer_count: int = Field(4, ge=0) @@ -88,10 +90,11 @@ class DeepSpeedZeroOffloadOptimizerConfig(DeepSpeedConfigModel): fast_init: bool = False """ Enable fast optimizer initialization when offloading to NVMe. """ - @validator("pipeline_read", "pipeline_write", always=True) - def set_pipeline(cls, field_value, values): - values["pipeline"] = field_value or values.get("pipeline", False) - return field_value - ratio: float = Field(1.0, ge=0.0, le=1.0) """ Percentage of offloaded optimizer states to CPU Adam. Only valid with ZeRO Stage 3.""" + + @model_validator(mode="after") + def set_pipeline(self): + pipeline = self.pipeline_read or self.pipeline_write + self.__dict__["pipeline"] = pipeline + return self diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index f76bcf0eb781..e01925cbd32b 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -814,24 +814,22 @@ class Init(InsertPostInitMethodToModuleSubClasses): apply_param_persistence = False override_module_apply = get_config_default(DeepSpeedZeroConfig, "override_module_apply") - def __init__( - self, - module=None, - data_parallel_group=None, - mem_efficient_linear=True, - remote_device=None, - pin_memory=False, - config_dict_or_path=None, - config=None, - enabled=True, - dtype=None, - mpu=None, - zero_param_parallel_group=None, - zero_quantized_weights=False, - zero_quantized_nontrainable_weights=False, - sequence_data_parallel_group=None, - param_swapper=None, - ): + def __init__(self, + module=None, + data_parallel_group=None, + mem_efficient_linear=True, + remote_device=None, + pin_memory=False, + config_dict_or_path=None, + config=None, + enabled=True, + dtype=None, + mpu=None, + zero_param_parallel_group=None, + zero_quantized_weights=False, + zero_quantized_nontrainable_weights=False, + sequence_data_parallel_group=None, + param_swapper=None): """A context to enable massive model construction for training with ZeRO-3. Models are automatically partitioned (or, sharded) across the system and converted to half precision. @@ -841,6 +839,8 @@ def __init__( if it was constructed in the context. data_parallel_group (``deepspeed.comm`` process group, optional): The group of processes to partition among. Defaults to all processes. + Synonymous with sequence data parallel group for param partitioning + across both sequence and data parallel groups. mem_efficient_linear (bool, optional): Replace torch.nn.functional.linear with an implementation that allows DeepSpeed to partition parameters. Defaults to ``True``. @@ -940,16 +940,19 @@ def __init__( init_distributed() assert dist.is_initialized(), "Parameters cannot be scattered without initializing deepspeed.comm" - if data_parallel_group is None and sequence_data_parallel_group is None: + if data_parallel_group is None: self.ds_process_group = dist.get_world_group() - elif sequence_data_parallel_group is not None: - self.ds_process_group = sequence_data_parallel_group - elif data_parallel_group is not None: + else: self.ds_process_group = data_parallel_group - else: # both given - raise ValueError( - "Both 'data_parallel_group' and 'sequence_data_parallel_group' were specified. Please provide only one of these arguments." - ) + + if sequence_data_parallel_group is not None: + logger.warning( + f"sequence_data_parallel_group' is deprecated and will be removed. Use 'data_parallel_group' instead.") + if data_parallel_group is not None: + raise ValueError( + "Both 'data_parallel_group' and 'sequence_data_parallel_group' were specified. Please provide only one of these arguments." + ) + self.ds_process_group = sequence_data_parallel_group self.rank = dist.get_rank(group=self.ds_process_group) self.dp_world_size = dist.get_world_size(group=self.ds_process_group) diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 57e80911d645..83cf996ca019 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -725,8 +725,9 @@ def reduce_gradients(self, pipeline_parallel=False): def get_first_param_index(self, group_id, param_group, partition_id): for index, param in enumerate(param_group): param_id = self.get_param_id(param) - if partition_id in self.param_to_partition_ids[group_id][param_id]: - return index + if group_id in self.param_to_partition_ids and param_id in self.param_to_partition_ids[group_id]: + if partition_id in self.param_to_partition_ids[group_id][param_id]: + return index return None def initialize_gradient_partitioning_data_structures(self): diff --git a/docs/_tutorials/accelerator-abstraction-interface.md b/docs/_tutorials/accelerator-abstraction-interface.md index d7c153638c0d..ec92d5ebcb93 100644 --- a/docs/_tutorials/accelerator-abstraction-interface.md +++ b/docs/_tutorials/accelerator-abstraction-interface.md @@ -81,6 +81,7 @@ torch.distributed.init_process_group(get_accelerator().communication_backend_nam [Accelerator Setup Guide](accelerator-setup-guide.md) provides a guide on how to setup different accelerators for DeepSpeed. It also comes with simple example how to run deepspeed for different accelerators. The following guides are provided: 1. Run DeepSpeed model on CPU 2. Run DeepSpeed model on XPU +3. Run DeepSpeed model on Huawei Ascend NPU # Implement new accelerator extension It is possible to implement a new DeepSpeed accelerator extension to support new accelerator in DeepSpeed. An example to follow is _[Intel Extension For DeepSpeed](https://github.com/intel/intel-extension-for-deepspeed/)_. An accelerator extension contains the following components: diff --git a/docs/_tutorials/accelerator-setup-guide.md b/docs/_tutorials/accelerator-setup-guide.md index cf2d01d2b25c..36d33f31b5b3 100644 --- a/docs/_tutorials/accelerator-setup-guide.md +++ b/docs/_tutorials/accelerator-setup-guide.md @@ -8,6 +8,7 @@ tags: getting-started - [Introduction](#introduction) - [Intel Architecture (IA) CPU](#intel-architecture-ia-cpu) - [Intel XPU](#intel-xpu) +- [Huawei Ascend NPU](#huawei-ascend-npu) # Introduction DeepSpeed supports different accelerators from different companies. Setup steps to run DeepSpeed on certain accelerators might be different. This guide allows user to lookup setup instructions for the accelerator family and hardware they are using. @@ -132,3 +133,115 @@ accelerator: xpu ## More example for using DeepSpeed on Intel XPU Refer to https://github.com/intel/intel-extension-for-pytorch/tree/release/xpu/2.1.40/examples/gpu/inference/python/llm for more extensive guide. + + +# Huawei Ascend NPU + +DeepSpeed has been verified on the following Huawei Ascend NPU products: +* Atlas 300T A2 + +## Installation steps for Huawei Ascend NPU + +The following steps outline the process for installing DeepSpeed on an Huawei Ascend NPU: +1. Install the Huawei Ascend NPU Driver and Firmware +
+ Click to expand + + Before proceeding with the installation, please download the necessary files from [Huawei Ascend NPU Driver and Firmware](https://www.hiascend.com/en/hardware/firmware-drivers/commercial?product=4&model=11). + + The following instructions below are sourced from the [Ascend Community](https://www.hiascend.com/document/detail/en/canncommercial/700/quickstart/quickstart/quickstart_18_0002.html) (refer to the [Chinese version](https://www.hiascend.com/document/detail/zh/canncommercial/700/quickstart/quickstart/quickstart_18_0002.html)): + + - Execute the following command to install the driver: + ``` + ./Ascend-hdk--npu-driver_x.x.x_linux-{arch}.run --full --install-for-all + ``` + + - Execute the following command to install the firmware: + ``` + ./Ascend-hdk--npu-firmware_x.x.x.x.X.run --full + ``` +
+ +2. Install CANN +
+ Click to expand + + Prior to installation, download the [CANN Toolkit](https://www.hiascend.com/en/software/cann/commercial). + + - Install third-party dependencies. + - Ubuntu (The operations are the same for Debian, UOS20, and Linux.) + ``` + apt-get install -y gcc g++ make cmake zlib1g zlib1g-dev openssl libsqlite3-dev libssl-dev libffi-dev unzip pciutils net-tools libblas-dev gfortran libblas3 + ``` + - openEuler (The operations are the same for EulerOS, CentOS, and BC-Linux.) + ``` + yum install -y gcc gcc-c++ make cmake unzip zlib-devel libffi-devel openssl-devel pciutils net-tools sqlite-devel lapack-devel gcc-gfortran + ``` + - Install the required Python dependencies: + ``` + pip3 install attrs numpy decorator sympy cffi pyyaml pathlib2 psutil protobuf scipy requests absl-py wheel typing_extensions + ``` + - Install the CANN Toolkit. + ``` + ./Ascend-cann-toolkit_x.x.x_linux-{arch}.run --install + ``` +
+ +3. Install PyTorch \ + `pip install torch torch_npu` + +4. Install DeepSpeed \ + `pip install deepspeed` + +You can view the installation results using the `ds_report` command, Here is an example: +``` +-------------------------------------------------- +DeepSpeed C++/CUDA extension op report +-------------------------------------------------- +NOTE: Ops not installed will be just-in-time (JIT) compiled at + runtime if needed. Op compatibility means that your system + meet the required dependencies to JIT install the op. +-------------------------------------------------- +JIT compiled ops requires ninja +ninja .................. [OKAY] +-------------------------------------------------- +op name ................ installed .. compatible +-------------------------------------------------- +deepspeed_not_implemented [NO] ....... [OKAY] +async_io ............... [NO] ....... [OKAY] +cpu_adagrad ............ [NO] ....... [OKAY] +cpu_adam ............... [NO] ....... [OKAY] +cpu_lion ............... [NO] ....... [OKAY] +fused_adam ............. [NO] ....... [OKAY] +transformer_inference .. [NO] ....... [OKAY] +-------------------------------------------------- +DeepSpeed general environment info: +torch install path ............... ['/root/miniconda3/envs/ds/lib/python3.10/site-packages/torch'] +torch version .................... 2.2.0 +deepspeed install path ........... ['/root/miniconda3/envs/ds/lib/python3.10/site-packages/deepspeed'] +deepspeed info ................... 0.14.4, unknown, unknown +deepspeed wheel compiled w. ...... torch 2.2 +torch_npu install path ........... ['/root/miniconda3/envs/ds/lib/python3.10/site-packages/torch_npu'] +torch_npu version ................ 2.2.0 +ascend_cann version .............. 8.0.RC2.alpha002 +shared memory (/dev/shm) size .... 20.00 GB +``` + +## How to launch DeepSpeed on Huawei Ascend NPU + +To validate the Huawei Ascend NPU availability and if the accelerator is correctly chosen, here is an example(Huawei Ascend NPU detection is automatic starting with DeepSpeed v0.12.6): +``` +>>> import torch +>>> print('torch:',torch.__version__) +torch: 2.2.0 +>>> import torch_npu +>>> print('torch_npu:',torch.npu.is_available(),",version:",torch_npu.__version__) +torch_npu: True ,version: 2.2.0 +>>> from deepspeed.accelerator import get_accelerator +>>> print('accelerator:', get_accelerator()._name) +accelerator: npu +``` + +## Multi-card parallel training using Huawei Ascend NPU + +To perform model training across multiple Huawei Ascend NPU cards using DeepSpeed, see the examples provided in [DeepSpeed Examples](https://github.com/microsoft/DeepSpeedExamples/blob/master/training/cifar/cifar10_deepspeed.py). diff --git a/docs/_tutorials/getting-started.md b/docs/_tutorials/getting-started.md index f9a4cfdc68b4..ce9e3ee9a892 100644 --- a/docs/_tutorials/getting-started.md +++ b/docs/_tutorials/getting-started.md @@ -11,6 +11,7 @@ tags: getting-started * To get started with DeepSpeed on AzureML, please see the [AzureML Examples GitHub](https://github.com/Azure/azureml-examples/tree/main/cli/jobs/deepspeed) * DeepSpeed has direct integrations with [HuggingFace Transformers](https://github.com/huggingface/transformers) and [PyTorch Lightning](https://github.com/PyTorchLightning/pytorch-lightning). HuggingFace Transformers users can now easily accelerate their models with DeepSpeed through a simple ``--deepspeed`` flag + config file [See more details](https://huggingface.co/docs/transformers/main_classes/deepspeed). PyTorch Lightning provides easy access to DeepSpeed through the Lightning Trainer [See more details](https://pytorch-lightning.readthedocs.io/en/stable/advanced/multi_gpu.html?highlight=deepspeed#deepspeed). * DeepSpeed on AMD can be used via our [ROCm images](https://hub.docker.com/r/deepspeed/rocm501/tags), e.g., `docker pull deepspeed/rocm501:ds060_pytorch110`. +* DeepSpeed also supports Intel Xeon CPU, Intel Data Center Max Series XPU, Intel Gaudi HPU, Huawei Ascend NPU etc, please refer to the [accelerator setup guide](/tutorials/accelerator-setup-guide/) @@ -226,6 +227,36 @@ deepspeed --include="worker-2:0,1" \ \ --deepspeed --deepspeed_config ds_config.json ``` +### Launching without passwordless SSH + +DeepSpeed now supports launching training jobs without the need for passwordless SSH. This mode is +particularly useful in cloud environments such as Kubernetes, where flexible container orchestration +is possible, and setting up a leader-worker architecture with passwordless SSH adds unnecessary +complexity. + +To use this mode, you need to run the DeepSpeed command separately on all nodes. The command should +be structured as follows: + +```bash +deepspeed --hostfile=myhostfile --no_ssh --node_rank= \ + --master_addr= --master_port= \ + \ + --deepspeed --deepspeed_config ds_config.json +``` + +- `--hostfile=myhostfile`: Specifies the hostfile that contains information about the nodes and GPUs. +- `--no_ssh`: Enables the no-SSH mode. +- `--node_rank=`: Specifies the rank of the node. This should be a unique integer from 0 to n - 1. +- `--master_addr=`: The address of the leader node (rank 0). +- `--master_port=`: The port of the leader node. + +In this setup, the hostnames in the hostfile do not need to be reachable via passwordless SSH. +However, the hostfile is still required for the launcher to collect information about the environment, +such as the number of nodes and the number of GPUs per node. + +Each node must be launched with a unique `node_rank`, and all nodes must be provided with the address +and port of the leader node (rank 0). This mode causes the launcher to act similarly to the `torchrun` +launcher, as described in the [PyTorch documentation](https://pytorch.org/docs/stable/elastic/run.html). ## Multi-Node Environment Variables diff --git a/op_builder/evoformer_attn.py b/op_builder/evoformer_attn.py index af3aa7429775..7f68ccf87290 100644 --- a/op_builder/evoformer_attn.py +++ b/op_builder/evoformer_attn.py @@ -52,11 +52,27 @@ def is_compatible(self, verbose=False): if verbose: self.warning("Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH") return False - with open(f'{self.cutlass_path}/CHANGELOG.md', 'r') as f: - if '3.1.0' not in f.read(): + if os.path.exists(f'{self.cutlass_path}/CHANGELOG.md'): + with open(f'{self.cutlass_path}/CHANGELOG.md', 'r') as f: + if '3.1.0' not in f.read(): + if verbose: + self.warning("Please use CUTLASS version >= 3.1.0") + return False + else: + # pip install nvidia-cutlass package + try: + import cutlass + except ImportError: + if verbose: + self.warning("Please pip install nvidia-cutlass if trying to pre-compile kernels") + return False + cutlass_major, cutlass_minor = cutlass.__version__.split('.')[:2] + cutlass_compatible = (int(cutlass_major) >= 3 and int(cutlass_minor) >= 1) + if not cutlass_compatible: if verbose: self.warning("Please use CUTLASS version >= 3.1.0") return False + cuda_okay = True if not self.is_rocm_pytorch() and torch.cuda.is_available(): #ignore-cuda sys_cuda_major, _ = installed_cuda_version() diff --git a/op_builder/gds.py b/op_builder/gds.py index e024674e01d8..01c2d5a245d1 100644 --- a/op_builder/gds.py +++ b/op_builder/gds.py @@ -36,7 +36,13 @@ def extra_ldflags(self): return super().extra_ldflags() + ['-lcufile'] def is_compatible(self, verbose=False): - import torch.utils.cpp_extension + try: + import torch.utils.cpp_extension + except ImportError: + if verbose: + self.warning("Please install torch if trying to pre-compile GDS") + return False + CUDA_HOME = torch.utils.cpp_extension.CUDA_HOME CUDA_LIB64 = os.path.join(CUDA_HOME, "lib64") gds_compatible = self.has_function(funcname="cuFileDriverOpen", diff --git a/op_builder/xpu/inference.py b/op_builder/xpu/inference.py index 9114dcc2c315..a9ac4f84c2ca 100644 --- a/op_builder/xpu/inference.py +++ b/op_builder/xpu/inference.py @@ -30,7 +30,10 @@ def cxx_args(self): def load(self): try: - import intel_extension_for_pytorch.deepspeed - return intel_extension_for_pytorch.deepspeed.transformer_inference.transformer_inference + import intel_extension_for_pytorch + if hasattr(intel_extension_for_pytorch, "deepspeed"): + return intel_extension_for_pytorch.deepspeed.transformer_inference.transformer_inference + else: + return intel_extension_for_pytorch.xpu.deepspeed except ImportError: raise ImportError("Please install intel-extension-for-pytorch >= 2.1.30 to include DeepSpeed kernels.") diff --git a/requirements/requirements-readthedocs.txt b/requirements/requirements-readthedocs.txt index 1a2ad18611e7..a48a47e4428d 100644 --- a/requirements/requirements-readthedocs.txt +++ b/requirements/requirements-readthedocs.txt @@ -1,10 +1,10 @@ -autodoc_pydantic +autodoc_pydantic>=2.0.0 docutils<0.18 hjson packaging psutil py-cpuinfo -pydantic<2.0.0 +pydantic>=2.0.0 recommonmark sphinx_rtd_theme torch diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 6840d6dbcc98..70c94a745435 100755 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -4,6 +4,6 @@ numpy packaging>=20.0 psutil py-cpuinfo -pydantic +pydantic>=2.0.0 torch tqdm diff --git a/tests/unit/alexnet_model.py b/tests/unit/alexnet_model.py index dfab28aa7477..51e80e7f9e62 100644 --- a/tests/unit/alexnet_model.py +++ b/tests/unit/alexnet_model.py @@ -101,12 +101,14 @@ def cifar_trainset(fp16=False): dist.barrier() if local_rank != 0: dist.barrier() - data_root = os.getenv("TEST_DATA_DIR", "/tmp/") - trainset = torchvision.datasets.CIFAR10(root=os.path.join(data_root, "cifar10-data"), - train=True, - download=True, - transform=transform) + if os.getenv("CIFAR10_DATASET_PATH"): + data_root = os.getenv("CIFAR10_DATASET_PATH") + download = False + else: + data_root = os.path.join(os.getenv("TEST_DATA_DIR", "/tmp"), "cifar10-data") + download = True + trainset = torchvision.datasets.CIFAR10(root=data_root, train=True, download=download, transform=transform) if local_rank == 0: dist.barrier() return trainset diff --git a/tests/unit/inference/test_checkpoint_sharding.py b/tests/unit/inference/test_checkpoint_sharding.py index 5bae9a151a27..f1e37ee26536 100644 --- a/tests/unit/inference/test_checkpoint_sharding.py +++ b/tests/unit/inference/test_checkpoint_sharding.py @@ -14,6 +14,7 @@ from huggingface_hub import snapshot_download from transformers.utils import is_offline_mode from deepspeed.ops.op_builder import InferenceBuilder +from deepspeed.accelerator import get_accelerator if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]: pytest.skip("This op had not been implemented on this system.", allow_module_level=True) @@ -44,6 +45,8 @@ def model_name(request): @pytest.fixture(params=[torch.float16, torch.int8], ids=["fp16", "int8"]) def dtype(request): + if request.param not in get_accelerator().supported_dtypes(): + pytest.skip(f"{request.param} not supported by {get_accelerator().device_name()}.") return request.param diff --git a/tests/unit/inference/test_model_profiling.py b/tests/unit/inference/test_model_profiling.py index 23e49f89025b..319055d0ea55 100644 --- a/tests/unit/inference/test_model_profiling.py +++ b/tests/unit/inference/test_model_profiling.py @@ -16,6 +16,9 @@ if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]: pytest.skip("This op had not been implemented on this system.", allow_module_level=True) +if torch.half not in get_accelerator().supported_dtypes(): + pytest.skip(f"fp16 not supported, valid dtype: {get_accelerator().supported_dtypes()}", allow_module_level=True) + @pytest.mark.inference @pytest.mark.parametrize("use_cuda_events", [True, False]) diff --git a/tests/unit/inference/v2/ragged/test_manager_configs.py b/tests/unit/inference/v2/ragged/test_manager_configs.py index a5f270cced8c..bdd513445ddb 100644 --- a/tests/unit/inference/v2/ragged/test_manager_configs.py +++ b/tests/unit/inference/v2/ragged/test_manager_configs.py @@ -5,7 +5,7 @@ import pytest -from deepspeed.pydantic_v1 import ValidationError +from pydantic import ValidationError from deepspeed.inference.v2.ragged import DSStateManagerConfig diff --git a/tests/unit/ops/transformer/inference/inference_test_utils.py b/tests/unit/ops/transformer/inference/inference_test_utils.py index 9c7b428c0e68..9cfcae809f09 100644 --- a/tests/unit/ops/transformer/inference/inference_test_utils.py +++ b/tests/unit/ops/transformer/inference/inference_test_utils.py @@ -26,12 +26,7 @@ def get_tolerances(): def get_dtypes(): global DTYPES if DTYPES is None: - DTYPES = [torch.float16, torch.float32] - try: - if get_accelerator().is_bf16_supported(): - DTYPES.append(torch.bfloat16) - except (AssertionError, AttributeError): - pass + DTYPES = get_accelerator().supported_dtypes() return DTYPES diff --git a/tests/unit/runtime/test_ds_config_dict.py b/tests/unit/runtime/test_ds_config_dict.py index c11c63d04867..d06b35e208fe 100644 --- a/tests/unit/runtime/test_ds_config_dict.py +++ b/tests/unit/runtime/test_ds_config_dict.py @@ -67,13 +67,11 @@ def _batch_assert(status, ds_config, batch, micro_batch, gas, success): if not success: assert not status - print("Failed but All is well") return assert ds_config.train_batch_size == batch assert ds_config.train_micro_batch_size_per_gpu == micro_batch assert ds_config.gradient_accumulation_steps == gas - print("All is well") #Tests different batch config provided in deepspeed json file diff --git a/tests/unit/runtime/test_ds_config_model.py b/tests/unit/runtime/test_ds_config_model.py index 87ea747cf423..4d184b2858a8 100644 --- a/tests/unit/runtime/test_ds_config_model.py +++ b/tests/unit/runtime/test_ds_config_model.py @@ -4,18 +4,25 @@ # DeepSpeed Team import pytest -import os import json -from typing import List -from deepspeed.pydantic_v1 import Field, ValidationError +import os +from typing import List, Optional + +from pydantic import Field, ValidationError + from deepspeed.runtime import config as ds_config from deepspeed.runtime.config_utils import DeepSpeedConfigModel class SimpleConf(DeepSpeedConfigModel): param_1: int = 0 - param_2_old: str = Field(None, deprecated=True, new_param="param_2", new_param_fn=(lambda x: [x])) - param_2: List[str] = None + param_2_old: Optional[str] = Field(None, + json_schema_extra={ + "deprecated": True, + "new_param": "param_2", + "new_param_fn": (lambda x: [x]) + }) + param_2: Optional[List[str]] = None param_3: int = Field(0, alias="param_3_alias") diff --git a/version.txt b/version.txt index 226468ee5b2e..e815b861f023 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.14.6 +0.15.1