diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py index 6b63efbb23f7..ba97e336a2bb 100644 --- a/deepspeed/runtime/bf16_optimizer.py +++ b/deepspeed/runtime/bf16_optimizer.py @@ -36,6 +36,7 @@ class BF16_Optimizer(ZeROOptimizer): def __init__(self, init_optimizer, param_names, + bfloat16_config, mpu=None, clip_grad=0.0, norm_type=2, @@ -44,7 +45,6 @@ def __init__(self, timers=None, grad_acc_dtype=None, graph_harvesting=False, - immediate_grad_update=False, has_moe_layers=False): super().__init__() see_memory_usage('begin bf16_optimizer', force=True) @@ -53,10 +53,12 @@ def __init__(self, self.param_names = param_names self.using_real_optimizer = not isinstance(self.optimizer, DummyOptim) + assert bfloat16_config.enabled, f"BF16Optimizer: requires bfloat16 to be enabled" assert grad_acc_dtype in [torch.float32, torch.bfloat16 ], f"BF16Optimizer: Unsupported gradient accumulation data type: {grad_acc_dtype}" self.grad_acc_dtype = grad_acc_dtype - self.immediate_grad_update = immediate_grad_update + + self.immediate_grad_update = bfloat16_config.immediate_grad_update self.clip_grad = clip_grad self.norm_type = norm_type diff --git a/deepspeed/runtime/config.py b/deepspeed/runtime/config.py index b6dabc161e8c..7ae85af46656 100755 --- a/deepspeed/runtime/config.py +++ b/deepspeed/runtime/config.py @@ -14,13 +14,13 @@ import base64 from .constants import * -from .fp16.loss_scaler import ( - INITIAL_LOSS_SCALE, - SCALE_WINDOW, - DELAYED_SHIFT, - CONSECUTIVE_HYSTERESIS, - MIN_LOSS_SCALE, -) +# from .fp16.loss_scaler import ( +# INITIAL_LOSS_SCALE, +# SCALE_WINDOW, +# DELAYED_SHIFT, +# CONSECUTIVE_HYSTERESIS, +# MIN_LOSS_SCALE, +# ) from .config_utils import ( get_scalar_param, dict_raise_error_on_duplicate_keys, @@ -31,6 +31,7 @@ from ..comm.config import DeepSpeedCommsConfig from ..monitor.config import get_monitor_config from ..inference.config import WeightQuantConfig +from .precision_config import get_bfloat16_config, get_float16_config from deepspeed import comm as dist from deepspeed.runtime.config_utils import DeepSpeedConfigModel @@ -157,88 +158,64 @@ def get_amp_params(param_dict): return False -def get_fp16_enabled(param_dict): - if FP16 in param_dict.keys(): - return get_scalar_param(param_dict[FP16], FP16_ENABLED, FP16_ENABLED_DEFAULT) - else: - return False - - -def get_bfloat16_enabled(param_dict): - for key in [BFLOAT16, BFLOAT16_OLD]: - if key in param_dict.keys(): - return get_scalar_param(param_dict[key], BFLOAT16_ENABLED, BFLOAT16_ENABLED_DEFAULT) - return False - - -def get_bfloat16_immediate_grad_update(param_dict): - for key in [BFLOAT16, BFLOAT16_OLD]: - if key in param_dict.keys(): - return get_scalar_param(param_dict[key], BFLOAT16_IMMEDIATE_GRAD_UPDATE, - BFLOAT16_IMMEDIATE_GRAD_UPDATE_DEFAULT) - return False - - -def get_fp16_master_weights_and_grads_enabled(param_dict): - if get_fp16_enabled(param_dict): - return get_scalar_param(param_dict[FP16], FP16_MASTER_WEIGHTS_AND_GRADS, FP16_MASTER_WEIGHTS_AND_GRADS_DEFAULT) - else: - return False - - -def get_fp16_auto_cast(param_dict): - if get_fp16_enabled(param_dict): - return get_scalar_param(param_dict[FP16], FP16_AUTO_CAST, FP16_AUTO_CAST_DEFAULT) - - -def get_loss_scale(param_dict): - if get_fp16_enabled(param_dict): - return get_scalar_param(param_dict[FP16], FP16_LOSS_SCALE, FP16_LOSS_SCALE_DEFAULT) - elif get_bfloat16_enabled(param_dict): - return 1.0 - else: - return FP16_LOSS_SCALE_DEFAULT - - -def get_initial_dynamic_scale(param_dict): - if get_fp16_enabled(param_dict): - initial_scale_power = get_scalar_param(param_dict[FP16], FP16_INITIAL_SCALE_POWER, - FP16_INITIAL_SCALE_POWER_DEFAULT) - elif get_bfloat16_enabled(param_dict): - initial_scale_power = 0 - else: - initial_scale_power = FP16_INITIAL_SCALE_POWER_DEFAULT - - return 2**initial_scale_power - - -def get_dynamic_loss_scale_args(param_dict): - loss_scale_args = None - if get_fp16_enabled(param_dict): - fp16_dict = param_dict[FP16] - dynamic_loss_args = [ - FP16_INITIAL_SCALE_POWER, - FP16_LOSS_SCALE_WINDOW, - FP16_MIN_LOSS_SCALE, - FP16_HYSTERESIS, - FP16_CONSECUTIVE_HYSTERESIS, - ] - if any(arg in list(fp16_dict.keys()) for arg in dynamic_loss_args): - init_scale = get_scalar_param(fp16_dict, FP16_INITIAL_SCALE_POWER, FP16_INITIAL_SCALE_POWER_DEFAULT) - scale_window = get_scalar_param(fp16_dict, FP16_LOSS_SCALE_WINDOW, FP16_LOSS_SCALE_WINDOW_DEFAULT) - delayed_shift = get_scalar_param(fp16_dict, FP16_HYSTERESIS, FP16_HYSTERESIS_DEFAULT) - consecutive_hysteresis = get_scalar_param(fp16_dict, FP16_CONSECUTIVE_HYSTERESIS, - FP16_CONSECUTIVE_HYSTERESIS_DEFAULT) - min_loss_scale = get_scalar_param(fp16_dict, FP16_MIN_LOSS_SCALE, FP16_MIN_LOSS_SCALE_DEFAULT) - loss_scale_args = { - INITIAL_LOSS_SCALE: 2**init_scale, - SCALE_WINDOW: scale_window, - DELAYED_SHIFT: delayed_shift, - CONSECUTIVE_HYSTERESIS: consecutive_hysteresis, - MIN_LOSS_SCALE: min_loss_scale, - } - - return loss_scale_args +# def get_fp16_enabled(param_dict): +# if FP16 in param_dict.keys(): +# return get_scalar_param(param_dict[FP16], FP16_ENABLED, FP16_ENABLED_DEFAULT) +# else: +# return False + +# def get_fp16_master_weights_and_grads_enabled(param_dict): +# if get_fp16_enabled(param_dict): +# return get_scalar_param(param_dict[FP16], FP16_MASTER_WEIGHTS_AND_GRADS, FP16_MASTER_WEIGHTS_AND_GRADS_DEFAULT) +# else: +# return False + +# def get_fp16_auto_cast(param_dict): +# if get_fp16_enabled(param_dict): +# return get_scalar_param(param_dict[FP16], FP16_AUTO_CAST, FP16_AUTO_CAST_DEFAULT) + +# def get_loss_scale(param_dict): +# if get_fp16_enabled(param_dict): +# return get_scalar_param(param_dict[FP16], FP16_LOSS_SCALE, FP16_LOSS_SCALE_DEFAULT) +# else: +# return FP16_LOSS_SCALE_DEFAULT + +# def get_initial_dynamic_scale(param_dict): +# if get_fp16_enabled(param_dict): +# initial_scale_power = get_scalar_param(param_dict[FP16], FP16_INITIAL_SCALE_POWER, +# FP16_INITIAL_SCALE_POWER_DEFAULT) +# else: +# initial_scale_power = FP16_INITIAL_SCALE_POWER_DEFAULT + +# return 2**initial_scale_power + +# def get_dynamic_loss_scale_args(param_dict): +# loss_scale_args = None +# if get_fp16_enabled(param_dict): +# fp16_dict = param_dict[FP16] +# dynamic_loss_args = [ +# FP16_INITIAL_SCALE_POWER, +# FP16_LOSS_SCALE_WINDOW, +# FP16_MIN_LOSS_SCALE, +# FP16_HYSTERESIS, +# FP16_CONSECUTIVE_HYSTERESIS, +# ] +# if any(arg in list(fp16_dict.keys()) for arg in dynamic_loss_args): +# init_scale = get_scalar_param(fp16_dict, FP16_INITIAL_SCALE_POWER, FP16_INITIAL_SCALE_POWER_DEFAULT) +# scale_window = get_scalar_param(fp16_dict, FP16_LOSS_SCALE_WINDOW, FP16_LOSS_SCALE_WINDOW_DEFAULT) +# delayed_shift = get_scalar_param(fp16_dict, FP16_HYSTERESIS, FP16_HYSTERESIS_DEFAULT) +# consecutive_hysteresis = get_scalar_param(fp16_dict, FP16_CONSECUTIVE_HYSTERESIS, +# FP16_CONSECUTIVE_HYSTERESIS_DEFAULT) +# min_loss_scale = get_scalar_param(fp16_dict, FP16_MIN_LOSS_SCALE, FP16_MIN_LOSS_SCALE_DEFAULT) +# loss_scale_args = { +# INITIAL_LOSS_SCALE: 2**init_scale, +# SCALE_WINDOW: scale_window, +# DELAYED_SHIFT: delayed_shift, +# CONSECUTIVE_HYSTERESIS: consecutive_hysteresis, +# MIN_LOSS_SCALE: min_loss_scale, +# } + +# return loss_scale_args def get_gradient_accumulation_steps(param_dict): @@ -827,18 +804,19 @@ def _initialize_params(self, param_dict): self.monitor_config = get_monitor_config(param_dict) self.gradient_clipping = get_gradient_clipping(param_dict) - self.fp16_enabled = get_fp16_enabled(param_dict) - self.fp16_auto_cast = get_fp16_auto_cast(param_dict) - self.bfloat16_enabled = get_bfloat16_enabled(param_dict) - self.bfloat16_immediate_grad_update = get_bfloat16_immediate_grad_update(param_dict) - assert not (self.fp16_enabled - and self.bfloat16_enabled), 'bfloat16 and fp16 modes cannot be simultaneously enabled' - self.fp16_master_weights_and_gradients = get_fp16_master_weights_and_grads_enabled(param_dict) + # self.fp16_enabled = get_fp16_enabled(param_dict) + # self.fp16_auto_cast = get_fp16_auto_cast(param_dict) + self.float16_config = get_float16_config(param_dict) + self.bfloat16_config = get_bfloat16_config(param_dict) + assert not (self.float16_config.enabled + and self.bfloat16_config.enabled), 'bfloat16 and fp16 modes cannot be simultaneously enabled' + # self.fp16_master_weights_and_gradients = get_fp16_master_weights_and_grads_enabled(param_dict) + # self.loss_scale = get_loss_scale(param_dict) + # self.initial_dynamic_scale = get_initial_dynamic_scale(param_dict) + # self.dynamic_loss_scale_args = get_dynamic_loss_scale_args(param_dict) + self.amp_enabled = get_amp_enabled(param_dict) self.amp_params = get_amp_params(param_dict) - self.loss_scale = get_loss_scale(param_dict) - self.initial_dynamic_scale = get_initial_dynamic_scale(param_dict) - self.dynamic_loss_scale_args = get_dynamic_loss_scale_args(param_dict) self.compression_config = get_compression_config(param_dict) self.graph_harvesting = get_graph_harvesting(param_dict) @@ -1018,11 +996,11 @@ def _do_error_check(self): <= ZeroStageEnum.max_stage), "DeepSpeedConfig: Maximum supported ZeRO stage is {}".format( ZeroStageEnum.max_stage) - if self.fp16_master_weights_and_gradients: + if self.float16_config.fp16_master_weights_and_grads: assert self.zero_enabled and self.zero_optimization_stage == ZeroStageEnum.gradients, "Fp16_master_weights_and_grads is only supported with ZeRO Stage 2 for now." def _do_warning_check(self): - fp16_enabled = self.fp16_enabled + fp16_enabled = self.float16_config.enabled vocabulary_size = self._param_dict.get(VOCABULARY_SIZE, VOCABULARY_SIZE_DEFAULT) if vocabulary_size and vocabulary_size % TENSOR_CORE_ALIGN_SIZE != 0: diff --git a/deepspeed/runtime/constants.py b/deepspeed/runtime/constants.py index 55cfa8f59c91..82252c3e713a 100755 --- a/deepspeed/runtime/constants.py +++ b/deepspeed/runtime/constants.py @@ -117,7 +117,9 @@ BFLOAT16_FORMAT = ''' BFLOAT16 parameters should be of the format: "bf16": { - "enabled": true + "enabled": true, + "immediate_grad_update": false, + "check_overflow": false } ''' BFLOAT16 = "bf16" @@ -126,6 +128,9 @@ BFLOAT16_ENABLED = "enabled" BFLOAT16_ENABLED_DEFAULT = False +CHECK_OVERFLOW = "check_overflow" +BFLOAT16_CHECK_OVERFLOW_DEFAULT = False + # BFLOAT16 optimizer immediate gradient update BFLOAT16_IMMEDIATE_GRAD_UPDATE = "immediate_grad_update" BFLOAT16_IMMEDIATE_GRAD_UPDATE_DEFAULT = False diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 986b68dc1bb1..c96b47df5d18 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -906,13 +906,13 @@ def graph_harvesting(self): return self._config.graph_harvesting def fp16_enabled(self): - return self._config.fp16_enabled + return self._config.float16_config.enabled def bfloat16_enabled(self): - return self._config.bfloat16_enabled + return self._config.bfloat16_config.enabled def fp16_master_weights_and_gradients(self): - return self._config.fp16_master_weights_and_gradients + return self._config.float16_config.fp16_master_weights_and_grads def amp_enabled(self): return self._config.amp_enabled @@ -921,10 +921,10 @@ def amp_params(self): return self._config.amp_params def fp16_auto_cast(self): - return self._config.fp16_auto_cast + return self._config.float16_config.auto_cast def loss_scale(self): - return self._config.loss_scale + return self._config.float16_config.loss_scale def gradient_accumulation_steps(self): return self._config.gradient_accumulation_steps @@ -990,13 +990,13 @@ def gradient_clipping(self): return self._config.gradient_clipping def dynamic_loss_scale(self): - return self._config.loss_scale == 0 + return self._config.float16_config.loss_scale == 0 def initial_dynamic_scale(self): - return self._config.initial_dynamic_scale + return self._config.float16_config.initial_dynamic_scale() def dynamic_loss_scale_args(self): - return self._config.dynamic_loss_scale_args + return self._config.float16_config.dynamic_loss_scale_args() def swap_tensor_config(self): return self._config.swap_tensor_config @@ -1597,6 +1597,7 @@ def _configure_bf16_optimizer(self, optimizer): timers = self.timers if self.wall_clock_breakdown() else NoopTimer() optimizer = BF16_Optimizer(optimizer, self.param_names, + bfloat16_config=self._config.bfloat16_config, mpu=self.mpu, clip_grad=clip_grad, allgather_bucket_size=self.zero_allgather_bucket_size(), @@ -1604,7 +1605,6 @@ def _configure_bf16_optimizer(self, optimizer): timers=timers, grad_acc_dtype=self.get_data_types()[1], graph_harvesting=self.graph_harvesting(), - immediate_grad_update=self._config.bfloat16_immediate_grad_update, has_moe_layers=self.has_moe_layers) return optimizer @@ -1615,6 +1615,13 @@ def _configure_zero_optimizer(self, optimizer): mics_shard_size = self.mics_shard_size() model_dtype, gradient_accumulation_dtype = self.get_data_types() + if self.bfloat16_enabled(): + check_grad_overflow = self._config.bfloat16_config.check_grad_overflow + elif self.fp16_enabled(): + check_grad_overflow = True + else: + check_grad_overflow = False + timers = self.timers if self.wall_clock_breakdown() else NoopTimer() if optimizer is None: @@ -1666,7 +1673,8 @@ def _configure_zero_optimizer(self, optimizer): fp16_master_weights_and_gradients=self.fp16_master_weights_and_gradients(), gradient_accumulation_dtype=gradient_accumulation_dtype, communication_data_type=self.communication_data_type, - elastic_checkpoint=self.zero_elastic_checkpoint()) + elastic_checkpoint=self.zero_elastic_checkpoint(), + check_grad_overflow=check_grad_overflow) elif zero_stage == ZeroStageEnum.weights: assert not self.has_moe_layers, "MoE not supported with Stage 3" diff --git a/deepspeed/runtime/fp16/loss_scaler.py b/deepspeed/runtime/fp16/loss_scaler.py index 451451c51a32..579a779068b0 100755 --- a/deepspeed/runtime/fp16/loss_scaler.py +++ b/deepspeed/runtime/fp16/loss_scaler.py @@ -116,18 +116,17 @@ class DynamicLossScaler(LossScalerBase): """ def __init__(self, - init_scale=2**32, - scale_factor=2., - scale_window=1000, - min_scale=1, - delayed_shift=1, - consecutive_hysteresis=False, + init_scale, + scale_window, + min_scale, + delayed_shift, + consecutive_hysteresis, raise_error_at_min_scale=True, dtype=torch.half): super(DynamicLossScaler, self).__init__(init_scale) self.cur_iter = 0 self.last_overflow_iter = -1 - self.scale_factor = scale_factor + self.scale_factor = 2.0 self.scale_window = scale_window self.min_scale = min_scale self.delayed_shift = delayed_shift @@ -196,7 +195,9 @@ def update_scale(self, overflow): hysteresis_msg = f"Consecutive hysteresis is enabled. Restoring hysteresis to {self.delayed_shift}" logger.info(hysteresis_msg) self.cur_hysteresis = self.delayed_shift - if (self.cur_iter - self.last_overflow_iter) % self.scale_window == 0: + + stable_interval = (self.cur_iter - self.last_overflow_iter) - 1 + if (stable_interval > 0) and (stable_interval % self.scale_window == 0): if not self.consecutive_hysteresis: self.cur_hysteresis = self.delayed_shift self.cur_scale *= self.scale_factor @@ -207,8 +208,7 @@ def update_scale(self, overflow): # we still create a scaler for other dtypes (fp32, bf16) which does not perform any scaling. def CreateLossScaler(dtype, static_loss_scale, dynamic_scaling, dynamic_loss_args): if dtype == torch.half and dynamic_scaling: - if dynamic_loss_args is None: - return DynamicLossScaler(dtype=dtype) + assert dynamic_loss_args is not None, f"Dynamic loss scaling parameters must be defined." return DynamicLossScaler(dtype=dtype, **dynamic_loss_args) loss_scale_value = static_loss_scale if dtype == torch.half else 1.0 diff --git a/deepspeed/runtime/hybrid_engine.py b/deepspeed/runtime/hybrid_engine.py index b6e417fd4764..b96228c90b02 100644 --- a/deepspeed/runtime/hybrid_engine.py +++ b/deepspeed/runtime/hybrid_engine.py @@ -78,9 +78,9 @@ def _replace_linear_layer(r_module, parent_type=None, prev_type=None): def new_inference_container(self, orig_layer, policy_cls, layer_id): policy = policy_cls(orig_layer, inference=True) - if self._config.fp16_enabled: + if self._config.float16_config.enabled: inference_dtype = torch.float16 - elif self._config.bfloat16_enabled: + elif self._config.bfloat16_config.enabled: inference_dtype = torch.bfloat16 else: inference_dtype = torch.float32 diff --git a/deepspeed/runtime/pipe/engine.py b/deepspeed/runtime/pipe/engine.py index deb44c2e71eb..81ad8f4dd14f 100644 --- a/deepspeed/runtime/pipe/engine.py +++ b/deepspeed/runtime/pipe/engine.py @@ -863,7 +863,7 @@ def _exec_backward_pass(self, buffer_id): if self.using_bf16_optimizer and not self.is_last_stage(): # manually call because we don't call optimizer.backward() - if not self._config.bfloat16_immediate_grad_update: + if not self._config.bfloat16_config.immediate_grad_update: self.optimizer.update_hp_grads(clear_lp_grads=False) # Free up the memory from the output of forward() diff --git a/deepspeed/runtime/precision_config.py b/deepspeed/runtime/precision_config.py new file mode 100644 index 000000000000..1c7fab8bd234 --- /dev/null +++ b/deepspeed/runtime/precision_config.py @@ -0,0 +1,146 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from deepspeed.runtime.config_utils import DeepSpeedConfigModel +from .fp16.loss_scaler import ( + INITIAL_LOSS_SCALE, + SCALE_WINDOW, + DELAYED_SHIFT, + CONSECUTIVE_HYSTERESIS, + MIN_LOSS_SCALE, +) + +######################################### +# BFLOAT16 support +######################################### +# BFLOAT16 feature. By default, this feature is not enabled. +# Users can configure in ds_config.json as below example: +BFLOAT16_FORMAT = ''' +BFLOAT16 parameters should be of the format: +"bf16": { + "enabled": true, + "immediate_grad_update": false, + "check_grad_overflow": false +} +''' +BFLOAT16 = "bf16" +BFLOAT16_OLD = "bfloat16" # keeping for backwards compatibility + + +def get_bfloat16_config(param_dict): + bf16_config_dict = param_dict.get(BFLOAT16, None) + if bf16_config_dict is None: + bf16_config_dict = param_dict.get(BFLOAT16_OLD, {}) + return DeepSpeedBF16Config(**bf16_config_dict) + + +class DeepSpeedBF16Config(DeepSpeedConfigModel): + """ + For bfloat16 configuration + """ + + enabled: bool = False + """ + Enable bfloat16 mixed-precision training/inference + """ + + immediate_grad_update: bool = False + """ + Apply gradient updates immediately rather than delayed. + """ + + check_grad_overflow: bool = False + """ + Check for gradient overflows and underflows + """ + + +######################################### +# FP16 support +######################################### +# FP16 feature. By default, this feature is not enabled. +# Users can configure in ds_config.json as below example: +FP16_FORMAT = ''' +FP16 parameters should be of the format: +"fp16": { + "enabled": true, + "auto_cast": false, + "loss_scale": 0, + "initial_scale_power": 16, + "loss_scale_window": 1000, + "hysteresis": 2, + "consecutive_hysteresis": false, + "min_loss_scale": 1 +} +''' +FP16 = "fp16" + + +def get_float16_config(param_dict): + fp16_config_dict = param_dict.get(FP16, {}) + return DeepSpeedFP16Config(**fp16_config_dict) + + +class DeepSpeedFP16Config(DeepSpeedConfigModel): + """ + For float16 configuration + """ + + enabled: bool = False + """ + Enable fp16 mixed-precision training/inference + """ + + auto_cast: bool = False + """ + Automatically cast inputs to fp16 + """ + + loss_scale: float = 0 + """ + Loss scaling value. Default value of 0 means dynamic loss scaling instead of static loss scale. + """ + + initial_scale_power: int = 16 + """ + For dynamic loss scaling, set initial loss scale to 2^{initial_scale_power}. + """ + + loss_scale_window: int = 1000 + """ + Iteration intervals for raising/lowering dynamic loss scale value. + """ + + hysteresis: int = 2 + """ + Delay shift in dynamic loss scaling. + """ + + consecutive_hysteresis: bool = False + """ + Refill hysteresis if iteration does not overflow/underflow. + """ + + min_loss_scale: int = 1 + """ + Minimum dynamic loss scale value. + """ + + fp16_master_weights_and_grads: bool = False + """ + Maintain master weights in optimizer state as fp16 instead of fp32 (valid with DeepSpeedCPUAdam only). + """ + + def initial_dynamic_scale(self): + return 2**self.initial_scale_power + + def dynamic_loss_scale_args(self): + return { + INITIAL_LOSS_SCALE: 2**self.initial_scale_power, + SCALE_WINDOW: self.loss_scale_window, + DELAYED_SHIFT: self.hysteresis, + CONSECUTIVE_HYSTERESIS: self.consecutive_hysteresis, + MIN_LOSS_SCALE: self.min_loss_scale, + } diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index e8cb797b8a5b..b0d22ed19ccd 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -358,12 +358,12 @@ def _post_init_method(self, module): def _set_dtype(self, ds_config, dtype): if ds_config is not None and dtype is None: - if ds_config.bfloat16_enabled and ds_config.fp16_enabled: + if ds_config.bfloat16_config.enabled and ds_config.float16_config.enabled: raise RuntimeError("bfloat16 and fp16 cannot be enabled at once") - if ds_config.bfloat16_enabled: + if ds_config.bfloat16_config.enabled: self.dtype = torch.bfloat16 - elif ds_config.fp16_enabled: + elif ds_config.float16_config.enabled: self.dtype = torch.half else: self.dtype = torch.float diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 2bece09bffc4..292f8f927565 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -136,7 +136,8 @@ def __init__(self, round_robin_gradients=False, has_moe_layers=False, fp16_master_weights_and_gradients=False, - elastic_checkpoint=False): + elastic_checkpoint=False, + check_grad_overflow=True): if offload_optimizer_config is not None and offload_optimizer_config.device != OffloadDeviceEnum.none: self.cpu_offload = True @@ -155,6 +156,7 @@ def __init__(self, # 2. keep common stuff here in case we need to add ne552w fused optimizer later self.elastic_checkpoint = elastic_checkpoint + self.check_grad_overflow = check_grad_overflow self.param_names = param_names self.mpu = mpu # differences from apex.fp16_utils: @@ -557,6 +559,8 @@ def __init__(self, self._enable_universal_checkpoint() self._param_slice_mappings = self._create_param_mapping() + if self.cpu_offload: + self._create_optimizer_mapping() def destroy(self): for i, _ in enumerate(self.optimizer.param_groups): @@ -583,6 +587,12 @@ def _create_param_mapping(self): return param_mapping + def _create_optimizer_mapping(self): + for i, _ in enumerate(self.optimizer.param_groups): + for lp in self.bit16_groups[i]: + if lp._hp_mapping is not None: + lp._zero_optimizer = self + def _link_all_hp_params(self): if self.cpu_offload: self._get_offload_gradient_dict() @@ -1175,11 +1185,14 @@ def get_grad_position(self, group_id, tensor_list, first_offset, partition_size) ] current_offset += num_elements - def update_overflow_tracker_for_param_grad(self, param): - grad_accum = self.get_param_gradient_attribute(param) - if grad_accum is not None and self._has_inf_or_nan(grad_accum.data): + def update_offload_overflow_tracker(self, grad): + if grad is not None and self._has_inf_or_nan(grad.data): self.local_overflow = True + def update_offload_overflow_tracker_for_param_grad(self, param): + grad_accum = self.get_param_gradient_attribute(param) + self.update_offload_overflow_tracker(grad_accum) + def _get_offload_gradient_dict(self): for param_group_index, _ in enumerate(self.optimizer.param_groups): self.offload_gradient_dict[param_group_index] = [] @@ -1281,7 +1294,7 @@ def async_inplace_copy_grad_to_fp32_buffer_from_gpu(self, param): src_tensor = src_tensor.float() dest_tensor.copy_(src_tensor, non_blocking=True) - param.grad = None #offload only + self.clear_grad_attribute(param) #offload only def complete_grad_norm_calculation_for_cpu_offload(self, params): total_norm = 0.0 @@ -1313,17 +1326,17 @@ def complete_grad_norm_calculation_for_cpu_offload(self, params): """ # Sum across all model parallel GPUs. - total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]) - dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM, group=self.dp_process_group) + total_dev_norm = get_accelerator().FloatTensor([float(total_norm)]) + dist.all_reduce(total_dev_norm, op=dist.ReduceOp.SUM, group=self.dp_process_group) - self._model_parallel_all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.SUM) + self._model_parallel_all_reduce(tensor=total_dev_norm, op=dist.ReduceOp.SUM) - total_norm = total_norm_cuda[0].item()**(1. / norm_type) + total_norm = total_dev_norm[0].item()**(1. / norm_type) if total_norm == float('inf') or total_norm == -float('inf') or total_norm != total_norm: - total_norm = -1 + total_norm = -1.0 - return total_norm + return torch.tensor(total_norm, device=self.device, dtype=torch.float) ############################################################################################ def copy_grads_in_partition(self, param): @@ -1335,7 +1348,7 @@ def copy_grads_in_partition(self, param): if self.is_gradient_accumulation_boundary: self.set_norm_for_param_grad_in_gpu(param) - self.update_overflow_tracker_for_param_grad(param) + self.update_offload_overflow_tracker_for_param_grad(param) self.async_inplace_copy_grad_to_fp32_buffer_from_gpu(param) @@ -1789,10 +1802,7 @@ def scaled_global_norm(self, norm_type=2): norm_groups = [] for i, group in enumerate(self.bit16_groups): if self.cpu_offload: - # complete complete_grad_norm_calculation_for_cpu_offload return python float, moving back to - # torch.tensor as else statement returns tensor as well - norm = torch.tensor(self.complete_grad_norm_calculation_for_cpu_offload(self.params_in_partition[i]), - device=self.device) + norm = self.complete_grad_norm_calculation_for_cpu_offload(self.params_in_partition[i]) norm_groups.append(norm) else: norm_groups.append(self.get_grad_norm_direct(self.averaged_gradients[i], self.params_in_partition[i])) @@ -1832,8 +1842,8 @@ def step(self, closure=None): see_memory_usage(f"In step before checking overflow") # First compute norm for all group so we know if there is overflow - if self.dtype == torch.float16: - self.check_overflow() + if self.check_grad_overflow: + self.check_overflow(partition_gradients=self.partition_gradients) prev_scale = self.loss_scale self._update_scale(self.overflow) @@ -1843,7 +1853,8 @@ def step(self, closure=None): if self.cpu_offload: self.reset_cpu_buffers() else: - self.averaged_gradients = {} + for k in self.averaged_gradients.keys(): + self.averaged_gradients[k] = None see_memory_usage('After overflow after clearing gradients') @@ -2004,21 +2015,15 @@ def has_overflow_partitioned_grads_serial(self): return invalid_grad_count.bool() def has_overflow(self, partition_gradients=True): + overflow = self.local_overflow if self.cpu_offload else self.has_overflow_partitioned_grads_serial() + overflow_gpu = get_accelerator().ByteTensor([overflow]) if self.cpu_offload else overflow.byte().to( + get_accelerator().current_device_name()) + if partition_gradients: - overflow = self.local_overflow if self.cpu_offload else self.has_overflow_partitioned_grads_serial() - overflow_gpu = get_accelerator().ByteTensor([overflow]) if self.cpu_offload else overflow.byte().to( - get_accelerator().current_device_name()) '''This will capture overflow across all data parallel and expert parallel process Since expert parallel process are a subset of data parallel process''' dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=self.dp_process_group) - else: - params = [] - for group in self.bit16_groups: - for param in group: - params.append(param) - overflow_gpu = self.has_overflow_serial(params).byte().to(get_accelerator().current_device_name()) - # Since each model parallel GPU carries only part of the model, # make sure overflow flag is synced across all the model parallel GPUs self._model_parallel_all_reduce(tensor=overflow_gpu, op=dist.ReduceOp.MAX) diff --git a/deepspeed/utils/tensor_fragment.py b/deepspeed/utils/tensor_fragment.py index 053c8b5adad0..e97f32ac2bd8 100644 --- a/deepspeed/utils/tensor_fragment.py +++ b/deepspeed/utils/tensor_fragment.py @@ -127,6 +127,8 @@ def set_full_hp_grad(self, value): lp_frag_address = self._hp_mapping.lp_fragment_address value_fragment = torch.narrow(value.flatten(), 0, lp_frag_address.start, lp_frag_address.numel) lp_grad_fragment.data.copy_(value_fragment.data.reshape_as(lp_grad_fragment.data)) + if hasattr(self, '_zero_optimizer'): + self._zero_optimizer.update_offload_overflow_tracker(value) def safe_get_full_fp32_param(param): diff --git a/tests/unit/runtime/half_precision/test_dynamic_loss_scale.py b/tests/unit/runtime/half_precision/test_dynamic_loss_scale.py index 4b263172261c..ab104dda80dd 100644 --- a/tests/unit/runtime/half_precision/test_dynamic_loss_scale.py +++ b/tests/unit/runtime/half_precision/test_dynamic_loss_scale.py @@ -197,6 +197,9 @@ def test_no_overflow(self): def test_all_overflow(self): if not get_accelerator().is_fp16_supported(): pytest.skip("fp16 is not supported") + + min_loss_scale_value = 2.0 + config_dict = { "train_batch_size": 1, "steps_per_print": 1, @@ -211,7 +214,7 @@ def test_all_overflow(self): "loss_scale": 0, "initial_scale_power": 4, "loss_scale_window": 2, - "min_loss_scale": 0.25 + "min_loss_scale": min_loss_scale_value } } hidden_dim = 1 @@ -219,7 +222,7 @@ def test_all_overflow(self): model, optim, _, _ = deepspeed.initialize(config=config_dict, model=model, model_parameters=model.parameters()) expected_loss_scale = 2**4 - expected_min_loss_scale = 0.25 + expected_min_loss_scale = min_loss_scale_value # Ensure the dynamic loss scaler is correctly configured. assert optim.dynamic_loss_scale == True assert optim.cur_scale == expected_loss_scale diff --git a/tests/unit/runtime/half_precision/test_zero_optim_overflow.py b/tests/unit/runtime/half_precision/test_zero_optim_overflow.py new file mode 100644 index 000000000000..62995fbe104d --- /dev/null +++ b/tests/unit/runtime/half_precision/test_zero_optim_overflow.py @@ -0,0 +1,365 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +import deepspeed +from deepspeed.accelerator import get_accelerator +import pytest +import numpy as np +from unit.common import DistributedTest +from unit.simple_model import SimpleModel, random_dataloader +from deepspeed.utils import safe_set_full_grad + + +def has_inf_or_nan(x): + float_x = x.float() + nan = float_x.isnan() + inf = float_x.isinf() + inf_or_nan = nan.logical_or(inf) + return inf_or_nan.float().max() + + +def run_model_step(model, x_sample, y_label, grad_value): + loss = model(x_sample, y_label) + model.backward(loss) + for p in model.parameters(): + grad = torch.empty_like(p, dtype=p.dtype) + grad.fill_(grad_value) + safe_set_full_grad(p, grad) + model.step() + + +@pytest.mark.parametrize("zero_stage", [1, 2]) +@pytest.mark.parametrize("offload_optimizer", [False, True]) +class TestZeROFloat16(DistributedTest): + world_size = 2 + + def test_no_overflow(self, zero_stage, offload_optimizer): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") + + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00015 + } + }, + "fp16": { + "enabled": True, + "loss_scale": 0, + "initial_scale_power": 8, + "loss_scale_window": 2 + }, + "zero_optimization": { + "stage": zero_stage + } + } + + if offload_optimizer: + config_dict["zero_optimization"]["offload_optimizer"] = {"device": "cpu"} + + hidden_dim = 10 + model = SimpleModel(hidden_dim) + model, optim, _, _ = deepspeed.initialize(config=config_dict, model=model, model_parameters=model.parameters()) + + expected_loss_scale = 2**8 + expected_scale_window = 2 + # Ensure the dynamic loss scaler is correctly configured. + loss_scaler = optim.loss_scaler + + assert optim.dynamic_loss_scale == True + assert loss_scaler.cur_scale == expected_loss_scale + assert loss_scaler.scale_window == expected_scale_window + + num_iterations = 10 + grad_values = np.random.uniform(-0.1, 0.1, num_iterations) + data_loader = random_dataloader(model=model, + total_samples=num_iterations, + hidden_dim=hidden_dim, + device=model.device, + dtype=torch.float16) + for i, (batch, grad_value) in enumerate(zip(data_loader, grad_values)): + run_model_step(model, batch[0], batch[1], grad_value) + assert loss_scaler.cur_scale == expected_loss_scale + assert loss_scaler.cur_iter == (i + 1) + + if loss_scaler.cur_iter % expected_scale_window == 0: + expected_loss_scale *= 2 + + def test_all_overflow(self, zero_stage, offload_optimizer): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") + + overflow_gradients = [float('inf'), float('-inf')] + [float('nan')] * 6 + initial_scale_power = len(overflow_gradients) + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00015 + } + }, + "fp16": { + "enabled": True, + "loss_scale": 0, + "initial_scale_power": initial_scale_power, + "loss_scale_window": 2, + "hysteresis": 1, + }, + "zero_optimization": { + "stage": zero_stage, + } + } + + if offload_optimizer: + config_dict["zero_optimization"]["offload_optimizer"] = {"device": "cpu"} + + hidden_dim = 10 + model = SimpleModel(hidden_dim) + model, optim, _, _ = deepspeed.initialize(config=config_dict, model=model, model_parameters=model.parameters()) + + expected_loss_scale = 2**initial_scale_power + expected_scale_window = 2 + # Ensure the dynamic loss scaler is correctly configured. + loss_scaler = optim.loss_scaler + + assert optim.dynamic_loss_scale == True + assert loss_scaler.cur_scale == expected_loss_scale + assert loss_scaler.scale_window == expected_scale_window + + data_loader = random_dataloader(model=model, + total_samples=len(overflow_gradients), + hidden_dim=hidden_dim, + device=model.device, + dtype=torch.float16) + for i, (batch, grad_value) in enumerate(zip(data_loader, overflow_gradients)): + run_model_step(model, batch[0], batch[1], grad_value) + expected_loss_scale = max(expected_loss_scale / 2, 1) + assert loss_scaler.cur_scale == expected_loss_scale + assert loss_scaler.cur_iter == (i + 1) + + def test_some_overflow(self, zero_stage, offload_optimizer): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") + initial_scale_power = 8 + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00015 + } + }, + "fp16": { + "enabled": True, + "loss_scale": 0, + "initial_scale_power": initial_scale_power, + "loss_scale_window": 2, + "hysteresis": 1, + }, + "zero_optimization": { + "stage": zero_stage, + } + } + + if offload_optimizer: + config_dict["zero_optimization"]["offload_optimizer"] = {"device": "cpu"} + + hidden_dim = 10 + model = SimpleModel(hidden_dim) + model, optim, _, _ = deepspeed.initialize(config=config_dict, model=model, model_parameters=model.parameters()) + + expected_loss_scale = 2**initial_scale_power + expected_scale_window = 2 + # Ensure the dynamic loss scaler is correctly configured. + loss_scaler = optim.loss_scaler + + assert optim.dynamic_loss_scale == True + assert loss_scaler.cur_scale == expected_loss_scale + assert loss_scaler.scale_window == expected_scale_window + + expected_iteration = 0 + + # Run model with overflows to decrease scale + overflow_gradients = [float('inf'), float('nan')] + expected_iteration += len(overflow_gradients) + data_loader = random_dataloader(model=model, + total_samples=len(overflow_gradients), + hidden_dim=hidden_dim, + device=model.device, + dtype=torch.float16) + for batch, grad_value in zip(data_loader, overflow_gradients): + run_model_step(model, batch[0], batch[1], grad_value) + + expected_loss_scale /= (2**len(overflow_gradients)) + assert loss_scaler.cur_scale == expected_loss_scale + assert loss_scaler.cur_iter == expected_iteration + + # Run model scale_window + 1 times to increase scale once + normal_gradients = np.random.uniform(-0.1, 0.1, expected_scale_window + 1) + expected_iteration += len(normal_gradients) + data_loader = random_dataloader(model=model, + total_samples=len(normal_gradients), + hidden_dim=hidden_dim, + device=model.device, + dtype=torch.float16) + for batch, grad_value in zip(data_loader, normal_gradients): + run_model_step(model, batch[0], batch[1], grad_value) + + expected_loss_scale *= 2 + assert loss_scaler.cur_scale == expected_loss_scale + assert loss_scaler.cur_iter == expected_iteration + + # Run model with overflows to decrease scale + overflow_gradients = [float('inf')] + expected_iteration += len(overflow_gradients) + data_loader = random_dataloader(model=model, + total_samples=len(overflow_gradients), + hidden_dim=hidden_dim, + device=model.device, + dtype=torch.float16) + for batch, grad_value in zip(data_loader, overflow_gradients): + run_model_step(model, batch[0], batch[1], grad_value) + + expected_loss_scale /= (2**len(overflow_gradients)) + assert loss_scaler.cur_scale == expected_loss_scale + assert loss_scaler.cur_iter == expected_iteration + + +@pytest.mark.parametrize("zero_stage", [1, 2]) +@pytest.mark.parametrize("offload_optimizer", [False, True]) +class TestZeROBFloat16(DistributedTest): + world_size = 2 + + def test_no_overflow(self, zero_stage, offload_optimizer): + if not get_accelerator().is_bf16_supported(): + pytest.skip("bf16 is not supported") + + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00015 + } + }, + "bf16": { + "enabled": True, + }, + "zero_optimization": { + "stage": zero_stage + } + } + + if offload_optimizer: + config_dict["zero_optimization"]["offload_optimizer"] = {"device": "cpu"} + + hidden_dim = 10 + model = SimpleModel(hidden_dim) + model, optim, _, _ = deepspeed.initialize(config=config_dict, model=model, model_parameters=model.parameters()) + + num_iterations = 10 + grad_values = np.random.uniform(-0.1, 0.1, num_iterations) + data_loader = random_dataloader(model=model, + total_samples=num_iterations, + hidden_dim=hidden_dim, + device=model.device, + dtype=torch.bfloat16) + for i, (batch, grad_value) in enumerate(zip(data_loader, grad_values)): + run_model_step(model, batch[0], batch[1], grad_value) + + assert model.skipped_steps == 0 + assert all([not has_inf_or_nan(p) for p in model.parameters()]) + + def test_detect_grad_overflow(self, zero_stage, offload_optimizer): + if not get_accelerator().is_bf16_supported(): + pytest.skip("bf16 is not supported") + + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00015 + } + }, + "bf16": { + "enabled": True, + "check_grad_overflow": True + }, + "zero_optimization": { + "stage": zero_stage, + } + } + + if offload_optimizer: + config_dict["zero_optimization"]["offload_optimizer"] = {"device": "cpu"} + + hidden_dim = 10 + model = SimpleModel(hidden_dim) + model, optim, _, _ = deepspeed.initialize(config=config_dict, model=model, model_parameters=model.parameters()) + + overflow_gradients = [float('inf'), float('-inf')] + [float('nan')] * 6 + data_loader = random_dataloader(model=model, + total_samples=len(overflow_gradients), + hidden_dim=hidden_dim, + device=model.device, + dtype=torch.bfloat16) + + for i, (batch, grad_value) in enumerate(zip(data_loader, overflow_gradients)): + run_model_step(model, batch[0], batch[1], grad_value) + assert model.skipped_steps == (i + 1) + + assert all([not has_inf_or_nan(p) for p in model.parameters()]) + + def test_ignore_grad_overflow(self, zero_stage, offload_optimizer): + if not get_accelerator().is_bf16_supported(): + pytest.skip("bf16 is not supported") + + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00015 + } + }, + "bf16": { + "enabled": True, + "check_grad_overflow": False + }, + "zero_optimization": { + "stage": zero_stage, + } + } + + if offload_optimizer: + config_dict["zero_optimization"]["offload_optimizer"] = {"device": "cpu"} + + hidden_dim = 10 + model = SimpleModel(hidden_dim) + model, optim, _, _ = deepspeed.initialize(config=config_dict, model=model, model_parameters=model.parameters()) + + overflow_gradients = [float('inf'), float('-inf')] + [float('nan')] * 6 + data_loader = random_dataloader(model=model, + total_samples=len(overflow_gradients), + hidden_dim=hidden_dim, + device=model.device, + dtype=torch.bfloat16) + + for i, (batch, grad_value) in enumerate(zip(data_loader, overflow_gradients)): + run_model_step(model, batch[0], batch[1], grad_value) + + assert model.skipped_steps == 0 + assert all([has_inf_or_nan(p) for p in model.parameters()]) diff --git a/tests/unit/runtime/test_ds_config_dict.py b/tests/unit/runtime/test_ds_config_dict.py index d06b35e208fe..618fd2eca65a 100644 --- a/tests/unit/runtime/test_ds_config_dict.py +++ b/tests/unit/runtime/test_ds_config_dict.py @@ -19,7 +19,8 @@ # A test on its own import deepspeed -from deepspeed.runtime.config import DeepSpeedConfig, get_bfloat16_enabled +from deepspeed.runtime.config import DeepSpeedConfig +from deepspeed.runtime.precision_config import get_bfloat16_config class TestBasicConfig(DistributedTest): @@ -151,7 +152,7 @@ def test_get_bfloat16_enabled(bf16_key): "enabled": True, }, } - assert get_bfloat16_enabled(cfg) == True + assert get_bfloat16_config(cfg).enabled == True class TestConfigLoad(DistributedTest): diff --git a/tests/unit/runtime/zero/test_zero.py b/tests/unit/runtime/zero/test_zero.py index 2ae2755086f8..dd56ab237387 100644 --- a/tests/unit/runtime/zero/test_zero.py +++ b/tests/unit/runtime/zero/test_zero.py @@ -1198,10 +1198,15 @@ def create_tensor(vals): _assert_partition_status(ds_engine, {ZeroParamStatus.NOT_AVAILABLE}) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) class TestParamPartitioningSkipInit(DistributedTest): world_size = 2 - def test(self): + def test(self, dtype): + + if not dtype in get_accelerator().supported_dtypes(): + pytest.skip("{dtype} is not supported") + config_dict = { "train_batch_size": 4, "steps_per_print": 1, @@ -1215,10 +1220,8 @@ def test(self): "stage": 3 }, } - if get_accelerator().is_fp16_supported(): - config_dict["fp16"] = {"enabled": True} - elif get_accelerator().is_bf16_supported(): - config_dict["bf16"] = {"enabled": True} + dtype_str = "fp16" if dtype == torch.float16 else "bf16" + config_dict[dtype_str] = {"enabled": True} hidden_dim = 10 class SubModel(torch.nn.Module): @@ -1266,7 +1269,11 @@ def forward(self, x, y): assert model.l4.module_list[0].weight.ds_tensor.numel() == ds_tensor_numel model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) - data_loader = random_dataloader(model=model, total_samples=50, hidden_dim=hidden_dim, device=model.device) + data_loader = random_dataloader(model=model, + total_samples=16, + hidden_dim=hidden_dim, + device=model.device, + dtype=dtype) dist.barrier() for n, batch in enumerate(data_loader): loss = model(batch[0], batch[1])