Skip to content

Commit

Permalink
update lesson45
Browse files Browse the repository at this point in the history
  • Loading branch information
dragen1860 committed May 31, 2019
1 parent e2ba443 commit 39ae9ba
Show file tree
Hide file tree
Showing 7 changed files with 64 additions and 49 deletions.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions lesson45-Cifar10与ResNet18实战/.idea/encodings.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion lesson45-Cifar10与ResNet18实战/.idea/misc.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

61 changes: 28 additions & 33 deletions lesson45-Cifar10与ResNet18实战/.idea/workspace.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 6 additions & 2 deletions lesson45-Cifar10与ResNet18实战/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,17 @@ def main():

cifar_train = datasets.CIFAR10('cifar', True, transform=transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor()
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
]), download=True)
cifar_train = DataLoader(cifar_train, batch_size=batchsz, shuffle=True)

cifar_test = datasets.CIFAR10('cifar', False, transform=transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor()
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
]), download=True)
cifar_test = DataLoader(cifar_test, batch_size=batchsz, shuffle=True)

Expand Down
30 changes: 18 additions & 12 deletions lesson45-Cifar10与ResNet18实战/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,20 +57,20 @@ def __init__(self):
super(ResNet18, self).__init__()

self.conv1 = nn.Sequential(
nn.Conv2d(3, 16, kernel_size=3, stride=3, padding=0),
nn.BatchNorm2d(16)
nn.Conv2d(3, 64, kernel_size=3, stride=3, padding=0),
nn.BatchNorm2d(64)
)
# followed 4 blocks
# [b, 64, h, w] => [b, 128, h ,w]
self.blk1 = ResBlk(16, 32, stride=2)
self.blk1 = ResBlk(64, 128, stride=2)
# [b, 128, h, w] => [b, 256, h, w]
self.blk2 = ResBlk(32, 64, stride=2)
self.blk2 = ResBlk(128, 256, stride=2)
# # [b, 256, h, w] => [b, 512, h, w]
self.blk3 = ResBlk(64, 128, stride=2)
self.blk3 = ResBlk(256, 512, stride=2)
# # [b, 512, h, w] => [b, 1024, h, w]
self.blk4 = ResBlk(128, 256, stride=2)
self.blk4 = ResBlk(512, 512, stride=2)

self.outlayer = nn.Linear(256*1*1, 10)
self.outlayer = nn.Linear(512*1*1, 10)

def forward(self, x):
"""
Expand All @@ -86,7 +86,11 @@ def forward(self, x):
x = self.blk3(x)
x = self.blk4(x)

# print(x.shape)

# print('after conv:', x.shape) #[b, 512, 2, 2]
# [b, 512, h, w] => [b, 512, 1, 1]
x = F.adaptive_avg_pool2d(x, [1, 1])
# print('after pool:', x.shape)
x = x.view(x.size(0), -1)
x = self.outlayer(x)

Expand All @@ -96,17 +100,19 @@ def forward(self, x):


def main():
blk = ResBlk(64, 128)

blk = ResBlk(64, 128, stride=4)
tmp = torch.randn(2, 64, 32, 32)
out = blk(tmp)
print('block:', out.shape)


x = torch.randn(2, 3, 32, 32)
model = ResNet18()
tmp = torch.randn(2, 3, 32, 32)
out = model(tmp)
out = model(x)
print('resnet:', out.shape)




if __name__ == '__main__':
main()

0 comments on commit 39ae9ba

Please sign in to comment.