Skip to content

Commit

Permalink
Merge branch 'bugfix_yannick' into 'master'
Browse files Browse the repository at this point in the history
fix bug with ddp finetuning

See merge request mic/internal/nnu-net!8
  • Loading branch information
FabianIsensee committed Apr 5, 2024
2 parents c7f85b7 + de2f2aa commit 9e2a877
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion nnunetv2/run/load_pretrained_weights.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
from torch._dynamo import OptimizedModule
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist


def load_pretrained_weights(network, fname, verbose=False):
Expand All @@ -15,7 +16,10 @@ def load_pretrained_weights(network, fname, verbose=False):
nnUNetTrainer.save_checkpoint takes care of that!
"""
saved_model = torch.load(fname)
if dist.is_initialized():
saved_model = torch.load(fname, map_location=torch.device('cuda', dist.get_rank()))
else:
saved_model = torch.load(fname)
pretrained_dict = saved_model['network_weights']

skip_strings_in_pretrained = [
Expand Down

0 comments on commit 9e2a877

Please sign in to comment.