diff --git a/.github/workflows/cpu-torch-latest.yml b/.github/workflows/cpu-torch-latest.yml index 0bd42b3995bd..34eb4b0b2093 100644 --- a/.github/workflows/cpu-torch-latest.yml +++ b/.github/workflows/cpu-torch-latest.yml @@ -19,7 +19,7 @@ concurrency: jobs: unit-tests: - runs-on: ubuntu-22.04 + runs-on: ubuntu-24.04 steps: - uses: actions/checkout@v4 diff --git a/.github/workflows/no-torch.yml b/.github/workflows/no-torch.yml index eb3ac9b03161..1a13c0f3f4f1 100644 --- a/.github/workflows/no-torch.yml +++ b/.github/workflows/no-torch.yml @@ -19,7 +19,7 @@ permissions: jobs: unit-tests: - runs-on: ubuntu-22.04 + runs-on: ubuntu-24.04 steps: - uses: actions/checkout@v4 @@ -30,6 +30,7 @@ jobs: - name: Python environment run: | pip uninstall torch --yes + pip install setuptools pip list - name: Build deepspeed diff --git a/.github/workflows/nv-ds-chat.yml b/.github/workflows/nv-ds-chat.yml index 2ad336cac4ed..329a1060f5eb 100644 --- a/.github/workflows/nv-ds-chat.yml +++ b/.github/workflows/nv-ds-chat.yml @@ -12,6 +12,7 @@ on: type: string pull_request: paths: + - ".github/workflows/nv-ds-chat.yml" - "deepspeed/runtime/zero/stage_1_and_2.py" - "deepspeed/runtime/zero/stage3.py" - "deepspeed/runtime/hybrid_engine.py" @@ -42,6 +43,7 @@ jobs: - name: Install deepspeed run: | + pip install transformers==4.45.2 pip install .[dev] ds_report diff --git a/.github/workflows/nv-pre-compile-ops.yml b/.github/workflows/nv-pre-compile-ops.yml index 72ba8abbd95d..fc810bc190d0 100644 --- a/.github/workflows/nv-pre-compile-ops.yml +++ b/.github/workflows/nv-pre-compile-ops.yml @@ -21,7 +21,7 @@ concurrency: jobs: unit-tests: - runs-on: ubuntu-22.04 + runs-on: ubuntu-24.04 container: image: deepspeed/gh-builder:ubuntu1804-py38-torch1131-cu116 diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index 3103e3f36e84..37b68f1dbe80 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -21,10 +21,10 @@ jobs: unit-tests: strategy: matrix: - pyVersion: ["3.7", "3.8", "3.9", "3.10"] + pyVersion: ["3.8", "3.9", "3.10", "3.11", "3.12"] fail-fast: false - runs-on: ubuntu-22.04 + runs-on: ubuntu-24.04 container: image: deepspeed/gh-builder:py${{ matrix.pyVersion }} diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 02881ef12f39..eb763792f0c4 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -7,7 +7,7 @@ on: jobs: deploy: - runs-on: ubuntu-22.04 + runs-on: ubuntu-24.04 environment: release-env steps: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b89c872eefe5..b5d8afa8e0b4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -23,7 +23,7 @@ repos: - id: trailing-whitespace - repo: https://github.com/google/yapf - rev: v0.32.0 + rev: v0.40.0 hooks: - id: yapf @@ -65,7 +65,7 @@ repos: ] - repo: https://github.com/pycqa/flake8 - rev: 4.0.1 + rev: 5.0.4 hooks: - id: flake8 args: ['--config=.flake8'] diff --git a/csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu b/csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu index 25a494111c54..bbb8a7f00b1f 100644 --- a/csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu +++ b/csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu @@ -101,7 +101,15 @@ __global__ void apply_rotary_pos_half(T* mixed_query, #if defined(__HIP_PLATFORM_AMD__) and ROCM_WAVEFRONT_SIZE == 64 #define LAUNCH_FOR_ALIGNMENT(ALIGNMENT) \ - if (threads_per_head == 64) { \ + if (threads_per_head == 4) { \ + LAUNCH_ROT_POS_EMB_HALF(4, ALIGNMENT); \ + } else if (threads_per_head == 8) { \ + LAUNCH_ROT_POS_EMB_HALF(8, ALIGNMENT); \ + } else if (threads_per_head == 16) { \ + LAUNCH_ROT_POS_EMB_HALF(16, ALIGNMENT); \ + } else if (threads_per_head == 32) { \ + LAUNCH_ROT_POS_EMB_HALF(32, ALIGNMENT); \ + } else if (threads_per_head == 64) { \ LAUNCH_ROT_POS_EMB_HALF(64, ALIGNMENT); \ } else { \ assert(false); \ diff --git a/deepspeed/__init__.py b/deepspeed/__init__.py index d8655299282f..de405dc40edb 100755 --- a/deepspeed/__init__.py +++ b/deepspeed/__init__.py @@ -165,8 +165,8 @@ def initialize(args=None, if hasattr(args, "deepscale_config") and args.deepscale_config is not None: logger.warning("************ --deepscale_config is deprecated, please use --deepspeed_config ************") if hasattr(args, "deepspeed_config"): - assert (args.deepspeed_config is - None), "Not sure how to proceed, we were given both a deepscale_config and deepspeed_config" + assert (args.deepspeed_config + is None), "Not sure how to proceed, we were given both a deepscale_config and deepspeed_config" args.deepspeed_config = args.deepscale_config args.deepscale_config = None diff --git a/deepspeed/autotuning/autotuner.py b/deepspeed/autotuning/autotuner.py index dfd195bc37eb..a72b3c951e97 100755 --- a/deepspeed/autotuning/autotuner.py +++ b/deepspeed/autotuning/autotuner.py @@ -248,8 +248,8 @@ def mp_size(self): return self.autotuning_config.mp_size def max_train_micro_batch_size_per_gpu(self): - if self.max_train_batch_size( - ) and self.max_train_batch_size() > 0: # if the user specifies a max_train_batch_size + if self.max_train_batch_size() and self.max_train_batch_size( + ) > 0: # if the user specifies a max_train_batch_size max_train_micro_batch_size = self.max_train_batch_size() * self.mp_size() // ( self.exp_num_gpus * self.exp_num_nodes) # gradient accumulation steps >=1 return min(self.autotuning_config.max_train_micro_batch_size_per_gpu, max_train_micro_batch_size) @@ -964,8 +964,8 @@ def get_min_max_micro_batch_size(self, stage, min_micro_batch_size, calculated_m low = mid + 1 self.update_records(tuning_space_name, exp, metric_val, 1) used_micro_batch_sizes.append(mid) - if prev_metric_val and ( - (metric_val - prev_metric_val) / prev_metric_val) < METRIC_PERCENT_DIFF_CONST: + if prev_metric_val and ((metric_val - prev_metric_val) / + prev_metric_val) < METRIC_PERCENT_DIFF_CONST: logger.info(f"performance plateaus at mbs = {low}") break prev_metric_val = metric_val @@ -1026,8 +1026,8 @@ def get_tuning_micro_batch_size_list(self, min_micro_batch_size, max_micro_batch # NUM_GPUS=$(( ${NUM_WORKERS} * ${NUM_GPUS_PER_WORKER} )) # DP_SIZE=$(( ${NUM_GPUS} / (${PP_SIZE} * ${MP_SIZE}) )) # GRAD_ACC_STEPS=$(( ${TARGET_GLOBAL_BATCH_SIZE} / (${BATCH_SIZE} * ${DP_SIZE}) )) - if self.max_train_batch_size( - ) and self.max_train_batch_size() > 0: # if the user specifies a max_train_batch_size + if self.max_train_batch_size() and self.max_train_batch_size( + ) > 0: # if the user specifies a max_train_batch_size max_train_batch_size_per_gpu = self.max_train_batch_size() * self.mp_size() // (self.exp_num_gpus * self.exp_num_nodes) else: diff --git a/deepspeed/elasticity/elastic_agent.py b/deepspeed/elasticity/elastic_agent.py index c6a69dd2a49f..8fd4293d312c 100644 --- a/deepspeed/elasticity/elastic_agent.py +++ b/deepspeed/elasticity/elastic_agent.py @@ -160,8 +160,8 @@ def _invoke_run(self, role: str = "default") -> RunResult: f" Waiting {self._exit_barrier_timeout} seconds for other agents to finish.") self._exit_barrier() return run_result - elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED - } or len(participants) > len(rdzv_handler._state_holder.state.participants): + elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED} or len(participants) > len( + rdzv_handler._state_holder.state.participants): if self._remaining_restarts > 0: log.info(f"[{role}] Worker group {state.name}. " f"{self._remaining_restarts}/{spec.max_restarts} attempts left;" diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index 8b1455f20c69..1c5745dcf168 100644 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -496,9 +496,10 @@ def conv2d_parallel_shard_weights(model, rank, world_size): if not dist.is_initialized() or dist.get_rank() == 0: print("Saving tp-sharded checkpoints") torch.save( - OrderedDict({k: v - for k, v in dict(replaced_module.state_dict()).items() - if transformer_name not in k}), f'{config.save_mp_checkpoint_path}/{non_tp_ckpt_name}') + OrderedDict({ + k: v + for k, v in dict(replaced_module.state_dict()).items() if transformer_name not in k + }), f'{config.save_mp_checkpoint_path}/{non_tp_ckpt_name}') dtype_reprs = { torch.float32: 'float32', diff --git a/deepspeed/runtime/config.py b/deepspeed/runtime/config.py index 8be2f7ac4055..fb786f29722d 100755 --- a/deepspeed/runtime/config.py +++ b/deepspeed/runtime/config.py @@ -1012,8 +1012,8 @@ def _do_error_check(self): self.gradient_accumulation_steps), "DeepSpeedConfig: {} is not defined".format(GRADIENT_ACCUMULATION_STEPS) if self.zero_enabled: - assert (self.zero_optimization_stage <= - ZeroStageEnum.max_stage), "DeepSpeedConfig: Maximum supported ZeRO stage is {}".format( + assert (self.zero_optimization_stage + <= ZeroStageEnum.max_stage), "DeepSpeedConfig: Maximum supported ZeRO stage is {}".format( ZeroStageEnum.max_stage) if self.fp16_master_weights_and_gradients: diff --git a/deepspeed/runtime/eigenvalue.py b/deepspeed/runtime/eigenvalue.py index 36300eb904dd..a82d8b1d5c7a 100755 --- a/deepspeed/runtime/eigenvalue.py +++ b/deepspeed/runtime/eigenvalue.py @@ -114,8 +114,8 @@ def compute_eigenvalue(self, module, device=None, scale=1.0): eigenvalue_current, eigenvalue_previous = 1., 0. while (i < self.max_iter) and abs(eigenvalue_current) > 0 and (abs( - (eigenvalue_current - eigenvalue_previous) / eigenvalue_current) >= - self.tol): # test convergence criteria + (eigenvalue_current - eigenvalue_previous) / eigenvalue_current) + >= self.tol): # test convergence criteria eigenvalue_previous = eigenvalue_current Hv = torch.autograd.grad(grads, params, grad_outputs=v, only_inputs=True, retain_graph=True) diff --git a/deepspeed/runtime/pipe/engine.py b/deepspeed/runtime/pipe/engine.py index b75270cbd306..deb44c2e71eb 100644 --- a/deepspeed/runtime/pipe/engine.py +++ b/deepspeed/runtime/pipe/engine.py @@ -640,9 +640,10 @@ def _aggregate_total_loss(self): self.dp_group_loss = losses[0].clone().detach() agg_loss = losses[1].clone().detach() if additional_losses is not None: - self.agg_additional_losses = OrderedDict( - {name: losses[2 + i].clone().detach() - for i, name in enumerate(additional_losses.keys())}) + self.agg_additional_losses = OrderedDict({ + name: losses[2 + i].clone().detach() + for i, name in enumerate(additional_losses.keys()) + }) return agg_loss def set_dataloader(self, loader): diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py index b9617d3e632f..f48adb58c9bf 100755 --- a/deepspeed/runtime/utils.py +++ b/deepspeed/runtime/utils.py @@ -257,8 +257,8 @@ def has_overflow(self, params, has_moe_params=None): elif self.mpu is not None: if self.deepspeed is not None: using_pipeline = hasattr(self.deepspeed, 'pipeline_enable_backward_allreduce') - if (using_pipeline and self.deepspeed.pipeline_enable_backward_allreduce is False) or ( - not using_pipeline and self.deepspeed.enable_backward_allreduce is False): + if (using_pipeline and self.deepspeed.pipeline_enable_backward_allreduce + is False) or (not using_pipeline and self.deepspeed.enable_backward_allreduce is False): dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=self.mpu.get_data_parallel_group()) dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=self.mpu.get_model_parallel_group()) elif self.deepspeed is not None and self.deepspeed.enable_backward_allreduce is False: diff --git a/deepspeed/runtime/zero/parameter_offload.py b/deepspeed/runtime/zero/parameter_offload.py index 90afaf03550a..4b0ddb7679a9 100644 --- a/deepspeed/runtime/zero/parameter_offload.py +++ b/deepspeed/runtime/zero/parameter_offload.py @@ -133,7 +133,6 @@ def __init__( self.persistent_parameters = self.mark_persistent_parameters(self.param_numel_persistence_threshold, self.model_persistence_threshold) - self.param_coordinators = {} self._prefetch_bucket_sz = int(prefetch_bucket_size) self._max_reuse_distance_in_numel = int(max_reuse_distance) self._max_available_parameters_in_numel = int(max_live_parameters) @@ -141,12 +140,21 @@ def __init__( ) if overlap_comm else get_accelerator().default_stream() if not hasattr(module, "ds_inflight_param_registry"): - module.ds_inflight_param_registry = dict() - # we need two registries, one for training and one for eval. They will be used when creating PartitionedParameterCoordinator - module.ds_inflight_param_registry[True] = InflightParamRegistry() - module.ds_inflight_param_registry[False] = InflightParamRegistry() + module.ds_inflight_param_registry = InflightParamRegistry() self.__inflight_param_registry = module.ds_inflight_param_registry + self.param_coordinator = PartitionedParameterCoordinator( + prefetch_bucket_sz=self._prefetch_bucket_sz, + max_reuse_distance_in_numel=self._max_reuse_distance_in_numel, + max_available_parameters_in_numel=self._max_available_parameters_in_numel, + allgather_stream=self.__allgather_stream, + inflight_param_registry=self.__inflight_param_registry, + prefetch_nvme=self.offload_device == OffloadDeviceEnum.nvme, + timers=self.timers, + zero_quantized_weights=self.zero_quantized_weights, + zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights, + ) + self.forward_hooks = [] self.backward_hooks = [] self.setup_zero_stage3_hooks() @@ -161,26 +169,13 @@ def partition_all_parameters(self): """Partitioning Parameters that were not partitioned usually if parameters of modules whose input parameters do not require grad computation do not trigger post call and will therefore will remain unpartitioned""" - self.get_param_coordinator(training=self.module.training).release_and_reset_all(self.module) + self.get_param_coordinator().release_and_reset_all(self.module) for param in iter_params(self.module, recurse=True): if param.ds_status != ZeroParamStatus.NOT_AVAILABLE: raise RuntimeError(f"{param.ds_summary()} expected to be released") - def get_param_coordinator(self, training): - if not training in self.param_coordinators: - self.param_coordinators[training] = PartitionedParameterCoordinator( - prefetch_bucket_sz=self._prefetch_bucket_sz, - max_reuse_distance_in_numel=self._max_reuse_distance_in_numel, - max_available_parameters_in_numel=self._max_available_parameters_in_numel, - allgather_stream=self.__allgather_stream, - inflight_param_registry=self.__inflight_param_registry[training], - prefetch_nvme=self.offload_device == OffloadDeviceEnum.nvme, - timers=self.timers, - zero_quantized_weights=self.zero_quantized_weights, - zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights, - ) - - return self.param_coordinators[training] + def get_param_coordinator(self): + return self.param_coordinator def empty_partition_cache(self): self.partition_all_parameters() @@ -228,14 +223,14 @@ def setup_zero_stage3_hooks(self): #reset step if in inference mode @instrument_w_nvtx - def _end_of_forward_hook(module, *args): + def _start_of_forward_hook(module, *args): + + self.get_param_coordinator().reset_step() - if not torch._C.is_grad_enabled(): - self.get_param_coordinator(training=False).reset_step() + self.module.register_forward_pre_hook(_start_of_forward_hook) #likely one of them should be enough but just to be safe self._register_hooks_recursively(self.module) - self.module.register_forward_hook(_end_of_forward_hook) # Add top module to stack trace global FWD_MODULE_STACK @@ -447,7 +442,7 @@ def pre_sub_module_forward_function(self, sub_module): global FWD_MODULE_STACK FWD_MODULE_STACK.append(sub_module) - param_coordinator = self.get_param_coordinator(training=sub_module.training) + param_coordinator = self.get_param_coordinator() param_coordinator.trace_prologue(sub_module) if param_coordinator.is_record_trace(): param_coordinator.record_module(sub_module) @@ -460,7 +455,7 @@ def post_sub_module_forward_function(self, sub_module): see_memory_usage(f"After sub module function {sub_module.__class__.__name__} {sub_module.id} before release", force=False) - param_coordinator = self.get_param_coordinator(training=sub_module.training) + param_coordinator = self.get_param_coordinator() param_coordinator.release_sub_module(sub_module) see_memory_usage(f"After sub module function {sub_module.__class__.__name__} {sub_module.id} after release", @@ -468,8 +463,8 @@ def post_sub_module_forward_function(self, sub_module): @torch.no_grad() def pre_sub_module_backward_function(self, sub_module): - assert sub_module.training, "backward pass is invalid for module in evaluation mode" - param_coordinator = self.get_param_coordinator(training=True) + # assert sub_module.training, "backward pass is invalid for module in evaluation mode" + param_coordinator = self.get_param_coordinator() param_coordinator.trace_prologue(sub_module) if param_coordinator.is_record_trace(): param_coordinator.record_module(sub_module) @@ -477,12 +472,12 @@ def pre_sub_module_backward_function(self, sub_module): @torch.no_grad() def post_sub_module_backward_function(self, sub_module): - assert sub_module.training, "backward pass is invalid for module in evaluation mode" + # assert sub_module.training, "backward pass is invalid for module in evaluation mode" see_memory_usage( f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} before release", force=False) - self.get_param_coordinator(training=True).release_sub_module(sub_module) + self.get_param_coordinator().release_sub_module(sub_module) see_memory_usage( f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} after release", diff --git a/deepspeed/runtime/zero/partitioned_param_coordinator.py b/deepspeed/runtime/zero/partitioned_param_coordinator.py index 5780b2afd6de..49f477cc4a1b 100644 --- a/deepspeed/runtime/zero/partitioned_param_coordinator.py +++ b/deepspeed/runtime/zero/partitioned_param_coordinator.py @@ -18,6 +18,7 @@ from deepspeed.utils.debug import debug_module2name_id, debug_param2name_id from deepspeed.accelerator import get_accelerator import deepspeed.runtime.compiler as compiler +from deepspeed.runtime.compiler import is_compiling import logging @@ -92,7 +93,7 @@ def __init__( # keeps track of the number of submodules invoked so far. self.__step_id: int = 0 # network tracing mode - self.__trace_mode: ZeRoTraceMode = ZeRoTraceMode.RECORD + self.__trace_mode: ZeRoTraceMode = ZeRoTraceMode.INVALID # sequence of submodules/parameters in forward pass + backward pass self.__submodule_order: Iterable[Module] = [] self.__param_order: Iterable[__class__.__ParamInTrace] = [] @@ -188,6 +189,9 @@ def trace_prologue(self, sub_module: Module) -> None: @compiler.disable def record_module(self, sub_module: Module) -> None: """adds sub module to trace""" + if is_compiling(): + return + if not self.is_record_trace(): raise RuntimeError(f"attempted to record trace when status = {self.__trace_mode}") @@ -195,6 +199,8 @@ def record_module(self, sub_module: Module) -> None: self.__step_id_module_fetched_for[sub_module.id].append(self.__step_id) def record_parameters(self, sub_module: Module) -> None: + if is_compiling(): + return """adds sub module to trace""" if not self.is_record_trace(): raise RuntimeError(f"attempted to record trace when status = {self.__trace_mode}") @@ -209,8 +215,12 @@ def construct_parameter_trace_from_module_trace(self): for sub_module in self.__submodule_order: self.record_parameters(sub_module) + @compiler.disable def reset_step(self) -> None: """indicate that we have completed one fwd+bwd for the model""" + if is_compiling(): + return + self._clean_inflight_param_registry() if not self.is_complete_trace(): # not self.trace_complete: diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 65460eb72a2f..2c0c9d498d13 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -593,8 +593,8 @@ def defragment(tensors: List[Tensor]) -> Tensor: return device_buffer - def _get_param_coordinator(self, training): - return self.parameter_offload.get_param_coordinator(training) + def _get_param_coordinator(self): + return self.parameter_offload.get_param_coordinator() def _configure_offloading(self, offload_optimizer_config, offload_param_config): ###################### offload optimizer setup ################################## @@ -1874,7 +1874,7 @@ def _pre_step(self): see_memory_usage(f"In step before checking overflow", force=False) print_rank_0("Finished Tracing at Beginning of Step") - self._get_param_coordinator(training=True).hierarchy = 0 + self._get_param_coordinator().hierarchy = 0 print_rank_0("Finished Tracing at Beginning of Step") @@ -2258,8 +2258,6 @@ def backward(self, loss, retain_graph=False): else: self.loss_scaler.backward(loss.float(), retain_graph=retain_graph) - self._get_param_coordinator(training=True).reset_step() - if self.swap_optimizer: self.optimizer_swapper.post_backward() diff --git a/docker/gh-builder/Dockerfile.py311 b/docker/gh-builder/Dockerfile.py311 new file mode 100644 index 000000000000..603fb614314f --- /dev/null +++ b/docker/gh-builder/Dockerfile.py311 @@ -0,0 +1,35 @@ +# Start with NGC container +FROM nvcr.io/nvidia/pytorch:24.03-py3 + +# Set noninteractive mode for apt-get +ARG DEBIAN_FRONTEND=noninteractive + +# Install necessary dependencies for building Python +RUN apt-get update && apt-get install -y \ + wget \ + build-essential \ + libssl-dev \ + zlib1g-dev \ + libbz2-dev \ + libreadline-dev \ + libsqlite3-dev \ + curl \ + libncursesw5-dev \ + libgdbm-dev \ + libc6-dev \ + libffi-dev \ + tk-dev \ + && rm -rf /var/lib/apt/lists/* + +# Download and install Python 3.11 +RUN wget https://www.python.org/ftp/python/3.11.9/Python-3.11.9.tgz \ + && tar xzf Python-3.11.9.tgz \ + && cd Python-3.11.9 \ + && ./configure --enable-optimizations \ + && make altinstall \ + && cd .. \ + && rm -rf Python-3.11.9 Python-3.11.9.tgz + +# Set Python 3.11 as the default Python version +RUN update-alternatives --install /usr/bin/python python /usr/local/bin/python3.11 1 \ + && update-alternatives --install /usr/bin/python3 python3 /usr/local/bin/python3.11 1 diff --git a/docker/gh-builder/Dockerfile.py312 b/docker/gh-builder/Dockerfile.py312 new file mode 100644 index 000000000000..a0a7193201d4 --- /dev/null +++ b/docker/gh-builder/Dockerfile.py312 @@ -0,0 +1,35 @@ +# Start with NGC container +FROM nvcr.io/nvidia/pytorch:24.03-py3 + +# Set noninteractive mode for apt-get +ARG DEBIAN_FRONTEND=noninteractive + +# Install necessary dependencies for building Python +RUN apt-get update && apt-get install -y \ + wget \ + build-essential \ + libssl-dev \ + zlib1g-dev \ + libbz2-dev \ + libreadline-dev \ + libsqlite3-dev \ + curl \ + libncursesw5-dev \ + libgdbm-dev \ + libc6-dev \ + libffi-dev \ + tk-dev \ + && rm -rf /var/lib/apt/lists/* + +# Download and install Python 3.12 +RUN wget https://www.python.org/ftp/python/3.12.5/Python-3.12.5.tgz \ + && tar xzf Python-3.12.5.tgz \ + && cd Python-3.12.5 \ + && ./configure --enable-optimizations \ + && make altinstall \ + && cd .. \ + && rm -rf Python-3.12.5 Python-3.12.5.tgz + +# Set Python 3.12 as the default Python version +RUN update-alternatives --install /usr/bin/python python /usr/local/bin/python3.12 1 \ + && update-alternatives --install /usr/bin/python3 python3 /usr/local/bin/python3.12 1 diff --git a/docker/Dockerfile.rocm b/docker/rocm/Dockerfile similarity index 100% rename from docker/Dockerfile.rocm rename to docker/rocm/Dockerfile diff --git a/tests/unit/ops/transformer/inference/test_rope.py b/tests/unit/ops/transformer/inference/test_rope.py new file mode 100644 index 000000000000..1f0ca0578e04 --- /dev/null +++ b/tests/unit/ops/transformer/inference/test_rope.py @@ -0,0 +1,38 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch +import deepspeed +from deepspeed.ops.op_builder import InferenceBuilder +from deepspeed.accelerator import get_accelerator + +if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]: + pytest.skip("Inference ops are not available on this system", allow_module_level=True) + + +@pytest.mark.inference_ops +@pytest.mark.parametrize("num_heads", [64, 32, 16, 8]) +def test_rope_warp_size_alignment(num_heads): + if get_accelerator().device_name() != "cuda": + pytest.skip("This test runs only on GPU") + + batch = 1 + head = 8 + seq_len = 1024 + head_dim = 32 + rotary_dim = 32 + offset = 8 + rotate_half = False + rope_theta = 2 + + cuda0 = torch.device('cuda:0') + query = torch.randn(batch, head, seq_len, head_dim, device=cuda0) + key = torch.randn(batch, head, seq_len, head_dim, device=cuda0) + + inference = InferenceBuilder().load() + # For num_heads values of 64, 32, 16, 8 + # corresponding threads_per_head (defined in apply_rotary_pos_emb.cu) values are 4, 8, 16, 32 + inference.apply_rotary_pos_emb(query, key, rotary_dim, offset, num_heads, rotate_half, rope_theta) diff --git a/tests/unit/runtime/zero/test_zero.py b/tests/unit/runtime/zero/test_zero.py index 7262a1b2c998..5dffd70aab68 100644 --- a/tests/unit/runtime/zero/test_zero.py +++ b/tests/unit/runtime/zero/test_zero.py @@ -1628,3 +1628,48 @@ def test_empty_param_groups(self, dtype, use_client_optimizer, empty_weight_grou optimizer=optimizer, config=config_dict, ) + + +class TestZero3SwitchModes(DistributedTest): + world_size = 2 + + @pytest.mark.parametrize("prefetch_ratio", [0.0, 0.5, 1.0]) + def test(self, prefetch_ratio, zero_stage=3): + + hidden_dim = 10 + model = SimpleModel(hidden_dim) + + prefetch_bucket_size = int(sum([p.numel() for p in model.parameters(recurse=True)]) * prefetch_ratio) + config_dict = { + "train_micro_batch_size_per_gpu": 2, + "gradient_accumulation_steps": 2, + "zero_optimization": { + "stage": zero_stage, + "stage3_prefetch_bucket_size": prefetch_bucket_size + }, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-3 + } + }, + "fp16": { + "enabled": True, + "initial_scale_power": 8 + } + } + + model, _, _, _ = deepspeed.initialize(config=config_dict, model=model, model_parameters=model.parameters()) + data_loader = random_dataloader(model=model, total_samples=16, hidden_dim=hidden_dim, device=model.device) + + for _ in range(3): + model.train() + for batch in data_loader: + loss = model(batch[0], batch[1]) + model.backward(loss) + model.step() + + model.eval() + with torch.no_grad(): + for batch in data_loader: + loss = model(batch[0], batch[1]) diff --git a/tests/unit/runtime/zero/test_zero_context.py b/tests/unit/runtime/zero/test_zero_context.py index ec9e9e94aeaf..1d4fcd60022c 100644 --- a/tests/unit/runtime/zero/test_zero_context.py +++ b/tests/unit/runtime/zero/test_zero_context.py @@ -218,9 +218,9 @@ def test_throughput_calculation(self): engine.tput_timer.stop(global_step=global_step) duration = engine.tput_timer.end_time - engine.tput_timer.start_time # step elapsed time is reset after gradient accumulation steps - assert engine.tput_timer.step_elapsed_time == ( - 0 if engine.tput_timer.global_step_count != engine.tput_timer.start_step else current_duration + - duration) + assert engine.tput_timer.step_elapsed_time == (0 if engine.tput_timer.global_step_count + != engine.tput_timer.start_step else current_duration + + duration) assert engine.tput_timer.total_elapsed_time == total_duration + duration def test_ext_param_getattr(self): diff --git a/version.txt b/version.txt index 7ffdfa1cad65..1282fff53bfa 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.15.4 +0.15.5