Skip to content

Commit

Permalink
[colossalai/checkpoint_io/...] fix bug in load_state_dict_into_model;…
Browse files Browse the repository at this point in the history
… format error msg (hpcaitech#6020)

* fix bug in load_state_dict_into_model; format error msg

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update utils.py

to support checking missing_keys

* Update general_checkpoint_io.py

fix bug in missing_keys error message

* retrigger tests

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
flymin and pre-commit-ci[bot] authored Sep 2, 2024
1 parent e96a076 commit e9032fb
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 14 deletions.
6 changes: 3 additions & 3 deletions colossalai/checkpoint_io/general_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,9 @@ def load_sharded_model(
if strict:
remain_keys = reduce(lambda a, b: a & b, map(set, missing_keys))
if len(remain_keys) > 0:
error_msgs = "Missing key(s) in state_dict: {}. ".format(
", ".join('"{}"'.format(k) for k in missing_keys)
)
error_msgs = [
"Missing key(s) in state_dict: {}. ".format(", ".join('"{}"'.format(k) for k in remain_keys))
]
raise RuntimeError(
"Error(s) in loading state_dict for {}:\n\t{}".format(
self.__class__.__name__, "\n\t".join(error_msgs)
Expand Down
6 changes: 3 additions & 3 deletions colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,9 +381,9 @@ def _load(name: str):
remain_keys = remain_keys.union(set(missing_file_keys))
if len(remain_keys) > 0:
if strict:
error_msgs = "Missing key(s) in state_dict: {}. ".format(
", ".join('"{}"'.format(k) for k in missing_keys)
)
error_msgs = [
"Missing key(s) in state_dict: {}. ".format(", ".join('"{}"'.format(k) for k in missing_keys))
]
raise RuntimeError(
"Error(s) in loading state_dict for {}:\n\t{}".format(
self.__class__.__name__, "\n\t".join(error_msgs)
Expand Down
10 changes: 5 additions & 5 deletions colossalai/checkpoint_io/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,10 +553,10 @@ def load_state_dict_into_model(

def load(module: nn.Module, state_dict, prefix="", load_sub_module: bool = True):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
args = (state_dict, prefix, local_metadata, True, sub_missing_keys, [], error_msgs)
args = (state_dict, prefix, local_metadata, True, sub_missing_keys, unexpected_keys, error_msgs)
# Parameters of module and children will start with prefix. We can exit early if there are none in this
# state_dict
if len([key for key in state_dict if key.startswith(prefix)]) > 0:
if strict or len([key for key in state_dict if key.startswith(prefix)]) > 0:
module._load_from_state_dict(*args)
if load_sub_module:
for name, child in module._modules.items():
Expand All @@ -570,9 +570,9 @@ def load(module: nn.Module, state_dict, prefix="", load_sub_module: bool = True)

if strict:
if len(unexpected_keys) > 0:
error_msgs = "Unexpected key(s) in state_dict: {}. ".format(
", ".join('"{}"'.format(k) for k in unexpected_keys)
)
error_msgs = [
"Unexpected key(s) in state_dict: {}. ".format(", ".join('"{}"'.format(k) for k in unexpected_keys))
]
raise RuntimeError(
"Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs))
)
Expand Down
6 changes: 3 additions & 3 deletions colossalai/inference/core/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,9 @@ def _load(name: str):
remain_keys = remain_keys.union(set(missing_file_keys))
if len(remain_keys) > 0:
if strict:
error_msgs = "Missing key(s) in state_dict: {}. ".format(
", ".join('"{}"'.format(k) for k in missing_keys)
)
error_msgs = [
"Missing key(s) in state_dict: {}. ".format(", ".join('"{}"'.format(k) for k in missing_keys))
]
raise RuntimeError(
"Error(s) in loading state_dict for {}:\n\t{}".format(
self.__class__.__name__, "\n\t".join(error_msgs)
Expand Down

0 comments on commit e9032fb

Please sign in to comment.