-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Topic/tutorial (#147) * complete tutorial 2 * complete tutorial 3, restructure main readme * lint * update commands * Adding prostests analyser * Fixing requirements.txt and useless imports * Linted core and test scripts * Fixed requirements syntax * Fixed syntax errors in .yaml, moved utils.py into core. * Added absolute path for model * Return predictions as float instead of float32 * Linted core * Compatibility with imagededup * Automatic download of the model * Deleted the now invalid api key * Update README.md * Update README.md * Update README.md * Update README.md * Adding prostests analyser * Fixing requirements.txt and useless imports * Linted core and test scripts * Fixed requirements syntax * Fixed syntax errors in .yaml, moved utils.py into core. * Added absolute path for model * Return predictions as float instead of float32 * Linted core * Compatibility with imagededup * Automatic download of the model * Deleted the now invalid api key * fix import from commit * rm empty spaces * black fmt fix * black fmt fix 2 Co-authored-by: Lachlan Kermode <[email protected]>
- Loading branch information
1 parent
f591abc
commit 94b4ce1
Showing
9 changed files
with
188 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
GOOGLE_API_KEY=my_key_from_gcp_project | ||
GOOGLE_API_KEY= |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
import numpy as np | ||
import sys | ||
import json | ||
import os | ||
import torch | ||
from torch.autograd import Variable | ||
from PIL import Image | ||
|
||
from lib.common.analyser import Analyser | ||
from lib.common.etypes import Etype, Union, Array | ||
from lib.analysers.ProtestsPretrained.utils import transform, modified_resnet50, decode | ||
|
||
PTH_TAR = "/mtriage/model.pth.tar" | ||
|
||
# TODO cuda ? | ||
|
||
|
||
class ProtestsPretrained(Analyser): | ||
def pre_analyse(self, config): | ||
""" | ||
Init the logging, etc | ||
Init the model | ||
""" | ||
rLabels = config["labels"] | ||
self.THRESH = 0.0 | ||
|
||
t = transform() | ||
model = modified_resnet50() | ||
model.load_state_dict( | ||
torch.load( | ||
PTH_TAR, | ||
map_location=torch.device("cpu"), | ||
)["state_dict"] | ||
) | ||
model.eval() | ||
|
||
def get_preds(img_path): | ||
""" | ||
Gives labelds and probabilities for a single image | ||
This is were we preprocess the image, using a function defined in the model class | ||
""" | ||
# load image | ||
img = Image.open(img_path).convert("RGB") | ||
# process it | ||
x = t(img) | ||
# get in in the right format | ||
x = Variable(x).unsqueeze(0) | ||
# predictions | ||
output = model(x) | ||
# decode | ||
output = decode(output.cpu().data.numpy()[0]) | ||
# filter | ||
output = [(x[0], x[1]) for x in output if x[0] in rLabels] | ||
output = [(x[0], float(x[1])) for x in output if x[1] >= self.THRESH] | ||
|
||
return output | ||
|
||
self.get_preds = get_preds | ||
|
||
def analyse_element( | ||
self, element: Union(Array(Etype.Image), Etype.Json), _ | ||
) -> Etype.Json: | ||
self.logger(f"Running inference on frames in {element.id}...") | ||
val = Etype.CvJson.from_preds(element, self.get_preds) | ||
self.logger(f"Wrote predictions JSON for {element.id}.") | ||
self.disk.delete_local_on_write = True | ||
return val | ||
|
||
|
||
module = ProtestsPretrained |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
desc: Classify the presence of protests and violence in images. | ||
args: | ||
- name: labels | ||
desc: Filter results to a limited array of labels. | ||
required: true | ||
input: whitelist |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
RUN wget -O /mtriage/model.pth.tar https://www.dropbox.com/s/vgh2nwxrzembxpw/model.pth.tar?dl=0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
torch==1.4 | ||
torchvision==0.5 | ||
pillow==6.1.0 | ||
numpy<1.17 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import numpy as np | ||
import os | ||
import torch | ||
from torch.autograd import Variable | ||
from PIL import Image | ||
from utils import transform, modified_resnet50, decode | ||
|
||
|
||
def pre_analyse(): | ||
""" | ||
Init the logging, etc | ||
Init the model | ||
Same as KerasPretrained | ||
""" | ||
t = transform() | ||
model = modified_resnet50() | ||
model.load_state_dict( | ||
torch.load( | ||
"model.pth.tar", | ||
map_location=torch.device("cpu"), | ||
)["state_dict"] | ||
) | ||
model.eval() | ||
|
||
def get_preds(img_path): | ||
""" | ||
Gives labelds and probabilities for a single image | ||
This is were we preprocess the image, using a function defined in the model class | ||
""" | ||
# load image | ||
img = Image.open(img_path).convert("RGB") | ||
# process it | ||
x = t(img) | ||
# get in in the right format | ||
x = Variable(x).unsqueeze(0) | ||
# predictions | ||
output = model(x) | ||
# decode | ||
output = decode(output.cpu().data.numpy()[0]) | ||
|
||
# filter | ||
# return pred, proba | ||
return output | ||
|
||
return get_preds("image.jpg") | ||
|
||
|
||
print(pre_analyse()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
""" | ||
created by: Donghyeon Won | ||
""" | ||
import torch | ||
import torch.nn as nn | ||
import torchvision.transforms as transforms | ||
import torchvision.models as models | ||
|
||
|
||
def transform(): | ||
return transforms.Compose( | ||
[ | ||
transforms.Resize(256), | ||
transforms.CenterCrop(224), | ||
transforms.ToTensor(), | ||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | ||
] | ||
) | ||
|
||
|
||
def decode(preds): | ||
classes = [ | ||
"protest", | ||
"violence", | ||
"sign", | ||
"photo", | ||
"fire", | ||
"police", | ||
"children", | ||
"group_20", | ||
"group_100", | ||
"flag", | ||
"night", | ||
"shouting", | ||
] | ||
return [(x, preds[c]) for c, x in enumerate(classes)] | ||
|
||
|
||
class FinalLayer(nn.Module): | ||
"""modified last layer for resnet50 for our dataset""" | ||
|
||
def __init__(self): | ||
super(FinalLayer, self).__init__() | ||
self.fc = nn.Linear(2048, 12) | ||
self.sigmoid = nn.Sigmoid() | ||
|
||
def forward(self, x): | ||
out = self.fc(x) | ||
out = self.sigmoid(out) | ||
return out | ||
|
||
|
||
def modified_resnet50(): | ||
model = models.resnet50(pretrained=True) | ||
model.fc = FinalLayer() | ||
return model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters