From b020aaa42941fd911daaa5924daebb6fe8590cef Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Fri, 4 Oct 2024 08:52:51 +0000 Subject: [PATCH] clean gathered parameters. see also #6557 --- deepspeed/runtime/zero/parameter_offload.py | 4 ++-- deepspeed/runtime/zero/partitioned_param_coordinator.py | 8 +++++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/deepspeed/runtime/zero/parameter_offload.py b/deepspeed/runtime/zero/parameter_offload.py index 1ce2414a1e17..4ca71197c735 100644 --- a/deepspeed/runtime/zero/parameter_offload.py +++ b/deepspeed/runtime/zero/parameter_offload.py @@ -225,8 +225,8 @@ def setup_zero_stage3_hooks(self): @instrument_w_nvtx def _end_of_forward_hook(module, *args): - if not torch._C.is_grad_enabled(): - self.get_param_coordinator(training=False).reset_step() + self.get_param_coordinator(training=False).reset_step() + self.get_param_coordinator(training=True).reset_step() #likely one of them should be enough but just to be safe self._register_hooks_recursively(self.module) diff --git a/deepspeed/runtime/zero/partitioned_param_coordinator.py b/deepspeed/runtime/zero/partitioned_param_coordinator.py index bdec8a55fcbc..9f7dab077c3c 100644 --- a/deepspeed/runtime/zero/partitioned_param_coordinator.py +++ b/deepspeed/runtime/zero/partitioned_param_coordinator.py @@ -205,8 +205,10 @@ def construct_parameter_trace_from_module_trace(self): def reset_step(self) -> None: """indicate that we have completed one fwd+bwd for the model""" if self.__inflight_param_registry: - raise RuntimeError(f"still have inflight params " - f"{[p.ds_summary() for p in self.__inflight_param_registry.keys()]}") + for param, handle in self.__inflight_param_registry.items(): + handle.wait() + self.__release_param(param) + self.__inflight_param_registry.clear() if not self.is_complete_trace(): # not self.trace_complete: # Make sure that recorded submodule orders are identical across ranks @@ -409,7 +411,7 @@ def release_and_reset_all(self, module: Module) -> None: """release all module parameters""" for param in iter_params(module, recurse=True): if param in self.__inflight_param_registry: - raise RuntimeError(f"param {param.ds_summary()} still in flight") + self.__inflight_param_registry.pop(param).wait() # TODO. make this throw if if there are still active submodules. currently # there's a hook execution issue