Skip to content

Commit

Permalink
Adding protests analyser (#156)
Browse files Browse the repository at this point in the history
* Topic/tutorial (#147)

* complete tutorial 2

* complete tutorial 3, restructure main readme

* lint

* update commands

* Adding prostests analyser

* Fixing requirements.txt and useless imports

* Linted core and test scripts

* Fixed requirements syntax

* Fixed syntax errors in .yaml, moved utils.py into core.

* Added absolute path for model

* Return predictions as float instead of float32

* Linted core

* Compatibility with imagededup

* Automatic download of the model

* Deleted the now invalid api key

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Adding prostests analyser

* Fixing requirements.txt and useless imports

* Linted core and test scripts

* Fixed requirements syntax

* Fixed syntax errors in .yaml, moved utils.py into core.

* Added absolute path for model

* Return predictions as float instead of float32

* Linted core

* Compatibility with imagededup

* Automatic download of the model

* Deleted the now invalid api key

* fix import from commit

* rm empty spaces

* black fmt fix

* black fmt fix 2

Co-authored-by: Lachlan Kermode <[email protected]>
  • Loading branch information
Smoltbob and breezykermo authored Nov 22, 2020
1 parent f591abc commit 94b4ce1
Show file tree
Hide file tree
Showing 9 changed files with 188 additions and 1 deletion.
2 changes: 1 addition & 1 deletion .env.example
Original file line number Diff line number Diff line change
@@ -1 +1 @@
GOOGLE_API_KEY=my_key_from_gcp_project
GOOGLE_API_KEY=
70 changes: 70 additions & 0 deletions src/lib/analysers/ProtestsPretrained/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import numpy as np
import sys
import json
import os
import torch
from torch.autograd import Variable
from PIL import Image

from lib.common.analyser import Analyser
from lib.common.etypes import Etype, Union, Array
from lib.analysers.ProtestsPretrained.utils import transform, modified_resnet50, decode

PTH_TAR = "/mtriage/model.pth.tar"

# TODO cuda ?


class ProtestsPretrained(Analyser):
def pre_analyse(self, config):
"""
Init the logging, etc
Init the model
"""
rLabels = config["labels"]
self.THRESH = 0.0

t = transform()
model = modified_resnet50()
model.load_state_dict(
torch.load(
PTH_TAR,
map_location=torch.device("cpu"),
)["state_dict"]
)
model.eval()

def get_preds(img_path):
"""
Gives labelds and probabilities for a single image
This is were we preprocess the image, using a function defined in the model class
"""
# load image
img = Image.open(img_path).convert("RGB")
# process it
x = t(img)
# get in in the right format
x = Variable(x).unsqueeze(0)
# predictions
output = model(x)
# decode
output = decode(output.cpu().data.numpy()[0])
# filter
output = [(x[0], x[1]) for x in output if x[0] in rLabels]
output = [(x[0], float(x[1])) for x in output if x[1] >= self.THRESH]

return output

self.get_preds = get_preds

def analyse_element(
self, element: Union(Array(Etype.Image), Etype.Json), _
) -> Etype.Json:
self.logger(f"Running inference on frames in {element.id}...")
val = Etype.CvJson.from_preds(element, self.get_preds)
self.logger(f"Wrote predictions JSON for {element.id}.")
self.disk.delete_local_on_write = True
return val


module = ProtestsPretrained
Binary file added src/lib/analysers/ProtestsPretrained/image.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
6 changes: 6 additions & 0 deletions src/lib/analysers/ProtestsPretrained/info.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
desc: Classify the presence of protests and violence in images.
args:
- name: labels
desc: Filter results to a limited array of labels.
required: true
input: whitelist
1 change: 1 addition & 0 deletions src/lib/analysers/ProtestsPretrained/partial.Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
RUN wget -O /mtriage/model.pth.tar https://www.dropbox.com/s/vgh2nwxrzembxpw/model.pth.tar?dl=0
4 changes: 4 additions & 0 deletions src/lib/analysers/ProtestsPretrained/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
torch==1.4
torchvision==0.5
pillow==6.1.0
numpy<1.17
48 changes: 48 additions & 0 deletions src/lib/analysers/ProtestsPretrained/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import numpy as np
import os
import torch
from torch.autograd import Variable
from PIL import Image
from utils import transform, modified_resnet50, decode


def pre_analyse():
"""
Init the logging, etc
Init the model
Same as KerasPretrained
"""
t = transform()
model = modified_resnet50()
model.load_state_dict(
torch.load(
"model.pth.tar",
map_location=torch.device("cpu"),
)["state_dict"]
)
model.eval()

def get_preds(img_path):
"""
Gives labelds and probabilities for a single image
This is were we preprocess the image, using a function defined in the model class
"""
# load image
img = Image.open(img_path).convert("RGB")
# process it
x = t(img)
# get in in the right format
x = Variable(x).unsqueeze(0)
# predictions
output = model(x)
# decode
output = decode(output.cpu().data.numpy()[0])

# filter
# return pred, proba
return output

return get_preds("image.jpg")


print(pre_analyse())
56 changes: 56 additions & 0 deletions src/lib/analysers/ProtestsPretrained/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""
created by: Donghyeon Won
"""
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.models as models


def transform():
return transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)


def decode(preds):
classes = [
"protest",
"violence",
"sign",
"photo",
"fire",
"police",
"children",
"group_20",
"group_100",
"flag",
"night",
"shouting",
]
return [(x, preds[c]) for c, x in enumerate(classes)]


class FinalLayer(nn.Module):
"""modified last layer for resnet50 for our dataset"""

def __init__(self):
super(FinalLayer, self).__init__()
self.fc = nn.Linear(2048, 12)
self.sigmoid = nn.Sigmoid()

def forward(self, x):
out = self.fc(x)
out = self.sigmoid(out)
return out


def modified_resnet50():
model = models.resnet50(pretrained=True)
model.fc = FinalLayer()
return model
2 changes: 2 additions & 0 deletions src/lib/common/analyser.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import glob
import os
import shutil
import traceback
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Generator, List, Union, Tuple
Expand Down Expand Up @@ -196,3 +197,4 @@ def __attempt_analyse(self, attempts, element):
raise e
else:
self.error_logger(f"{str(e)}: skipping element", element)
print(traceback.format_exc())

0 comments on commit 94b4ce1

Please sign in to comment.