-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathhelper_functions.py
93 lines (74 loc) · 3.04 KB
/
helper_functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import cv2
import numpy as np
import tensorflow as tf
import config
CLASSES = config.CLASSES
# Define a list of colors for visualization
COLORS = np.random.randint(0, 255, size=(len(CLASSES), 3), dtype=np.uint8)
def preprocess_image(image_path, input_size):
''' Preprocess the input image to feed to the TFLite model
'''
img = tf.io.read_file(image_path)
img = tf.io.decode_image(img, channels=3)
img = tf.image.convert_image_dtype(img, tf.uint8)
original_image = img
resized_img = tf.image.resize(img, input_size)
resized_img = resized_img[tf.newaxis, :]
resized_img = tf.cast(resized_img, dtype=tf.uint8)
return resized_img, original_image
def detect_objects(interpreter, image, threshold):
''' Returns a list of detection results, each a dictionary of object info.
'''
signature_fn = interpreter.get_signature_runner()
# Feed the input image to the model
output = signature_fn(images=image)
# Get all outputs from the model
count = int(np.squeeze(output['output_0']))
scores = np.squeeze(output['output_1'])
classes = np.squeeze(output['output_2'])
boxes = np.squeeze(output['output_3'])
results = []
for i in range(count):
if scores[i] >= threshold:
result = {
'bounding_box': boxes[i],
'class_id': classes[i],
'score': scores[i]
}
results.append(result)
return results
def run_odt_and_draw_results(image_path, interpreter, threshold=0.5):
''' Run object detection on the input image and draw the detection results
'''
# Load the input shape required by the model
_, input_height, input_width, _ = interpreter.get_input_details()[0]['shape']
# Load the input image and preprocess it
preprocessed_image, original_image = preprocess_image(
image_path,
(input_height, input_width)
)
# Run object detection on the input image
results = detect_objects(interpreter, preprocessed_image, threshold=threshold)
# Plot the detection results on the input image
original_image_np = original_image.numpy().astype(np.uint8)
for obj in results:
# Convert the object bounding box from relative coordinates to absolute
# coordinates based on the original image resolution
ymin, xmin, ymax, xmax = obj['bounding_box']
xmin = int(xmin * original_image_np.shape[1])
xmax = int(xmax * original_image_np.shape[1])
ymin = int(ymin * original_image_np.shape[0])
ymax = int(ymax * original_image_np.shape[0])
# Find the class index of the current object
class_id = int(obj['class_id'])
# Draw the bounding box and label on the image
color = [int(c) for c in COLORS[class_id]]
cv2.rectangle(original_image_np, (xmin, ymin), (xmax, ymax), color, 2)
# Make adjustments to make the label visible for all objects
y = ymin - 15 if ymin - 15 > 15 else ymin + 15
label = "{}: {:.0f}%".format(CLASSES[class_id], obj['score'] * 100)
cv2.putText(original_image_np, label, (xmin, y),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
# Return the final image
original_uint8 = original_image_np.astype(np.uint8)
return original_uint8