diff --git a/deepspeed/runtime/zero/partitioned_param_coordinator.py b/deepspeed/runtime/zero/partitioned_param_coordinator.py index bdec8a55fcbc..5780b2afd6de 100644 --- a/deepspeed/runtime/zero/partitioned_param_coordinator.py +++ b/deepspeed/runtime/zero/partitioned_param_coordinator.py @@ -153,11 +153,18 @@ def is_invalid_trace(self) -> bool: def is_record_trace(self) -> bool: return self.__trace_mode == ZeRoTraceMode.RECORD + def _clean_inflight_param_registry(self) -> None: + for param, handle in self.__inflight_param_registry.items(): + handle.wait() + self.__release_param(param) + self.__inflight_param_registry.clear() + def _invalidate_trace(self) -> None: if self.is_invalid_trace(): raise RuntimeError("attempted to invalidate already invalid trace") self.__trace_mode = ZeRoTraceMode.INVALID self._clear_trace_structures() + self._clean_inflight_param_registry() def trace_prologue(self, sub_module: Module) -> None: if self.is_complete_trace(): @@ -204,9 +211,7 @@ def construct_parameter_trace_from_module_trace(self): def reset_step(self) -> None: """indicate that we have completed one fwd+bwd for the model""" - if self.__inflight_param_registry: - raise RuntimeError(f"still have inflight params " - f"{[p.ds_summary() for p in self.__inflight_param_registry.keys()]}") + self._clean_inflight_param_registry() if not self.is_complete_trace(): # not self.trace_complete: # Make sure that recorded submodule orders are identical across ranks @@ -409,7 +414,7 @@ def release_and_reset_all(self, module: Module) -> None: """release all module parameters""" for param in iter_params(module, recurse=True): if param in self.__inflight_param_registry: - raise RuntimeError(f"param {param.ds_summary()} still in flight") + self.__inflight_param_registry.pop(param).wait() # TODO. make this throw if if there are still active submodules. currently # there's a hook execution issue