-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_rl_attack_scratch.py
50 lines (40 loc) · 1.47 KB
/
train_rl_attack_scratch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import torch
import torchvision.datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from generator_mask import Generator
from discriminator import Discriminator
from lp_pretrained_attack_func import PVRL_Attack
import resnet_model
use_cuda=True
epochs = 100
batch_size = 128
BOX_MIN = 0
BOX_MAX = 1
# Define what device we are using
print("CUDA Available: ",torch.cuda.is_available())
device = torch.device("cuda" if (use_cuda and torch.cuda.is_available()) else "cpu")
# Load target model
checkpoint = torch.load("./save_temp/checkpoint.th")
targeted_model = torch.nn.DataParallel(resnet_model.__dict__['resnet32']())
targeted_model.cuda()
targeted_model.load_state_dict(checkpoint['state_dict'])
targeted_model.eval()
# load the generator of adversarial examples
pretrained_G = Generator().to(device)
pretrained_G.train()
# load the discriminator of adversarial examples
pretrained_Disc = Discriminator().to(device)
pretrained_Disc.train()
num_classes = 10
# CIFAR train dataset and dataloader declaration
cifar_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, transform=transforms.ToTensor(), download=True)
dataloader = DataLoader(cifar_dataset, batch_size=batch_size, shuffle=True, num_workers=1)
pvrl = PVRL_Attack(device=device, \
model=targeted_model, \
generator=pretrained_G, \
discriminator=pretrained_Disc, \
model_num_labels=num_classes,\
box_min=BOX_MIN,\
box_max=BOX_MAX)
pvrl.train(dataloader, epochs)