Skip to content

Commit

Permalink
Clean up prefetched parameters (#6557)
Browse files Browse the repository at this point in the history
Parameters prefetched by ZeRO3 are sometimes not used. This occurs when
the actual sub-module execution differs from previous tracing. As a
result, the state of the allgather handle for such a parameter remains
`INFLIGHT`, causing functions like `empty_partition_cache` to detect it
and throw an error.
This PR resolves the issue by ensuring that communication finishes and
the parameters are freed.

As this issue was mentioned in #6011, this includes the change of the
branch. We need to merge #6011 first.

---------

Co-authored-by: Logan Adams <[email protected]>
Co-authored-by: Olatunji Ruwase <[email protected]>
  • Loading branch information
3 people authored Oct 9, 2024
1 parent 55f7f37 commit 7d751ee
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions deepspeed/runtime/zero/partitioned_param_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,11 +153,18 @@ def is_invalid_trace(self) -> bool:
def is_record_trace(self) -> bool:
return self.__trace_mode == ZeRoTraceMode.RECORD

def _clean_inflight_param_registry(self) -> None:
for param, handle in self.__inflight_param_registry.items():
handle.wait()
self.__release_param(param)
self.__inflight_param_registry.clear()

def _invalidate_trace(self) -> None:
if self.is_invalid_trace():
raise RuntimeError("attempted to invalidate already invalid trace")
self.__trace_mode = ZeRoTraceMode.INVALID
self._clear_trace_structures()
self._clean_inflight_param_registry()

def trace_prologue(self, sub_module: Module) -> None:
if self.is_complete_trace():
Expand Down Expand Up @@ -204,9 +211,7 @@ def construct_parameter_trace_from_module_trace(self):

def reset_step(self) -> None:
"""indicate that we have completed one fwd+bwd for the model"""
if self.__inflight_param_registry:
raise RuntimeError(f"still have inflight params "
f"{[p.ds_summary() for p in self.__inflight_param_registry.keys()]}")
self._clean_inflight_param_registry()

if not self.is_complete_trace(): # not self.trace_complete:
# Make sure that recorded submodule orders are identical across ranks
Expand Down Expand Up @@ -409,7 +414,7 @@ def release_and_reset_all(self, module: Module) -> None:
"""release all module parameters"""
for param in iter_params(module, recurse=True):
if param in self.__inflight_param_registry:
raise RuntimeError(f"param {param.ds_summary()} still in flight")
self.__inflight_param_registry.pop(param).wait()

# TODO. make this throw if if there are still active submodules. currently
# there's a hook execution issue
Expand Down

0 comments on commit 7d751ee

Please sign in to comment.