-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathpredict.py
110 lines (95 loc) · 4.79 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import os
import argparse
import numpy as np
import torch
from scipy import ndimage
from congestion_model import CongestionModel
import matplotlib.pyplot as plt
import pandas as pd
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
class CongestionPrediction():
def __init__(self,datapath,features,model_weight_path,device):
super(CongestionPrediction, self).__init__()
self.datapath = datapath
self.FeaturePathList = features
self.feature = self.data_process(self.FeaturePathList).unsqueeze(0).to(device)
self.model = CongestionModel(device).to(device)
self.device = device
checkpoint = torch.load(model_weight_path)
self.model.load_state_dict(checkpoint)
self.model.eval()
def resize(self,input):
dimension = input.shape
result = ndimage.zoom(input, (256 / dimension[0], 256 / dimension[1]), order=3)
return result
def std(self,input):
if input.max() == 0:
return input
else:
result = (input - input.min()) / (input.max() - input.min())
return result
def data_process(self,FeaturePathList):
features = []
for feature_name in FeaturePathList:
name = os.listdir(os.path.join(self.datapath, feature_name))[0]
feature = np.load(os.path.join(self.datapath, feature_name, name))
feature = self.std(self.resize(feature))
features.append(torch.as_tensor(feature))
features = torch.stack(features).type(torch.float32)
return features
def find_congestion_coord_and_value(self,tensor, threshold):
indices = torch.where(tensor >= threshold)
values = self.std(tensor[indices])
return np.array(list((indices[1].tolist(), indices[0].tolist(),values.tolist()))).T
def Prediction(self, congestion_threshold):
self.congestion_threshold = congestion_threshold
if self.device != 'cpu':
with torch.cuda.amp.autocast():
self.pred = self.model(self.feature)
self.pred = self.model.sigmoid(self.pred)
if self.device == 'cpu':
self.pred = self.model(self.feature)
self.pred = self.model.sigmoid(self.pred)
self.pred_coord = self.find_congestion_coord_and_value(self.pred.squeeze(), threshold=congestion_threshold)
self.pred_coord = pd.DataFrame(self.pred_coord,columns=['x','y','congestion'])
return self.pred, self.pred_coord
def ShowFig(self,fig_save_path):
if fig_save_path is None:
raise ValueError("Figure save path is not specified clear.")
plt.imshow(self.feature[0,0].detach().cpu().numpy())
plt.title(f"Congestion > {self.congestion_threshold}")
pts = plt.scatter(x=self.pred_coord['x'],y=self.pred_coord['y'],c=self.pred_coord['congestion'],cmap='jet',s=5)
# plt.legend([pts],["Congestion locate"])
plt.colorbar()
plt.savefig(f"{fig_save_path}/congestion_{self.congestion_threshold}.png")
plt.show()
def save(self,output_path):
np.save(f"{output_path}/PredArray",self.pred[0,0].detach().cpu().numpy())
self.pred_coord.to_csv(f"{output_path}/PredCoord.csv")
def parse_args():
description = "Input the Path for Prediction"
parser = argparse.ArgumentParser(description=description)
parser.add_argument("--data_path", default="./data", type=str, help='The path of the data file')
parser.add_argument("--fig_save_path", default="./save_img", type=str, help='The path you want to save fingue')
parser.add_argument("--weight_path", default="./model_weight/congestion2_weights.pt", type=str, help='The path of the model weight')
parser.add_argument("--output_path", default="./output", type=str, help='The path of the model weight')
parser.add_argument("--congestion_threshold", default=0.2, type=float, help='congestion_threshold [0,1]')
parser.add_argument("--device", default='cpu', type=str, help='If you have gpu type "cuda" will be faster!!')
args = parser.parse_args()
return args
if __name__ == "__main__":
import time
start = time.time()
feature_list = ['macro_region', 'RUDY', 'RUDY_pin']
args = parse_args()
predictionSystem = CongestionPrediction(datapath=args.data_path,features=feature_list,
model_weight_path=args.weight_path,device=args.device)
pred,pred_coord = predictionSystem.Prediction(congestion_threshold=args.congestion_threshold)
print("-------------congestion point------------------")
print(pred_coord)
print("-----------------------------------------------")
end = time.time()
print("cost time:%f sec" % (end - start))
predictionSystem.save(args.output_path)
if args.fig_save_path !=None:
predictionSystem.ShowFig(args.fig_save_path)