-
Notifications
You must be signed in to change notification settings - Fork 4
/
run.py
107 lines (86 loc) · 3.97 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import numpy as np
import logging
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
from gnn.argparser import parse_arguments
from gnn.dataset import TrafficDataset
from gnn.models import P3D
from gnn.backlog.models import GCN, STGCN, SLGCN, GCRNN
from utils import load_adjacency_matrix, save_model_to_path, get_device
logger = logging.getLogger(__name__)
MODEL_SAVE_PATH = "./saved_models/"
parser = parse_arguments()
args = parser.parse_args()
DEVICE = get_device(args.gpu)
def run_epoch(model, optimizer, dataloader, training=True):
mu, std = dataloader.dataset.mu, dataloader.dataset.std
mu = torch.tensor(mu, device=DEVICE)
std = torch.tensor(std, device=DEVICE)
bar = tqdm(dataloader)
losses = []
if training:
model.train()
else:
model.eval()
# print("epoch: {}".format(epoch + 1))
for sample_batched in bar:
optimizer.zero_grad()
x = sample_batched['features'].to(DEVICE).type(torch.float32)
y = sample_batched['labels'].to(DEVICE).type(torch.float32)
output = model(x)
output_denormalized = output * std + mu
y_denormalized = y * std + mu
loss = F.mse_loss(output, y)
loss_mse = F.mse_loss(output_denormalized, y_denormalized)
loss_mae = F.l1_loss(output_denormalized, y_denormalized)
if training:
loss.backward()
optimizer.step()
losses.append(loss_mse.item())
bar.set_description('loss_mae: {:.1f}, loss_mse: {:.1f}'.format(
loss_mae.item(), loss_mse.item()))
return np.mean(losses)
if __name__ == "__main__":
# load adjacency matrix
adj = load_adjacency_matrix(args, DEVICE)
# Model and optimizer
model = globals()[args.model](adj, args).to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=args.lr)
if args.log_file:
logging.basicConfig(filename=args.log_file, level=logging.INFO)
else:
logging.basicConfig(level=logging.INFO, format='# %(message)s')
print(f"Training model {args.model_name}")
logging.info(args)
if args.mode == 'train':
dataset_train = TrafficDataset(args, split='train')
dataset_val = TrafficDataset(args, split='val')
dataloader_train = DataLoader(dataset_train, batch_size=args.batch_size, shuffle=True, num_workers=1)
dataloader_val = DataLoader(dataset_val, batch_size=args.batch_size, shuffle=False, num_workers=1)
# Training
hist_loss = []
best_loss = np.inf
for epoch in range(args.n_epochs):
ml_train = run_epoch(model, optimizer, dataloader_train)
logger.info(f"epoch: {epoch}")
logger.info('Mean train-loss over batch: {:.4f}'.format(ml_train))
ml_val = run_epoch(model, optimizer, dataloader_val, training=False)
logger.info('Mean validation-loss over batch: {:.4f}'.format(ml_val))
hist_loss.append((ml_train, ml_val))
# save the model if the loss is lowest
if ml_val < best_loss:
best_loss = ml_val
save_model_to_path(args, model)
logger.info(f"Save model to path on epoch {epoch}")
np.save(f"./studies/losses/losses_on_{args.model_name}", hist_loss)
if args.mode == 'test':
dataset_test = TrafficDataset(args, split='test')
dataloader_test = DataLoader(dataset_test, batch_size=args.batch_size, shuffle=False, num_workers=1)
model.load_state_dict(torch.load(MODEL_SAVE_PATH + args.model_name + '.pt'))
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
logger.info('Iterate over the test-split...')
ml_test = run_epoch(model, optimizer, dataloader_test, training=False)
logger.info('Mean loss over test dataset: {:.4f}'.format(ml_test))