From 14795a6626f97ec4ac4d560e121f70afb164b8d5 Mon Sep 17 00:00:00 2001 From: SkalskiP Date: Sat, 8 Apr 2023 17:31:52 +0200 Subject: [PATCH 01/18] =?UTF-8?q?=F0=9F=8F=9E=20updated=20plot=5Fimage=20a?= =?UTF-8?q?nd=20plot=5Fimages=5Fgrid?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- supervision/notebook/utils.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/supervision/notebook/utils.py b/supervision/notebook/utils.py index 8d7c79291..789eee12c 100644 --- a/supervision/notebook/utils.py +++ b/supervision/notebook/utils.py @@ -6,7 +6,9 @@ def plot_image( - image: np.ndarray, size: Tuple[int, int] = (10, 10), cmap: Optional[str] = "gray" + image: np.ndarray, + size: Tuple[int, int] = (12, 12), + cmap: Optional[str] = "gray" ) -> None: """ Plots image using matplotlib. @@ -27,12 +29,14 @@ def plot_image( >>> sv.plot_image(image, (16, 16)) ``` """ + plt.figure(figsize=size) + if image.ndim == 2: - plt.figure(figsize=size) plt.imshow(image, cmap=cmap) else: - plt.figure(figsize=size) plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) + + plt.axis('off') plt.show() @@ -41,6 +45,7 @@ def plot_images_grid( grid_size: Tuple[int, int], titles: Optional[List[str]] = None, size: Tuple[int, int] = (12, 12), + cmap: Optional[str] = "gray" ) -> None: """ Plots images in a grid using matplotlib. @@ -50,6 +55,7 @@ def plot_images_grid( grid_size (Tuple[int, int]): A tuple specifying the number of rows and columns for the grid. titles (Optional[List[str]]): A list of titles for each image. Defaults to None. size (Tuple[int, int]): A tuple specifying the width and height of the entire plot in inches. + cmap (str): the colormap to use for single channel images. Raises: ValueError: If the number of images exceeds the grid size. @@ -70,7 +76,6 @@ def plot_images_grid( >>> plot_images_grid(images, grid_size=(2, 2), titles=titles, figsize=(16, 16)) ``` """ - nrows, ncols = grid_size if len(images) > nrows * ncols: @@ -82,11 +87,13 @@ def plot_images_grid( for idx, ax in enumerate(axes.flat): if idx < len(images): - ax.imshow(cv2.cvtColor(images[idx], cv2.COLOR_BGR2RGB)) + if images[idx].ndim == 2: + ax.imshow(images[idx], cmap=cmap) + else: + ax.imshow(cv2.cvtColor(images[idx], cv2.COLOR_BGR2RGB)) + if titles is not None and idx < len(titles): ax.set_title(titles[idx]) - ax.axis("off") - else: - ax.axis("off") + ax.axis("off") plt.show() From d5d2035048c9bf17637629692b94a74aa48472f7 Mon Sep 17 00:00:00 2001 From: SkalskiP Date: Sat, 8 Apr 2023 19:01:48 +0200 Subject: [PATCH 02/18] =?UTF-8?q?=F0=9F=98=B7add=20mask=20support=20to=20D?= =?UTF-8?q?etection=20object?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- supervision/detection/annotate.py | 5 +- supervision/detection/core.py | 82 ++++++++++++++++++++++--------- 2 files changed, 61 insertions(+), 26 deletions(-) diff --git a/supervision/detection/annotate.py b/supervision/detection/annotate.py index 3c74653d0..9c49b8244 100644 --- a/supervision/detection/annotate.py +++ b/supervision/detection/annotate.py @@ -56,8 +56,9 @@ def annotate( np.ndarray: The image with the bounding boxes drawn on it """ font = cv2.FONT_HERSHEY_SIMPLEX - for i, (xyxy, confidence, class_id, tracker_id) in enumerate(detections): - x1, y1, x2, y2 = xyxy.astype(int) + for i in range(len(detections)): + x1, y1, x2, y2 = detections.xyxy[i].astype(int) + class_id = detections.class_id[i] if detections.class_id is not None else None idx = class_id if class_id is not None else i color = ( self.color.by_idx(idx) diff --git a/supervision/detection/core.py b/supervision/detection/core.py index 83fd76ef2..08bd88673 100644 --- a/supervision/detection/core.py +++ b/supervision/detection/core.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Iterator, List, Optional, Tuple, Union +from typing import Iterator, List, Optional, Tuple, Union, Any import numpy as np @@ -9,6 +9,48 @@ from supervision.geometry.core import Position +def _validate_xyxy(xyxy: Any, n: int) -> None: + is_valid = isinstance(xyxy, np.ndarray) and xyxy.shape == (n, 4) + if not is_valid: + raise ValueError("xyxy must be 2d np.ndarray with (n, 4) shape") + + +def _validate_mask(mask: Any, n: int) -> None: + is_valid = mask is None or ( + isinstance(mask, np.ndarray) + and len(mask.shape) == 3 and mask[0] == n + ) + if not is_valid: + raise ValueError("mask must be 3d np.ndarray with (n, W, H) shape") + + +def _validate_class_id(class_id: Any, n: int) -> None: + is_valid = class_id is None or ( + isinstance(class_id, np.ndarray) + and class_id.shape == (n,) + ) + if not is_valid: + raise ValueError("class_id must be None or 1d np.ndarray with (n,) shape") + + +def _validate_confidence(confidence: Any, n: int) -> None: + is_valid = confidence is None or ( + isinstance(confidence, np.ndarray) + and confidence.shape == (n,) + ) + if not is_valid: + raise ValueError("confidence must be None or 1d np.ndarray with (n,) shape") + + +def _validate_tracker_id(tracker_id: Any, n: int) -> None: + is_valid = tracker_id is None or ( + isinstance(tracker_id, np.ndarray) + and tracker_id.shape == (n,) + ) + if not is_valid: + raise ValueError("tracker_id must be None or 1d np.ndarray with (n,) shape") + + @dataclass class Detections: """ @@ -16,40 +58,25 @@ class Detections: Attributes: xyxy (np.ndarray): An array of shape `(n, 4)` containing the bounding boxes coordinates in format `[x1, y1, x2, y2]` + mask: (Optional[np.ndarray]): An array of shape `(n, W, H)` containing the segmentation masks. class_id (Optional[np.ndarray]): An array of shape `(n,)` containing the class ids of the detections. confidence (Optional[np.ndarray]): An array of shape `(n,)` containing the confidence scores of the detections. tracker_id (Optional[np.ndarray]): An array of shape `(n,)` containing the tracker ids of the detections. """ xyxy: np.ndarray + mask: np.Optional[np.ndarray] = None class_id: Optional[np.ndarray] = None confidence: Optional[np.ndarray] = None tracker_id: Optional[np.ndarray] = None def __post_init__(self): n = len(self.xyxy) - validators = [ - (isinstance(self.xyxy, np.ndarray) and self.xyxy.shape == (n, 4)), - self.class_id is None - or (isinstance(self.class_id, np.ndarray) and self.class_id.shape == (n,)), - self.confidence is None - or ( - isinstance(self.confidence, np.ndarray) - and self.confidence.shape == (n,) - ), - self.tracker_id is None - or ( - isinstance(self.tracker_id, np.ndarray) - and self.tracker_id.shape == (n,) - ), - ] - if not all(validators): - raise ValueError( - "xyxy must be 2d np.ndarray with (n, 4) shape, " - "class_id must be None or 1d np.ndarray with (n,) shape, " - "confidence must be None or 1d np.ndarray with (n,) shape, " - "tracker_id must be None or 1d np.ndarray with (n,) shape" - ) + _validate_xyxy(xyxy=self.xyxy, n=n) + _validate_mask(mask=self.mask, n=n) + _validate_class_id(class_id=self.class_id, n=n) + _validate_confidence(confidence=self.confidence, n=n) + _validate_tracker_id(tracker_id=self.tracker_id, n=n) def __len__(self): """ @@ -59,13 +86,14 @@ def __len__(self): def __iter__( self, - ) -> Iterator[Tuple[np.ndarray, Optional[float], int, Optional[Union[str, int]]]]: + ) -> Iterator[Tuple[np.ndarray, Optional[np.ndarray], Optional[float], Optional[int], Optional[int]]]: """ Iterates over the Detections object and yield a tuple of `(xyxy, confidence, class_id, tracker_id)` for each detection. """ for i in range(len(self.xyxy)): yield ( self.xyxy[i], + self.mask[i] if self.mask is not None else None, self.confidence[i] if self.confidence is not None else None, self.class_id[i] if self.class_id is not None else None, self.tracker_id[i] if self.tracker_id is not None else None, @@ -75,6 +103,12 @@ def __eq__(self, other: Detections): return all( [ np.array_equal(self.xyxy, other.xyxy), + any( + [ + self.mask is None and other.mask is None, + np.array_equal(self.mask, other.mask), + ] + ), any( [ self.class_id is None and other.class_id is None, From bcc4b25eb91e7e835d42ca465f655a05abeb5fc0 Mon Sep 17 00:00:00 2001 From: SkalskiP Date: Sat, 8 Apr 2023 19:05:07 +0200 Subject: [PATCH 03/18] =?UTF-8?q?=F0=9F=96=A4bake=20black=20happy?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- supervision/detection/annotate.py | 4 +++- supervision/detection/core.py | 24 ++++++++++++++---------- supervision/notebook/utils.py | 8 +++----- 3 files changed, 20 insertions(+), 16 deletions(-) diff --git a/supervision/detection/annotate.py b/supervision/detection/annotate.py index 9c49b8244..070a5d660 100644 --- a/supervision/detection/annotate.py +++ b/supervision/detection/annotate.py @@ -58,7 +58,9 @@ def annotate( font = cv2.FONT_HERSHEY_SIMPLEX for i in range(len(detections)): x1, y1, x2, y2 = detections.xyxy[i].astype(int) - class_id = detections.class_id[i] if detections.class_id is not None else None + class_id = ( + detections.class_id[i] if detections.class_id is not None else None + ) idx = class_id if class_id is not None else i color = ( self.color.by_idx(idx) diff --git a/supervision/detection/core.py b/supervision/detection/core.py index 08bd88673..40a391c65 100644 --- a/supervision/detection/core.py +++ b/supervision/detection/core.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Iterator, List, Optional, Tuple, Union, Any +from typing import Any, Iterator, List, Optional, Tuple, Union import numpy as np @@ -17,8 +17,7 @@ def _validate_xyxy(xyxy: Any, n: int) -> None: def _validate_mask(mask: Any, n: int) -> None: is_valid = mask is None or ( - isinstance(mask, np.ndarray) - and len(mask.shape) == 3 and mask[0] == n + isinstance(mask, np.ndarray) and len(mask.shape) == 3 and mask[0] == n ) if not is_valid: raise ValueError("mask must be 3d np.ndarray with (n, W, H) shape") @@ -26,8 +25,7 @@ def _validate_mask(mask: Any, n: int) -> None: def _validate_class_id(class_id: Any, n: int) -> None: is_valid = class_id is None or ( - isinstance(class_id, np.ndarray) - and class_id.shape == (n,) + isinstance(class_id, np.ndarray) and class_id.shape == (n,) ) if not is_valid: raise ValueError("class_id must be None or 1d np.ndarray with (n,) shape") @@ -35,8 +33,7 @@ def _validate_class_id(class_id: Any, n: int) -> None: def _validate_confidence(confidence: Any, n: int) -> None: is_valid = confidence is None or ( - isinstance(confidence, np.ndarray) - and confidence.shape == (n,) + isinstance(confidence, np.ndarray) and confidence.shape == (n,) ) if not is_valid: raise ValueError("confidence must be None or 1d np.ndarray with (n,) shape") @@ -44,8 +41,7 @@ def _validate_confidence(confidence: Any, n: int) -> None: def _validate_tracker_id(tracker_id: Any, n: int) -> None: is_valid = tracker_id is None or ( - isinstance(tracker_id, np.ndarray) - and tracker_id.shape == (n,) + isinstance(tracker_id, np.ndarray) and tracker_id.shape == (n,) ) if not is_valid: raise ValueError("tracker_id must be None or 1d np.ndarray with (n,) shape") @@ -86,7 +82,15 @@ def __len__(self): def __iter__( self, - ) -> Iterator[Tuple[np.ndarray, Optional[np.ndarray], Optional[float], Optional[int], Optional[int]]]: + ) -> Iterator[ + Tuple[ + np.ndarray, + Optional[np.ndarray], + Optional[float], + Optional[int], + Optional[int], + ] + ]: """ Iterates over the Detections object and yield a tuple of `(xyxy, confidence, class_id, tracker_id)` for each detection. """ diff --git a/supervision/notebook/utils.py b/supervision/notebook/utils.py index 789eee12c..79434ca7e 100644 --- a/supervision/notebook/utils.py +++ b/supervision/notebook/utils.py @@ -6,9 +6,7 @@ def plot_image( - image: np.ndarray, - size: Tuple[int, int] = (12, 12), - cmap: Optional[str] = "gray" + image: np.ndarray, size: Tuple[int, int] = (12, 12), cmap: Optional[str] = "gray" ) -> None: """ Plots image using matplotlib. @@ -36,7 +34,7 @@ def plot_image( else: plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) - plt.axis('off') + plt.axis("off") plt.show() @@ -45,7 +43,7 @@ def plot_images_grid( grid_size: Tuple[int, int], titles: Optional[List[str]] = None, size: Tuple[int, int] = (12, 12), - cmap: Optional[str] = "gray" + cmap: Optional[str] = "gray", ) -> None: """ Plots images in a grid using matplotlib. From 5b9eef363dd79ca09d739e39bcdc9bcbef6a16b5 Mon Sep 17 00:00:00 2001 From: SkalskiP Date: Sun, 9 Apr 2023 22:16:12 +0200 Subject: [PATCH 04/18] small fix :) --- supervision/detection/core.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/supervision/detection/core.py b/supervision/detection/core.py index 40a391c65..e4508d271 100644 --- a/supervision/detection/core.py +++ b/supervision/detection/core.py @@ -17,7 +17,7 @@ def _validate_xyxy(xyxy: Any, n: int) -> None: def _validate_mask(mask: Any, n: int) -> None: is_valid = mask is None or ( - isinstance(mask, np.ndarray) and len(mask.shape) == 3 and mask[0] == n + isinstance(mask, np.ndarray) and len(mask.shape) == 3 and mask.shape[0] == n ) if not is_valid: raise ValueError("mask must be 3d np.ndarray with (n, W, H) shape") @@ -239,6 +239,10 @@ def from_roboflow(cls, roboflow_result: dict, class_list: List[str]) -> Detectio class_id=np.array(class_id).astype(int), ) + @classmethod + def from_segment_anything_model(cls, segment_anything_model_result: List[dict]) -> Detections: + pass + @classmethod def from_coco_annotations(cls, coco_annotation: dict) -> Detections: xyxy, class_id = [], [] From ff2788945daaf62f6cec727850cd46a3ce8925f1 Mon Sep 17 00:00:00 2001 From: SkalskiP Date: Sun, 9 Apr 2023 22:57:29 +0200 Subject: [PATCH 05/18] =?UTF-8?q?=F0=9F=98=B7=20Initial=20implementation?= =?UTF-8?q?=20of=20MaskAnnotator?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- supervision/detection/annotate.py | 41 +++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/supervision/detection/annotate.py b/supervision/detection/annotate.py index 070a5d660..f733eec1f 100644 --- a/supervision/detection/annotate.py +++ b/supervision/detection/annotate.py @@ -117,3 +117,44 @@ def annotate( lineType=cv2.LINE_AA, ) return scene + + +class MaskAnnotator: + def __init__( + self, + color: Union[Color, ColorPalette] = ColorPalette.default(), + ): + self.color: Union[Color, ColorPalette] = color + + def annotate( + self, + scene: np.ndarray, + detections: Detections, + opacity: float = 0.5 + ) -> np.ndarray: + + for i in range(len(detections.xyxy)): + if detections.mask is None: + continue + + class_id = ( + detections.class_id[i] if detections.class_id is not None else None + ) + idx = class_id if class_id is not None else i + color = ( + self.color.by_idx(idx) + if isinstance(self.color, ColorPalette) + else self.color + ) + + mask = detections.mask[i] + colored_mask = np.zeros_like(scene, dtype=np.uint8) + colored_mask[:] = color.as_rgb() + + scene = np.where( + np.expand_dims(mask, axis=-1), + np.uint8(opacity * colored_mask + (1 - opacity) * scene), + scene + ) + + return scene \ No newline at end of file From 511375f24b4daffab9ce35befeafe13d15f01554 Mon Sep 17 00:00:00 2001 From: SkalskiP Date: Sun, 9 Apr 2023 23:05:56 +0200 Subject: [PATCH 06/18] =?UTF-8?q?=F0=9F=98=B7=20Top=20level=20import=20of?= =?UTF-8?q?=20MaskAnnotator?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- supervision/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/supervision/__init__.py b/supervision/__init__.py index 1211432dd..4b58d6b58 100644 --- a/supervision/__init__.py +++ b/supervision/__init__.py @@ -1,7 +1,7 @@ __version__ = "0.4.0" from supervision.annotation.voc import detections_to_voc_xml -from supervision.detection.annotate import BoxAnnotator +from supervision.detection.annotate import BoxAnnotator, MaskAnnotator from supervision.detection.core import Detections from supervision.detection.line_counter import LineZone, LineZoneAnnotator from supervision.detection.polygon_zone import PolygonZone, PolygonZoneAnnotator From e439b29f075639fde1fc5dd73a418533d67ef6d1 Mon Sep 17 00:00:00 2001 From: SkalskiP Date: Sun, 9 Apr 2023 23:21:04 +0200 Subject: [PATCH 07/18] Initial from_sam connector for Detections --- supervision/detection/annotate.py | 9 +++------ supervision/detection/core.py | 16 +++++++++++++--- supervision/detection/utils.py | 7 +++++++ 3 files changed, 23 insertions(+), 9 deletions(-) diff --git a/supervision/detection/annotate.py b/supervision/detection/annotate.py index f733eec1f..20353ef86 100644 --- a/supervision/detection/annotate.py +++ b/supervision/detection/annotate.py @@ -127,10 +127,7 @@ def __init__( self.color: Union[Color, ColorPalette] = color def annotate( - self, - scene: np.ndarray, - detections: Detections, - opacity: float = 0.5 + self, scene: np.ndarray, detections: Detections, opacity: float = 0.5 ) -> np.ndarray: for i in range(len(detections.xyxy)): @@ -154,7 +151,7 @@ def annotate( scene = np.where( np.expand_dims(mask, axis=-1), np.uint8(opacity * colored_mask + (1 - opacity) * scene), - scene + scene, ) - return scene \ No newline at end of file + return scene diff --git a/supervision/detection/core.py b/supervision/detection/core.py index e4508d271..b2e1b6805 100644 --- a/supervision/detection/core.py +++ b/supervision/detection/core.py @@ -5,7 +5,7 @@ import numpy as np -from supervision.detection.utils import non_max_suppression +from supervision.detection.utils import non_max_suppression, xywh_to_xyxy from supervision.geometry.core import Position @@ -240,8 +240,18 @@ def from_roboflow(cls, roboflow_result: dict, class_list: List[str]) -> Detectio ) @classmethod - def from_segment_anything_model(cls, segment_anything_model_result: List[dict]) -> Detections: - pass + def from_segment_anything_model( + cls, segment_anything_model_result: List[dict] + ) -> Detections: + sorted_generated_masks = sorted( + segment_anything_model_result, key=lambda x: x["area"], reverse=True + ) + + xywh = np.array([mask["bbox"] for mask in sorted_generated_masks]) + + mask = np.array([mask["segmentation"] for mask in sorted_generated_masks]) + + return Detections(xyxy=xywh_to_xyxy(boxes_xywh=xywh), mask=mask) @classmethod def from_coco_annotations(cls, coco_annotation: dict) -> Detections: diff --git a/supervision/detection/utils.py b/supervision/detection/utils.py index 82b8be3ec..ec50200d7 100644 --- a/supervision/detection/utils.py +++ b/supervision/detection/utils.py @@ -115,3 +115,10 @@ def clip_boxes( result[:, [0, 2]] = result[:, [0, 2]].clip(0, width) result[:, [1, 3]] = result[:, [1, 3]].clip(0, height) return result + + +def xywh_to_xyxy(boxes_xywh: np.ndarray) -> np.ndarray: + xyxy = boxes_xywh.copy() + xyxy[:, 2] = boxes_xywh[:, 0] + boxes_xywh[:, 2] + xyxy[:, 3] = boxes_xywh[:, 1] + boxes_xywh[:, 3] + return xyxy From bed0e41cb77c47b5142508cb4e7ee5c77be826d9 Mon Sep 17 00:00:00 2001 From: SkalskiP Date: Sun, 9 Apr 2023 23:24:10 +0200 Subject: [PATCH 08/18] =?UTF-8?q?=F0=9F=91=8A=20version=20bump?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- supervision/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/supervision/__init__.py b/supervision/__init__.py index 4b58d6b58..2e0e3a49b 100644 --- a/supervision/__init__.py +++ b/supervision/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.4.0" +__version__ = "0.5.0" from supervision.annotation.voc import detections_to_voc_xml from supervision.detection.annotate import BoxAnnotator, MaskAnnotator From ef5ee71f0b782f837a221f3f70726f974cd3396b Mon Sep 17 00:00:00 2001 From: SkalskiP Date: Mon, 10 Apr 2023 11:58:09 +0200 Subject: [PATCH 09/18] Update `sv.Detections.from_sam` API --- supervision/detection/core.py | 33 +++++++++++++++++++++++++-------- 1 file changed, 25 insertions(+), 8 deletions(-) diff --git a/supervision/detection/core.py b/supervision/detection/core.py index b2e1b6805..c3288cdd7 100644 --- a/supervision/detection/core.py +++ b/supervision/detection/core.py @@ -151,7 +151,7 @@ def from_yolov5(cls, yolov5_results) -> Detections: >>> from supervision import Detections >>> model = torch.hub.load('ultralytics/yolov5', 'yolov5s') - >>> results = model(frame) + >>> results = model(IMAGE) >>> detections = Detections.from_yolov5(results) ``` """ @@ -179,8 +179,8 @@ def from_yolov8(cls, yolov8_results) -> Detections: >>> from supervision import Detections >>> model = YOLO('yolov8s.pt') - >>> results = model(frame)[0] - >>> detections = Detections.from_yolov8(results) + >>> yolov8_results = model(IMAGE)[0] + >>> detections = Detections.from_yolov8(yolov8_results) ``` """ return cls( @@ -240,15 +240,32 @@ def from_roboflow(cls, roboflow_result: dict, class_list: List[str]) -> Detectio ) @classmethod - def from_segment_anything_model( - cls, segment_anything_model_result: List[dict] - ) -> Detections: + def from_sam(cls, sam_result: List[dict]) -> Detections: + """ + Creates a Detections instance from Segment Anything Model (SAM) by Meta AI. + + Args: + sam_result (List[dict]): The output Results instance from SAM + + Returns: + Detections: A new Detections object. + + Example: + ```python + >>> from segment_anything import sam_model_registry, SamAutomaticMaskGenerator + >>> import supervision as sv + + >>> sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH).to(device=DEVICE) + >>> mask_generator = SamAutomaticMaskGenerator(sam) + >>> sam_result = mask_generator.generate(IMAGE) + >>> detections = sv.Detections.from_sam(sam_result=sam_result) + ``` + """ sorted_generated_masks = sorted( - segment_anything_model_result, key=lambda x: x["area"], reverse=True + sam_result, key=lambda x: x["area"], reverse=True ) xywh = np.array([mask["bbox"] for mask in sorted_generated_masks]) - mask = np.array([mask["segmentation"] for mask in sorted_generated_masks]) return Detections(xyxy=xywh_to_xyxy(boxes_xywh=xywh), mask=mask) From 34dba17641946fb34ebae8a6d7a322131c543ea9 Mon Sep 17 00:00:00 2001 From: SkalskiP Date: Mon, 10 Apr 2023 12:00:00 +0200 Subject: [PATCH 10/18] =?UTF-8?q?=F0=9F=96=A4=20Make=20Black=20happy?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- supervision/detection/core.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/supervision/detection/core.py b/supervision/detection/core.py index c3288cdd7..e9e5e689a 100644 --- a/supervision/detection/core.py +++ b/supervision/detection/core.py @@ -242,25 +242,25 @@ def from_roboflow(cls, roboflow_result: dict, class_list: List[str]) -> Detectio @classmethod def from_sam(cls, sam_result: List[dict]) -> Detections: """ - Creates a Detections instance from Segment Anything Model (SAM) by Meta AI. + Creates a Detections instance from Segment Anything Model (SAM) by Meta AI. - Args: - sam_result (List[dict]): The output Results instance from SAM + Args: + sam_result (List[dict]): The output Results instance from SAM - Returns: - Detections: A new Detections object. + Returns: + Detections: A new Detections object. - Example: - ```python - >>> from segment_anything import sam_model_registry, SamAutomaticMaskGenerator - >>> import supervision as sv + Example: + ```python + >>> from segment_anything import sam_model_registry, SamAutomaticMaskGenerator + >>> import supervision as sv - >>> sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH).to(device=DEVICE) - >>> mask_generator = SamAutomaticMaskGenerator(sam) - >>> sam_result = mask_generator.generate(IMAGE) - >>> detections = sv.Detections.from_sam(sam_result=sam_result) - ``` - """ + >>> sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH).to(device=DEVICE) + >>> mask_generator = SamAutomaticMaskGenerator(sam) + >>> sam_result = mask_generator.generate(IMAGE) + >>> detections = sv.Detections.from_sam(sam_result=sam_result) + ``` + """ sorted_generated_masks = sorted( sam_result, key=lambda x: x["area"], reverse=True ) From aab7e238e41f6872400f1d35ec14fc2dde0b18ce Mon Sep 17 00:00:00 2001 From: SkalskiP Date: Mon, 10 Apr 2023 18:46:13 +0200 Subject: [PATCH 11/18] =?UTF-8?q?=F0=9F=98=B7=20to=20=F0=9F=93=A6=20initia?= =?UTF-8?q?l=20version=20of=20mask=5Fto=5Fxyxy=20util?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- supervision/detection/utils.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/supervision/detection/utils.py b/supervision/detection/utils.py index ec50200d7..e10e74898 100644 --- a/supervision/detection/utils.py +++ b/supervision/detection/utils.py @@ -122,3 +122,18 @@ def xywh_to_xyxy(boxes_xywh: np.ndarray) -> np.ndarray: xyxy[:, 2] = boxes_xywh[:, 0] + boxes_xywh[:, 2] xyxy[:, 3] = boxes_xywh[:, 1] + boxes_xywh[:, 3] return xyxy + + +def mask_to_xyxy(masks: np.ndarray) -> np.ndarray: + n = masks.shape[0] + bboxes = np.zeros((n, 4), dtype=int) + + for i, mask in enumerate(masks): + rows, cols = np.where(mask) + + if len(rows) > 0 and len(cols) > 0: + x_min, x_max = np.min(cols), np.max(cols) + y_min, y_max = np.min(rows), np.max(rows) + bboxes[i, :] = [x_min, y_min, x_max, y_max] + + return bboxes From a14bb94b46a48dd8cb3e4863033e1ffd3b9e5ab3 Mon Sep 17 00:00:00 2001 From: SkalskiP Date: Mon, 10 Apr 2023 18:58:07 +0200 Subject: [PATCH 12/18] =?UTF-8?q?=F0=9F=98=B7=20to=20=F0=9F=93=A6=20initia?= =?UTF-8?q?l=20version=20of=20mask=5Fto=5Fxyxy=20util?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- supervision/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/supervision/__init__.py b/supervision/__init__.py index 2e0e3a49b..43278acb0 100644 --- a/supervision/__init__.py +++ b/supervision/__init__.py @@ -5,7 +5,7 @@ from supervision.detection.core import Detections from supervision.detection.line_counter import LineZone, LineZoneAnnotator from supervision.detection.polygon_zone import PolygonZone, PolygonZoneAnnotator -from supervision.detection.utils import generate_2d_mask +from supervision.detection.utils import generate_2d_mask, mask_to_xyxy from supervision.draw.color import Color, ColorPalette from supervision.draw.utils import draw_filled_rectangle, draw_polygon, draw_text from supervision.geometry.core import Point, Position, Rect From c37590ef3a6ed14744e2bbce2d493b3b8fb27675 Mon Sep 17 00:00:00 2001 From: SkalskiP Date: Mon, 10 Apr 2023 19:17:53 +0200 Subject: [PATCH 13/18] add `area` and `box_area` to `Detections` object --- supervision/detection/core.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/supervision/detection/core.py b/supervision/detection/core.py index e9e5e689a..12a0e3385 100644 --- a/supervision/detection/core.py +++ b/supervision/detection/core.py @@ -333,6 +333,20 @@ def __getitem__(self, index: np.ndarray) -> Detections: @property def area(self) -> np.ndarray: + """ + Calculate the area of each detection in the set of object detections. If masks field is defined property + returns are of each mask. If only box is given property return area of each box. + + Returns: + np.ndarray: An array of floats containing the area of each detection in the format of `(area_1, area_2, ..., area_n)`, where n is the number of detections. + """ + if self.mask is not None: + return np.ndarray([np.sum(mask) for mask in self.mask]) + else: + return self.box_area + + @property + def box_area(self) -> np.ndarray: """ Calculate the area of each bounding box in the set of object detections. From 5bc436b1c54e541b37abe977226261853f49bd99 Mon Sep 17 00:00:00 2001 From: SkalskiP Date: Mon, 10 Apr 2023 19:58:44 +0200 Subject: [PATCH 14/18] add `area` and `box_area` to `Detections` object --- supervision/detection/annotate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/supervision/detection/annotate.py b/supervision/detection/annotate.py index 20353ef86..a13dbbf44 100644 --- a/supervision/detection/annotate.py +++ b/supervision/detection/annotate.py @@ -146,7 +146,7 @@ def annotate( mask = detections.mask[i] colored_mask = np.zeros_like(scene, dtype=np.uint8) - colored_mask[:] = color.as_rgb() + colored_mask[:] = color.as_bgr() scene = np.where( np.expand_dims(mask, axis=-1), From f3c4a966c41cbab91ef787ba9fb68fbf6eebcf74 Mon Sep 17 00:00:00 2001 From: SkalskiP Date: Mon, 10 Apr 2023 22:57:08 +0200 Subject: [PATCH 15/18] changelog added --- docs/changelog.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/docs/changelog.md b/docs/changelog.md index adf05cb1b..481e017df 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -1,3 +1,10 @@ +### 0.5.0 April 10, 2023 + +- Added [[#58](https://github.com/roboflow/supervision/pull/58)]: `Detections.mask` to enable segmentation support. +- Added [[#58](https://github.com/roboflow/supervision/pull/58)]: `MaskAnnotator` to allow easy `Detections.mask` annotation. +- Added [[#58](https://github.com/roboflow/supervision/pull/58)]: `Detections.from_sam` to enable native Segment Anything Model (SAM) support. +- Changed [[#58](https://github.com/roboflow/supervision/pull/58)]: `Detections.area` behaviour to work not only with boxes but also with masks. + ### 0.4.0 April 5, 2023 - Added [[#46](https://github.com/roboflow/supervision/discussions/48)]: `Detections.empty` to allow easy creation of empty `Detections` objects. From d02bac75c4b3a1ce464cfc9a9759b4a65a464249 Mon Sep 17 00:00:00 2001 From: SkalskiP Date: Mon, 10 Apr 2023 23:07:53 +0200 Subject: [PATCH 16/18] Document changes in docs --- docs/detection/annotate.md | 6 +++++- docs/detection/utils.md | 6 +++++- supervision/detection/annotate.py | 16 ++++++++++++++++ supervision/detection/core.py | 2 +- supervision/detection/utils.py | 9 +++++++++ 5 files changed, 36 insertions(+), 3 deletions(-) diff --git a/docs/detection/annotate.md b/docs/detection/annotate.md index fcebdec70..3062ee8e2 100644 --- a/docs/detection/annotate.md +++ b/docs/detection/annotate.md @@ -1,3 +1,7 @@ ## BoxAnnotator -:::supervision.detection.annotate.BoxAnnotator \ No newline at end of file +:::supervision.detection.annotate.BoxAnnotator + +## MaskAnnotator + +:::supervision.detection.annotate.MaskAnnotator \ No newline at end of file diff --git a/docs/detection/utils.md b/docs/detection/utils.md index c19ba03c3..ace7e1f03 100644 --- a/docs/detection/utils.md +++ b/docs/detection/utils.md @@ -8,4 +8,8 @@ ## non_max_suppression -:::supervision.detection.utils.non_max_suppression \ No newline at end of file +:::supervision.detection.utils.non_max_suppression + +## mask_to_xyxy + +:::supervision.detection.utils.mask_to_xyxy \ No newline at end of file diff --git a/supervision/detection/annotate.py b/supervision/detection/annotate.py index a13dbbf44..78dd899c3 100644 --- a/supervision/detection/annotate.py +++ b/supervision/detection/annotate.py @@ -124,12 +124,28 @@ def __init__( self, color: Union[Color, ColorPalette] = ColorPalette.default(), ): + """ + A class for overlaying masks on an image using detections provided. + + Attributes: + color (Union[Color, ColorPalette]): The color to fill the mask, can be a single color or a color palette + """ self.color: Union[Color, ColorPalette] = color def annotate( self, scene: np.ndarray, detections: Detections, opacity: float = 0.5 ) -> np.ndarray: + """ + Overlays the masks on the given image based on the provided detections, with a specified opacity. + Parameters: + scene (np.ndarray): The image on which the masks will be overlaid + detections (Detections): The detections for which the masks will be overlaid + opacity (float): The opacity of the masks, between 0 and 1, default is 0.5 + + Returns: + np.ndarray: The image with the masks overlaid + """ for i in range(len(detections.xyxy)): if detections.mask is None: continue diff --git a/supervision/detection/core.py b/supervision/detection/core.py index 12a0e3385..6a00a9232 100644 --- a/supervision/detection/core.py +++ b/supervision/detection/core.py @@ -92,7 +92,7 @@ def __iter__( ] ]: """ - Iterates over the Detections object and yield a tuple of `(xyxy, confidence, class_id, tracker_id)` for each detection. + Iterates over the Detections object and yield a tuple of `(xyxy, mask, confidence, class_id, tracker_id)` for each detection. """ for i in range(len(self.xyxy)): yield ( diff --git a/supervision/detection/utils.py b/supervision/detection/utils.py index e10e74898..4db14be74 100644 --- a/supervision/detection/utils.py +++ b/supervision/detection/utils.py @@ -125,6 +125,15 @@ def xywh_to_xyxy(boxes_xywh: np.ndarray) -> np.ndarray: def mask_to_xyxy(masks: np.ndarray) -> np.ndarray: + """ + Converts a 3D `np.array` of 2D bool masks into a 2D `np.array` of bounding boxes. + + Parameters: + masks (np.ndarray): A 3D `np.array` of shape `(N, W, H)` containing 2D bool masks + + Returns: + np.ndarray: A 2D `np.array` of shape `(N, 4)` containing the bounding boxes `(x_min, y_min, x_max, y_max)` for each mask + """ n = masks.shape[0] bboxes = np.zeros((n, 4), dtype=int) From c1ccff474f0df7efdfc823300060115ead9f617c Mon Sep 17 00:00:00 2001 From: SkalskiP Date: Mon, 10 Apr 2023 23:15:35 +0200 Subject: [PATCH 17/18] Documentation for PolygonZone and PolygonZoneAnnotator --- docs/detection/tools/polygon_zone.md | 7 +++ mkdocs.yml | 2 + supervision/__init__.py | 2 +- supervision/detection/annotate.py | 12 ++--- supervision/detection/tools/__init__.py | 0 .../detection/{ => tools}/polygon_zone.py | 44 +++++++++++++++++++ 6 files changed, 60 insertions(+), 7 deletions(-) create mode 100644 docs/detection/tools/polygon_zone.md create mode 100644 supervision/detection/tools/__init__.py rename supervision/detection/{ => tools}/polygon_zone.py (57%) diff --git a/docs/detection/tools/polygon_zone.md b/docs/detection/tools/polygon_zone.md new file mode 100644 index 000000000..54a66da92 --- /dev/null +++ b/docs/detection/tools/polygon_zone.md @@ -0,0 +1,7 @@ +## PolygonZone + +:::supervision.detection.tools.polygon_zone.PolygonZone + +## PolygonZoneAnnotator + +:::supervision.detection.tools.polygon_zone.PolygonZoneAnnotator \ No newline at end of file diff --git a/mkdocs.yml b/mkdocs.yml index 89d444adc..14d963f79 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -29,6 +29,8 @@ nav: - Core: detection/core.md - Annotate: detection/annotate.md - Utils: detection/utils.md + - Tools: + - Polygon Zone: detection/tools/polygon_zone.md - Draw: - Utils: draw/utils.md - Annotations: diff --git a/supervision/__init__.py b/supervision/__init__.py index 43278acb0..eefd48c40 100644 --- a/supervision/__init__.py +++ b/supervision/__init__.py @@ -4,7 +4,7 @@ from supervision.detection.annotate import BoxAnnotator, MaskAnnotator from supervision.detection.core import Detections from supervision.detection.line_counter import LineZone, LineZoneAnnotator -from supervision.detection.polygon_zone import PolygonZone, PolygonZoneAnnotator +from supervision.detection.tools.polygon_zone import PolygonZone, PolygonZoneAnnotator from supervision.detection.utils import generate_2d_mask, mask_to_xyxy from supervision.draw.color import Color, ColorPalette from supervision.draw.utils import draw_filled_rectangle, draw_polygon, draw_text diff --git a/supervision/detection/annotate.py b/supervision/detection/annotate.py index 78dd899c3..98c5244e2 100644 --- a/supervision/detection/annotate.py +++ b/supervision/detection/annotate.py @@ -120,16 +120,16 @@ def annotate( class MaskAnnotator: + """ + A class for overlaying masks on an image using detections provided. + + Attributes: + color (Union[Color, ColorPalette]): The color to fill the mask, can be a single color or a color palette + """ def __init__( self, color: Union[Color, ColorPalette] = ColorPalette.default(), ): - """ - A class for overlaying masks on an image using detections provided. - - Attributes: - color (Union[Color, ColorPalette]): The color to fill the mask, can be a single color or a color palette - """ self.color: Union[Color, ColorPalette] = color def annotate( diff --git a/supervision/detection/tools/__init__.py b/supervision/detection/tools/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/supervision/detection/polygon_zone.py b/supervision/detection/tools/polygon_zone.py similarity index 57% rename from supervision/detection/polygon_zone.py rename to supervision/detection/tools/polygon_zone.py index da60a62c7..fb53ea30b 100644 --- a/supervision/detection/polygon_zone.py +++ b/supervision/detection/tools/polygon_zone.py @@ -13,6 +13,16 @@ class PolygonZone: + """ + A class for defining a polygon-shaped zone within a frame for detecting objects. + + Attributes: + polygon (np.ndarray): A numpy array defining the polygon vertices + frame_resolution_wh (Tuple[int, int]): The frame resolution (width, height) + triggering_position (Position): The position within the bounding box that triggers the zone (default: Position.BOTTOM_CENTER) + current_count (int): The current count of detected objects within the zone + mask (np.ndarray): The 2D bool mask for the polygon zone + """ def __init__( self, polygon: np.ndarray, @@ -30,6 +40,16 @@ def __init__( ) def trigger(self, detections: Detections) -> np.ndarray: + """ + Determines if the detections are within the polygon zone. + + Parameters: + detections (Detections): The detections to be checked against the polygon zone + + Returns: + np.ndarray: A boolean numpy array indicating if each detection is within the polygon zone + """ + clipped_xyxy = clip_boxes( boxes_xyxy=detections.xyxy, frame_resolution_wh=self.frame_resolution_wh ) @@ -43,6 +63,20 @@ def trigger(self, detections: Detections) -> np.ndarray: class PolygonZoneAnnotator: + """ + A class for annotating a polygon-shaped zone within a frame with a count of detected objects. + + Attributes: + zone (PolygonZone): The polygon zone to be annotated + color (Color): The color to draw the polygon lines + thickness (int): The thickness of the polygon lines, default is 2 + text_color (Color): The color of the text on the polygon, default is black + text_scale (float): The scale of the text on the polygon, default is 0.5 + text_thickness (int): The thickness of the text on the polygon, default is 1 + text_padding (int): The padding around the text on the polygon, default is 10 + font (int): The font type for the text on the polygon, default is cv2.FONT_HERSHEY_SIMPLEX + center (Tuple[int, int]): The center of the polygon for text placement + """ def __init__( self, zone: PolygonZone, @@ -64,6 +98,16 @@ def __init__( self.center = get_polygon_center(polygon=zone.polygon) def annotate(self, scene: np.ndarray, label: Optional[str] = None) -> np.ndarray: + """ + Annotates the polygon zone within a frame with a count of detected objects. + + Parameters: + scene (np.ndarray): The image on which the polygon zone will be annotated + label (Optional[str]): An optional label for the count of detected objects within the polygon zone (default: None) + + Returns: + np.ndarray: The image with the polygon zone and count of detected objects + """ annotated_frame = draw_polygon( scene=scene, polygon=self.zone.polygon, From ddc8a8cccef939f52c45474c91b3facf90666dc0 Mon Sep 17 00:00:00 2001 From: SkalskiP Date: Mon, 10 Apr 2023 23:16:04 +0200 Subject: [PATCH 18/18] =?UTF-8?q?=F0=9F=96=A4=20Make=20black=20happy?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- supervision/detection/annotate.py | 9 +++++---- supervision/detection/tools/polygon_zone.py | 2 ++ 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/supervision/detection/annotate.py b/supervision/detection/annotate.py index 98c5244e2..e78ee6e29 100644 --- a/supervision/detection/annotate.py +++ b/supervision/detection/annotate.py @@ -121,11 +121,12 @@ def annotate( class MaskAnnotator: """ - A class for overlaying masks on an image using detections provided. + A class for overlaying masks on an image using detections provided. + + Attributes: + color (Union[Color, ColorPalette]): The color to fill the mask, can be a single color or a color palette + """ - Attributes: - color (Union[Color, ColorPalette]): The color to fill the mask, can be a single color or a color palette - """ def __init__( self, color: Union[Color, ColorPalette] = ColorPalette.default(), diff --git a/supervision/detection/tools/polygon_zone.py b/supervision/detection/tools/polygon_zone.py index fb53ea30b..48c3eaa01 100644 --- a/supervision/detection/tools/polygon_zone.py +++ b/supervision/detection/tools/polygon_zone.py @@ -23,6 +23,7 @@ class PolygonZone: current_count (int): The current count of detected objects within the zone mask (np.ndarray): The 2D bool mask for the polygon zone """ + def __init__( self, polygon: np.ndarray, @@ -77,6 +78,7 @@ class PolygonZoneAnnotator: font (int): The font type for the text on the polygon, default is cv2.FONT_HERSHEY_SIMPLEX center (Tuple[int, int]): The center of the polygon for text placement """ + def __init__( self, zone: PolygonZone,