Skip to content

Commit

Permalink
Merge pull request #643 from roboflow/owlv2
Browse files Browse the repository at this point in the history
Owlv2
  • Loading branch information
PawelPeczek-Roboflow authored Sep 17, 2024
2 parents 0f6d249 + a2b8c7d commit 982a260
Show file tree
Hide file tree
Showing 12 changed files with 512 additions and 0 deletions.
1 change: 1 addition & 0 deletions docker/dockerfiles/Dockerfile.onnx.cpu
Original file line number Diff line number Diff line change
Expand Up @@ -73,5 +73,6 @@ ENV WORKFLOWS_STEP_EXECUTION_MODE=local
ENV WORKFLOWS_MAX_CONCURRENT_STEPS=1
ENV API_LOGGING_ENABLED=True
ENV CORE_MODEL_SAM2_ENABLED=True
ENV CORE_MODEL_OWLV2_ENABLED=True

ENTRYPOINT uvicorn cpu_http:app --workers $NUM_WORKERS --host $HOST --port $PORT
1 change: 1 addition & 0 deletions docker/dockerfiles/Dockerfile.onnx.gpu
Original file line number Diff line number Diff line change
Expand Up @@ -78,5 +78,6 @@ ENV WORKFLOWS_MAX_CONCURRENT_STEPS=1
ENV API_LOGGING_ENABLED=True
ENV LMM_ENABLED=True
ENV CORE_MODEL_SAM2_ENABLED=True
ENV CORE_MODEL_OWLV2_ENABLED=True

ENTRYPOINT uvicorn gpu_http:app --workers $NUM_WORKERS --host $HOST --port $PORT
72 changes: 72 additions & 0 deletions docs/foundation/owlv2.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@

<a href="https://arxiv.org/abs/2306.09683" target="_blank">OWLv2</a> is an open set object detectio model trained by Google. OWLv2 was primarily trained to detect objects from text. The implementation in `Inference` currently only supports detecting objects from visual examples of that object.

### Installation

To install inference with the extra dependencies necessary to run OWLv2, run

```pip install inference[transformers]```

or

```pip install inference-gpu[transformers]```

### How to Use OWLv2

Create a new Python file called `app.py` and add the following code:

```python
import inference
from inference.models.owlv2.owlv2 import OwlV2
from inference.core.entities.requests.owlv2 import OwlV2InferenceRequest
from PIL import Image
import io
import base64

model = OwlV2()


im_url = "https://media.roboflow.com/inference/seawithdock.jpeg"
image = {
"type": "url",
"value": im_url
}
request = OwlV2InferenceRequest(
image=image,
training_data=[
{
"image": image,
"boxes": [{"x": 223, "y": 306, "w": 40, "h": 226, "cls": "post"}],
}
],
visualize_predictions=True,
confidence=0.9999,
)

response = OwlV2().infer_from_request(request)

def load_image_from_base64(base64_str):
image = Image.open(io.BytesIO(base64_str))
return image

visualization = load_image_from_base64(response.visualization)
visualization.save("owlv2_visualization.jpg")
```

In this code, we run OWLv2 on an image, using example objects from that image. Above, replace:

1. `training_data` with the locations of the objects you want to detect.
2. `im_url` with the image you would like to perform inference on.

Then, run the Python script you have created:

```
python app.py
```

The result from your model will be save to disk at `owlv2_visualization.jpg`

Note the blue bounding boxes surrounding each pole of the dock.


![OWLv2 results](https://media.roboflow.com/inference/owlv2_visualization.jpg)
91 changes: 91 additions & 0 deletions inference/core/entities/requests/owlv2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from typing import List, Optional, Union

from pydantic import BaseModel, Field, validator

from inference.core.entities.common import ApiKey
from inference.core.entities.requests.inference import (
BaseRequest,
InferenceRequestImage,
)
from inference.core.env import OWLV2_VERSION_ID


class TrainBox(BaseModel):
x: int = Field(description="Center x coordinate in pixels of train box")
y: int = Field(description="Center y coordinate in pixels of train box")
w: int = Field(description="Width in pixels of train box")
h: int = Field(description="Height in pixels of train box")
cls: str = Field(description="Class name of object this box encloses")


class TrainingImage(BaseModel):
boxes: List[TrainBox] = Field(
description="List of boxes and corresponding classes of examples for the model to learn from"
)
image: InferenceRequestImage = Field(
description="Image data that `boxes` describes"
)


class OwlV2InferenceRequest(BaseRequest):
"""Request for gaze detection inference.
Attributes:
api_key (Optional[str]): Roboflow API Key.
owlv2_version_id (Optional[str]): The version ID of Gaze to be used for this request.
image (Union[List[InferenceRequestImage], InferenceRequestImage]): Image(s) for inference.
training_data (List[TrainingImage]): Training data to ground the model on
confidence (float): Confidence threshold to filter predictions by
"""

owlv2_version_id: Optional[str] = Field(
default=OWLV2_VERSION_ID,
examples=["owlv2-base-patch16-ensemble"],
description="The version ID of owlv2 to be used for this request.",
)
model_id: Optional[str] = Field(
default=None, description="Model id to be used in the request."
)

image: Union[List[InferenceRequestImage], InferenceRequestImage] = Field(
description="Images to run the model on"
)
training_data: List[TrainingImage] = Field(
description="Training images for the owlvit model to learn form"
)
confidence: Optional[float] = Field(
default=0.99,
examples=[0.99],
description="Default confidence threshold for owlvit predictions. "
"Needs to be much higher than you're used to, probably 0.99 - 0.9999",
)
visualize_predictions: Optional[bool] = Field(
default=False,
examples=[False],
description="If true, return visualized predictions as a base64 string",
)
visualization_labels: Optional[bool] = Field(
default=False,
examples=[False],
description="If true, labels will be rendered on prediction visualizations",
)
visualization_stroke_width: Optional[int] = Field(
default=1,
examples=[1],
description="The stroke width used when visualizing predictions",
)
visualize_predictions: Optional[bool] = Field(
default=False,
examples=[False],
description="If true, the predictions will be drawn on the original image and returned as a base64 string",
)

# TODO[pydantic]: We couldn't refactor the `validator`, please replace it by `field_validator` manually.
# Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-validators for more information.
@validator("model_id", always=True, allow_reuse=True)
def validate_model_id(cls, value, values):
if value is not None:
return value
if values.get("owl2_version_id") is None:
return None
return f"google/{values['owl2_version_id']}"
3 changes: 3 additions & 0 deletions inference/core/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@

# Gaze version ID, default is "L2CS"
GAZE_VERSION_ID = os.getenv("GAZE_VERSION_ID", "L2CS")
OWLV2_VERSION_ID = os.getenv("OWLV2_VERSION_ID", "owlv2-base-patch16-ensemble")

# Gaze model ID
GAZE_MODEL_ID = f"gaze/{CLIP_VERSION_ID}"
Expand Down Expand Up @@ -108,6 +109,8 @@
CORE_MODEL_SAM_ENABLED = str2bool(os.getenv("CORE_MODEL_SAM_ENABLED", True))
CORE_MODEL_SAM2_ENABLED = str2bool(os.getenv("CORE_MODEL_SAM2_ENABLED", True))

CORE_MODEL_OWLV2_ENABLED = str2bool(os.getenv("CORE_MODEL_OWLV2_ENABLED", False))

# Flag to enable GAZE core model, default is True
CORE_MODEL_GAZE_ENABLED = str2bool(os.getenv("CORE_MODEL_GAZE_ENABLED", True))

Expand Down
38 changes: 38 additions & 0 deletions inference/core/interfaces/http/http_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
LMMInferenceRequest,
ObjectDetectionInferenceRequest,
)
from inference.core.entities.requests.owlv2 import OwlV2InferenceRequest
from inference.core.entities.requests.sam import (
SamEmbeddingRequest,
SamSegmentationRequest,
Expand Down Expand Up @@ -96,6 +97,7 @@
CORE_MODEL_DOCTR_ENABLED,
CORE_MODEL_GAZE_ENABLED,
CORE_MODEL_GROUNDINGDINO_ENABLED,
CORE_MODEL_OWLV2_ENABLED,
CORE_MODEL_SAM2_ENABLED,
CORE_MODEL_SAM_ENABLED,
CORE_MODEL_YOLO_WORLD_ENABLED,
Expand Down Expand Up @@ -691,6 +693,7 @@ def load_core_model(
"""

load_yolo_world_model = partial(load_core_model, core_model="yolo_world")
load_owlv2_model = partial(load_core_model, core_model="owlv2")
"""Loads the YOLO World model into the model manager.
Args:
Expand Down Expand Up @@ -1537,6 +1540,41 @@ async def sam2_segment_image(
)
return model_response

if CORE_MODEL_OWLV2_ENABLED:

@app.post(
"/owlv2/infer",
response_model=ObjectDetectionInferenceResponse,
summary="Owlv2 image prompting",
description="Run the google owlv2 model to few-shot object detect",
)
@with_route_exceptions
async def owlv2_infer(
inference_request: OwlV2InferenceRequest,
request: Request,
api_key: Optional[str] = Query(
None,
description="Roboflow API Key that will be passed to the model during initialization for artifact retrieval",
),
):
"""
Embeds image data using the Meta AI Segmant Anything Model (SAM).
Args:
inference_request (SamEmbeddingRequest): The request containing the image to be embedded.
api_key (Optional[str], default None): Roboflow API Key passed to the model during initialization for artifact retrieval.
request (Request, default Body()): The HTTP request.
Returns:
M.Sam2EmbeddingResponse or Response: The response affirming the image has been embedded
"""
logger.debug(f"Reached /owlv2/infer")
owl2_model_id = load_owlv2_model(inference_request, api_key=api_key)
model_response = await self.model_manager.infer_from_request(
owl2_model_id, inference_request
)
return model_response

if CORE_MODEL_GAZE_ENABLED:

@app.post(
Expand Down
1 change: 1 addition & 0 deletions inference/core/registries/roboflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
"cogvlm": ("llm", "cogvlm"),
"paligemma": ("llm", "paligemma"),
"yolo_world": ("object-detection", "yolo-world"),
"owlv2": ("object-detection", "owlv2"),
}

STUB_VERSION_ID = "0"
Expand Down
Empty file.
Loading

0 comments on commit 982a260

Please sign in to comment.