-
Notifications
You must be signed in to change notification settings - Fork 51
/
Copy pathFCNVMB_test.py
121 lines (100 loc) · 4.61 KB
/
FCNVMB_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# -*- coding: utf-8 -*-
"""
Fully Convolutional neural network (U-Net) for velocity model building from prestack
unmigrated seismic data directly
Created on Feb 2018
@author: fangshuyang ([email protected])
"""
################################################
######## IMPORT LIBARIES ########
################################################
from ParamConfig import *
from PathConfig import *
from LibConfig import *
################################################
######## LOAD NETWORK ########
################################################
# Here indicating the GPU you want to use. if you don't have GPU, just leave it.
cuda_available = torch.cuda.is_available()
device = torch.device("cuda" if cuda_available else "cpu")
model_file = models_dir+modelname+'_epoch'+str(Epochs)+'.pkl'
net = UnetModel(n_classes=Nclasses,in_channels=Inchannels, \
is_deconv=True,is_batchnorm=True)
net.load_state_dict(torch.load(model_file))
if torch.cuda.is_available():
net.cuda()
################################################
######## LOADING TESTING DATA ########
################################################
print('***************** Loading Testing DataSet *****************')
test_set,label_set,data_dsp_dim,label_dsp_dim = DataLoad_Test(test_size=TestSize,test_data_dir=test_data_dir, \
data_dim=DataDim,in_channels=Inchannels, \
model_dim=ModelDim,data_dsp_blk=data_dsp_blk, \
label_dsp_blk=label_dsp_blk,start=1601, \
datafilename=datafilename,dataname=dataname, \
truthfilename=truthfilename,truthname=truthname)
test = data_utils.TensorDataset(torch.from_numpy(test_set),torch.from_numpy(label_set))
test_loader = data_utils.DataLoader(test,batch_size=TestBatchSize,shuffle=False)
################################################
######## TESTING ########
################################################
print()
print('*******************************************')
print('*******************************************')
print(' START TESTING ')
print('*******************************************')
print('*******************************************')
print()
# Initialization
since = time.time()
TotPSNR = np.zeros((1,TestSize),dtype=float)
TotSSIM = np.zeros((1,TestSize),dtype=float)
Prediction = np.zeros((TestSize,label_dsp_dim[0],label_dsp_dim[1]),dtype=float)
GT = np.zeros((TestSize,label_dsp_dim[0],label_dsp_dim[1]),dtype=float)
total = 0
for i, (images,labels) in enumerate(test_loader):
images = images.view(TestBatchSize,Inchannels,data_dsp_dim[0],data_dsp_dim[1])
labels = labels.view(TestBatchSize,Nclasses,label_dsp_dim[0],label_dsp_dim[1])
images = images.to(device)
labels = labels.to(device)
# Predictions
net.eval()
outputs = net(images,label_dsp_dim)
outputs = outputs.view(TestBatchSize,label_dsp_dim[0],label_dsp_dim[1])
outputs = outputs.data.cpu().numpy()
gts = labels.data.cpu().numpy()
# Calculate the PSNR, SSIM
for k in range(TestBatchSize):
pd = outputs[k,:,:].reshape(label_dsp_dim[0],label_dsp_dim[1])
gt = gts[k,:,:].reshape(label_dsp_dim[0],label_dsp_dim[1])
pd = turn(pd)
gt = turn(gt)
Prediction[i*TestBatchSize+k,:,:] = pd
GT[i*TestBatchSize+k,:,:] = gt
psnr = PSNR(pd,gt)
TotPSNR[0,total] = psnr
ssim = SSIM(pd.reshape(-1,1,label_dsp_dim[0],label_dsp_dim[1]),gt.reshape(-1,1,label_dsp_dim[0],label_dsp_dim[1]))
TotSSIM[0,total] = ssim
print('The %d testing psnr: %.2f, SSIM: %.4f ' % (total,psnr,ssim))
total = total + 1
# Save Results
SaveTestResults(TotPSNR,TotSSIM,Prediction,GT,results_dir)
# Plot one prediction and ground truth
num = 0
if SimulateData:
minvalue = 2000
else:
minvalue = 1500
maxvalue = 4500
font2 = {'family': 'Times New Roman',
'weight': 'normal',
'size': 17,
}
font3 = {'family': 'Times New Roman',
'weight': 'normal',
'size': 21,
}
PlotComparison(Prediction[num,:,:],GT[num,:,:],label_dsp_dim,label_dsp_blk,dh,minvalue,maxvalue,font2,font3,SavePath=results_dir)
# Record the consuming time
time_elapsed = time.time() - since
print('Testing complete in {:.0f}m {:.0f}s' .format(time_elapsed // 60, time_elapsed % 60))