-
Notifications
You must be signed in to change notification settings - Fork 125
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #643 from roboflow/owlv2
Owlv2
- Loading branch information
Showing
12 changed files
with
512 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
|
||
<a href="https://arxiv.org/abs/2306.09683" target="_blank">OWLv2</a> is an open set object detectio model trained by Google. OWLv2 was primarily trained to detect objects from text. The implementation in `Inference` currently only supports detecting objects from visual examples of that object. | ||
|
||
### Installation | ||
|
||
To install inference with the extra dependencies necessary to run OWLv2, run | ||
|
||
```pip install inference[transformers]``` | ||
|
||
or | ||
|
||
```pip install inference-gpu[transformers]``` | ||
|
||
### How to Use OWLv2 | ||
|
||
Create a new Python file called `app.py` and add the following code: | ||
|
||
```python | ||
import inference | ||
from inference.models.owlv2.owlv2 import OwlV2 | ||
from inference.core.entities.requests.owlv2 import OwlV2InferenceRequest | ||
from PIL import Image | ||
import io | ||
import base64 | ||
|
||
model = OwlV2() | ||
|
||
|
||
im_url = "https://media.roboflow.com/inference/seawithdock.jpeg" | ||
image = { | ||
"type": "url", | ||
"value": im_url | ||
} | ||
request = OwlV2InferenceRequest( | ||
image=image, | ||
training_data=[ | ||
{ | ||
"image": image, | ||
"boxes": [{"x": 223, "y": 306, "w": 40, "h": 226, "cls": "post"}], | ||
} | ||
], | ||
visualize_predictions=True, | ||
confidence=0.9999, | ||
) | ||
|
||
response = OwlV2().infer_from_request(request) | ||
|
||
def load_image_from_base64(base64_str): | ||
image = Image.open(io.BytesIO(base64_str)) | ||
return image | ||
|
||
visualization = load_image_from_base64(response.visualization) | ||
visualization.save("owlv2_visualization.jpg") | ||
``` | ||
|
||
In this code, we run OWLv2 on an image, using example objects from that image. Above, replace: | ||
|
||
1. `training_data` with the locations of the objects you want to detect. | ||
2. `im_url` with the image you would like to perform inference on. | ||
|
||
Then, run the Python script you have created: | ||
|
||
``` | ||
python app.py | ||
``` | ||
|
||
The result from your model will be save to disk at `owlv2_visualization.jpg` | ||
|
||
Note the blue bounding boxes surrounding each pole of the dock. | ||
|
||
|
||
![OWLv2 results](https://media.roboflow.com/inference/owlv2_visualization.jpg) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
from typing import List, Optional, Union | ||
|
||
from pydantic import BaseModel, Field, validator | ||
|
||
from inference.core.entities.common import ApiKey | ||
from inference.core.entities.requests.inference import ( | ||
BaseRequest, | ||
InferenceRequestImage, | ||
) | ||
from inference.core.env import OWLV2_VERSION_ID | ||
|
||
|
||
class TrainBox(BaseModel): | ||
x: int = Field(description="Center x coordinate in pixels of train box") | ||
y: int = Field(description="Center y coordinate in pixels of train box") | ||
w: int = Field(description="Width in pixels of train box") | ||
h: int = Field(description="Height in pixels of train box") | ||
cls: str = Field(description="Class name of object this box encloses") | ||
|
||
|
||
class TrainingImage(BaseModel): | ||
boxes: List[TrainBox] = Field( | ||
description="List of boxes and corresponding classes of examples for the model to learn from" | ||
) | ||
image: InferenceRequestImage = Field( | ||
description="Image data that `boxes` describes" | ||
) | ||
|
||
|
||
class OwlV2InferenceRequest(BaseRequest): | ||
"""Request for gaze detection inference. | ||
Attributes: | ||
api_key (Optional[str]): Roboflow API Key. | ||
owlv2_version_id (Optional[str]): The version ID of Gaze to be used for this request. | ||
image (Union[List[InferenceRequestImage], InferenceRequestImage]): Image(s) for inference. | ||
training_data (List[TrainingImage]): Training data to ground the model on | ||
confidence (float): Confidence threshold to filter predictions by | ||
""" | ||
|
||
owlv2_version_id: Optional[str] = Field( | ||
default=OWLV2_VERSION_ID, | ||
examples=["owlv2-base-patch16-ensemble"], | ||
description="The version ID of owlv2 to be used for this request.", | ||
) | ||
model_id: Optional[str] = Field( | ||
default=None, description="Model id to be used in the request." | ||
) | ||
|
||
image: Union[List[InferenceRequestImage], InferenceRequestImage] = Field( | ||
description="Images to run the model on" | ||
) | ||
training_data: List[TrainingImage] = Field( | ||
description="Training images for the owlvit model to learn form" | ||
) | ||
confidence: Optional[float] = Field( | ||
default=0.99, | ||
examples=[0.99], | ||
description="Default confidence threshold for owlvit predictions. " | ||
"Needs to be much higher than you're used to, probably 0.99 - 0.9999", | ||
) | ||
visualize_predictions: Optional[bool] = Field( | ||
default=False, | ||
examples=[False], | ||
description="If true, return visualized predictions as a base64 string", | ||
) | ||
visualization_labels: Optional[bool] = Field( | ||
default=False, | ||
examples=[False], | ||
description="If true, labels will be rendered on prediction visualizations", | ||
) | ||
visualization_stroke_width: Optional[int] = Field( | ||
default=1, | ||
examples=[1], | ||
description="The stroke width used when visualizing predictions", | ||
) | ||
visualize_predictions: Optional[bool] = Field( | ||
default=False, | ||
examples=[False], | ||
description="If true, the predictions will be drawn on the original image and returned as a base64 string", | ||
) | ||
|
||
# TODO[pydantic]: We couldn't refactor the `validator`, please replace it by `field_validator` manually. | ||
# Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-validators for more information. | ||
@validator("model_id", always=True, allow_reuse=True) | ||
def validate_model_id(cls, value, values): | ||
if value is not None: | ||
return value | ||
if values.get("owl2_version_id") is None: | ||
return None | ||
return f"google/{values['owl2_version_id']}" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
Oops, something went wrong.