From 7d751ee8903187cd7cd82f43816acf1d5d0907ba Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com> Date: Wed, 9 Oct 2024 08:23:33 -0700 Subject: [PATCH] Clean up prefetched parameters (#6557) Parameters prefetched by ZeRO3 are sometimes not used. This occurs when the actual sub-module execution differs from previous tracing. As a result, the state of the allgather handle for such a parameter remains `INFLIGHT`, causing functions like `empty_partition_cache` to detect it and throw an error. This PR resolves the issue by ensuring that communication finishes and the parameters are freed. As this issue was mentioned in #6011, this includes the change of the branch. We need to merge #6011 first. --------- Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Co-authored-by: Olatunji Ruwase --- .../runtime/zero/partitioned_param_coordinator.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/deepspeed/runtime/zero/partitioned_param_coordinator.py b/deepspeed/runtime/zero/partitioned_param_coordinator.py index bdec8a55fcbc..5780b2afd6de 100644 --- a/deepspeed/runtime/zero/partitioned_param_coordinator.py +++ b/deepspeed/runtime/zero/partitioned_param_coordinator.py @@ -153,11 +153,18 @@ def is_invalid_trace(self) -> bool: def is_record_trace(self) -> bool: return self.__trace_mode == ZeRoTraceMode.RECORD + def _clean_inflight_param_registry(self) -> None: + for param, handle in self.__inflight_param_registry.items(): + handle.wait() + self.__release_param(param) + self.__inflight_param_registry.clear() + def _invalidate_trace(self) -> None: if self.is_invalid_trace(): raise RuntimeError("attempted to invalidate already invalid trace") self.__trace_mode = ZeRoTraceMode.INVALID self._clear_trace_structures() + self._clean_inflight_param_registry() def trace_prologue(self, sub_module: Module) -> None: if self.is_complete_trace(): @@ -204,9 +211,7 @@ def construct_parameter_trace_from_module_trace(self): def reset_step(self) -> None: """indicate that we have completed one fwd+bwd for the model""" - if self.__inflight_param_registry: - raise RuntimeError(f"still have inflight params " - f"{[p.ds_summary() for p in self.__inflight_param_registry.keys()]}") + self._clean_inflight_param_registry() if not self.is_complete_trace(): # not self.trace_complete: # Make sure that recorded submodule orders are identical across ranks @@ -409,7 +414,7 @@ def release_and_reset_all(self, module: Module) -> None: """release all module parameters""" for param in iter_params(module, recurse=True): if param in self.__inflight_param_registry: - raise RuntimeError(f"param {param.ds_summary()} still in flight") + self.__inflight_param_registry.pop(param).wait() # TODO. make this throw if if there are still active submodules. currently # there's a hook execution issue