-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
support train methods: gpu -> DistributedDataParallel
- Loading branch information
1 parent
6de3c41
commit f10116a
Showing
1 changed file
with
151 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
import os | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import torch.optim as optim | ||
from tqdm import tqdm | ||
import torchvision | ||
import torchvision.transforms as transforms | ||
|
||
# define Net | ||
class Net(nn.Module): | ||
def __init__(self): | ||
super(Net, self).__init__() | ||
self.conv1 = nn.Conv2d(3, 6, 5) | ||
self.pool = nn.MaxPool2d(2, 2) | ||
self.conv2 = nn.Conv2d(6, 16, 5) | ||
self.fc1 = nn.Linear(16 * 5 * 5, 120) | ||
self.fc2 = nn.Linear(120, 84) | ||
self.fc3 = nn.Linear(84, 10) | ||
|
||
def forward(self, x): | ||
x = self.pool(F.relu(self.conv1(x))) | ||
x = self.pool(F.relu(self.conv2(x))) | ||
x = x.view(-1, 16 * 5 * 5) | ||
x = F.relu(self.fc1(x)) | ||
x = F.relu(self.fc2(x)) | ||
x = self.fc3(x) | ||
return x | ||
|
||
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') | ||
|
||
def get_dataset(mode="cpu", batch_size=64, num_workers=2, local_rank=0): | ||
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) | ||
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) | ||
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform) | ||
trainsampler = None | ||
trainloader = None | ||
|
||
# 1. cpu or cuda or DataParallel | ||
if mode == "cpu" or mode == "gpu" or mode == "DataParallel": | ||
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers) | ||
|
||
# 2. DistributedDataParallel | ||
elif mode == "DistributedDataParallel": | ||
####### 1. DistributedSampler设置分布式采样;有了sampler,shuffle=False即可;这里的batch_size不是总和,应该是每个GPU的batch_size;####### | ||
trainsampler = torch.utils.data.distributed.DistributedSampler(trainset, rank=local_rank) | ||
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=False, sampler=trainsampler) | ||
|
||
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_workers) | ||
|
||
return trainset, testset, trainloader, testloader, trainsampler | ||
|
||
# show some images | ||
def sample_imshow(trainloader, batch_size=64, save_dir='./sample'): | ||
dataiter = iter(trainloader) | ||
images, labels = dataiter.next() | ||
os.makedirs(save_dir, exist_ok=True) | ||
filename = os.path.join(save_dir, 'sample_grid.png') | ||
sample_grid = torchvision.utils.make_grid(images, nrow=int(batch_size ** 0.5), normalize=True, pad_value=1) | ||
torchvision.utils.save_image(sample_grid, filename) | ||
print(">>> save image to {}".format(filename)) | ||
# print(' '.join('%5s' % classes[labels[j]] for j in range(4))) | ||
|
||
def train(log_step_interval=100, save_step_interval=1000, eval_step_interval=200, save_path="ckpt", mode="cpu", epochs=10, batch_size=64, resume=""): | ||
save_path = f"{save_path}_{mode}" | ||
|
||
if mode == "cpu": | ||
device = torch.device("cpu") | ||
elif mode == "gpu" or mode == "DataParallel": | ||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | ||
elif mode == "DistributedDataParallel": | ||
n_gpus = 2 | ||
####### 2. init_process_group ####### | ||
torch.distributed.init_process_group("nccl", world_size=n_gpus) | ||
####### 3. set local device ####### | ||
local_rank = torch.distributed.get_rank() | ||
device = torch.device("cuda", local_rank) | ||
torch.cuda.set_device(local_rank) | ||
print(">>> use device : {}".format(device)) | ||
|
||
# 1. load dataset | ||
_, _, trainloader, _, trainsampler = get_dataset(mode=mode, batch_size=batch_size, local_rank=local_rank) | ||
|
||
# 2. define model | ||
model = Net() | ||
|
||
if local_rank == 0: | ||
sample_imshow(trainloader) | ||
print(">>> model param sum : {}".format(sum(p.numel() for p in model.parameters()))) | ||
|
||
# resume | ||
if resume != "": | ||
# load resume pth -> use [cpu/cuda/cuda:index] | ||
checkpoint = torch.load(resume, map_location=torch.device("cuda:0")) | ||
model.load_state_dict(checkpoint['model_state_dict']) | ||
optimizer.load_state_dict(checkpoint['optimizer_state_dict']) | ||
start_epoch = checkpoint['epoch'] | ||
start_step = checkpoint['step'] | ||
|
||
####### 4. model move to cuda && use DataParallel ####### | ||
model.to(device) | ||
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank]) | ||
|
||
# 3. define loss + optimizer | ||
criterion = nn.CrossEntropyLoss() | ||
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) | ||
|
||
# 4. train | ||
epoch_index = 0 | ||
start_epoch = 0 | ||
num_batches = len(trainloader) | ||
for epoch_index in range(start_epoch, start_epoch+epochs+1): | ||
ema_loss = 0.0 | ||
####### 5. random dataset in every rank ####### | ||
trainsampler.set_epoch(epoch_index) | ||
for batch_index, data in enumerate(trainloader): | ||
step = num_batches*(epoch_index) + batch_index + 1 | ||
####### 6. data move to cuda ####### | ||
inputs, labels = data[0].to(device), data[1].to(device) | ||
optimizer.zero_grad() | ||
outputs = model(inputs) | ||
loss = criterion(outputs, labels) | ||
loss.backward() | ||
optimizer.step() | ||
ema_loss = 0.9*ema_loss + 0.1*loss | ||
|
||
# log loss | ||
if step % log_step_interval == 0 and local_rank == 0: | ||
print(">>> epoch: {:5}, step: {:7}, loss: {:2.4f}".format(epoch_index, step, ema_loss.item())) | ||
|
||
# 5. save model pth | ||
if step % save_step_interval == 0 and local_rank == 0: | ||
os.makedirs(save_path, exist_ok=True) | ||
save_file = os.path.join(save_path, f"cifar_net_step_{step}.pth") | ||
torch.save({'epoch': epoch_index, | ||
####### 7. use model.module.state_dict() instead of model.state_dict() ####### | ||
'model_state_dict': model.module.state_dict(), | ||
'optimizer_state_dict': optimizer.state_dict(), | ||
'loss': loss, | ||
}, save_file) | ||
print(">>> save pth to {}".format(save_file)) | ||
if local_rank == 0: | ||
print('>>> Finished Training') | ||
torch.distributed.destroy_process_group() | ||
|
||
if __name__ == '__main__': | ||
####### batch_size = batch_size_origin ####### | ||
batch_size = 4 | ||
train(log_step_interval=1000, save_step_interval=10000, eval_step_interval=200, mode="DistributedDataParallel", epochs=100, batch_size=batch_size) | ||
|
||
# CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 04_cifar10_gpu_distributeddataparallel.py |