Skip to content

Commit

Permalink
refactor(api): move request validation to pydantic models
Browse files Browse the repository at this point in the history
Use Python builtin ContextVar and context manager together with FastAPI
dependencies to provide request context such as path variables to
pydantic model validation.
  • Loading branch information
matthiasschaub committed Nov 24, 2024
1 parent 300f766 commit de22d59
Show file tree
Hide file tree
Showing 11 changed files with 189 additions and 140 deletions.
34 changes: 14 additions & 20 deletions ohsome_quality_api/api/api.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import json
import logging
import os
from typing import Any, Union
from typing import Annotated, Any, Union

from fastapi import FastAPI, HTTPException, Request, status
from fastapi import Depends, FastAPI, HTTPException, Request, status
from fastapi.encoders import jsonable_encoder
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
Expand All @@ -24,6 +24,7 @@
__version__,
oqt,
)
from ohsome_quality_api.api.request_context import set_request_context
from ohsome_quality_api.api.request_models import (
AttributeCompletenessRequest,
IndicatorDataRequest,
Expand Down Expand Up @@ -69,16 +70,11 @@
OhsomeApiError,
SizeRestrictionError,
TopicDataSchemaError,
ValidationError,
)
from ohsome_quality_api.utils.helper import (
get_class_from_key,
json_serialize,
)
from ohsome_quality_api.utils.validators import (
validate_attribute_topic_combination,
validate_indicator_topic_combination,
)

MEDIA_TYPE_GEOJSON = "application/geo+json"
MEDIA_TYPE_JSON = "application/json"
Expand Down Expand Up @@ -125,6 +121,10 @@
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")


class CommonDepedencies:
RequestContext = Annotated[None, Depends(set_request_context)]


@app.get("/docs", include_in_schema=False)
async def custom_swagger_ui_html(request: Request):
root_path = request.scope.get("root_path")
Expand Down Expand Up @@ -180,13 +180,8 @@ async def validation_exception_handler(
@app.exception_handler(TopicDataSchemaError)
@app.exception_handler(OhsomeApiError)
@app.exception_handler(SizeRestrictionError)
@app.exception_handler(ValidationError)
async def custom_exception_handler(
_: Request,
exception: TopicDataSchemaError
| OhsomeApiError
| SizeRestrictionError
| ValidationError,
_: Request, exception: TopicDataSchemaError | OhsomeApiError | SizeRestrictionError
):
"""Exception handler for custom exceptions."""
return JSONResponse(
Expand Down Expand Up @@ -229,7 +224,10 @@ def empty_api_response() -> dict:


@app.post("/indicators/mapping-saturation/data", include_in_schema=False)
async def post_indicator_ms(parameters: IndicatorDataRequest) -> CustomJSONResponse:
async def post_indicator_ms(
parameters: IndicatorDataRequest,
_: CommonDepedencies.RequestContext,
) -> CustomJSONResponse:
"""Legacy support for computing the Mapping Saturation indicator for given data."""
indicators = await oqt.create_indicator(
key="mapping-saturation",
Expand Down Expand Up @@ -269,13 +267,9 @@ async def post_indicator_ms(parameters: IndicatorDataRequest) -> CustomJSONRespo
async def post_attribute_completeness(
request: Request,
parameters: AttributeCompletenessRequest,
_: CommonDepedencies.RequestContext,
) -> Any:
"""Request the Attribute Completeness indicator for your area of interest."""
for attribute in parameters.attribute_keys:
validate_attribute_topic_combination(
attribute.value, parameters.topic_key.value
)

return await _post_indicator(request, "attribute-completeness", parameters)


Expand All @@ -300,6 +294,7 @@ async def post_indicator(
request: Request,
key: IndicatorEnumRequest,
parameters: IndicatorRequest,
_: CommonDepedencies.RequestContext,
) -> Any:
"""Request an indicator for your area of interest."""
return await _post_indicator(request, key.value, parameters)
Expand All @@ -308,7 +303,6 @@ async def post_indicator(
async def _post_indicator(
request: Request, key: str, parameters: IndicatorRequest
) -> Any:
validate_indicator_topic_combination(key, parameters.topic_key.value)
attribute_keys = getattr(parameters, "attribute_keys", None)
if attribute_keys:
attribute_keys = [attribute.value for attribute in attribute_keys]
Expand Down
32 changes: 32 additions & 0 deletions ohsome_quality_api/api/request_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from contextlib import asynccontextmanager
from contextvars import ContextVar
from dataclasses import dataclass

from fastapi import Request


@dataclass
class RequestContext:
path_parameters: dict


request_context: ContextVar[RequestContext] = ContextVar("request_context")


@asynccontextmanager
async def request_context_manager(path_parameters: dict):
token = request_context.set(RequestContext(path_parameters=path_parameters))
try:
yield
finally:
request_context.reset(token)


async def set_request_context(request: Request):
"""Set request context for the duration of a request.
After leaving the context manager (after the request is processed)
the request context is resented again.
"""
async with request_context_manager(request.path_params):
yield
70 changes: 65 additions & 5 deletions ohsome_quality_api/api/request_models.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
from typing import Dict, List
from typing import Dict, List, Self

import geojson
from geojson_pydantic import Feature, FeatureCollection, MultiPolygon, Polygon
from pydantic import BaseModel, ConfigDict, Field, field_validator

from ohsome_quality_api.attributes.definitions import AttributeEnum
from pydantic import (
BaseModel,
ConfigDict,
Field,
field_validator,
model_validator,
)

from ohsome_quality_api.api.request_context import RequestContext, request_context
from ohsome_quality_api.attributes.definitions import AttributeEnum, get_attributes
from ohsome_quality_api.indicators.definitions import get_valid_indicators
from ohsome_quality_api.topics.definitions import TopicEnum
from ohsome_quality_api.topics.models import TopicData
from ohsome_quality_api.utils.helper import snake_to_lower_camel
Expand All @@ -19,6 +27,12 @@ class BaseConfig(BaseModel):
)


class BaseRequestContext(BaseModel):
@property
def request_context(self) -> RequestContext | None:
return request_context.get()


FeatureCollection_ = FeatureCollection[Feature[Polygon | MultiPolygon, Dict]]


Expand Down Expand Up @@ -56,14 +70,30 @@ def transform(cls, value) -> geojson.FeatureCollection:
return geojson.loads(value.model_dump_json())


class IndicatorRequest(BaseBpolys):
class IndicatorRequest(BaseBpolys, BaseRequestContext):
topic_key: TopicEnum = Field(
...,
title="Topic Key",
alias="topic",
)
include_figure: bool = True

@model_validator(mode="after")
def validate_indicator_topic_combination(self) -> Self:
if self.request_context is not None:
indicator = self.request_context.path_parameters["key"]
else:
raise TypeError("Request context for /indicators should never be None.")

valid_indicators = get_valid_indicators(self.topic_key.value)
if indicator not in valid_indicators:
raise ValueError(
"Invalid combination of indicator and topic: {} and {}".format(
indicator, self.topic_key.value
)
)
return self


class AttributeCompletenessRequest(IndicatorRequest):
attribute_keys: List[AttributeEnum] = Field(
Expand All @@ -72,6 +102,36 @@ class AttributeCompletenessRequest(IndicatorRequest):
alias="attributes",
)

@model_validator(mode="after")
def validate_indicator_topic_combination(self) -> Self:
valid_indicators = get_valid_indicators(self.topic_key.value)
if "attribute-completeness" not in valid_indicators:
raise ValueError(
"Invalid combination of indicator and topic: {} and {}".format(
"attribute-completeness",
self.topic_key.value,
)
)
return self

@model_validator(mode="after")
def validate_attributes(self) -> Self:
valid_attributes = tuple(get_attributes()[self.topic_key.value].keys())
for attribute in self.attribute_keys:
if attribute.value not in valid_attributes:
raise ValueError(
(
"Invalid combination of attribute {} and topic {}. "
"Topic {} supports these attributes: {}"
).format(
attribute.value,
self.topic_key.value,
self.topic_key.value,
", ".join(valid_attributes),
)
)
return self


class IndicatorDataRequest(BaseBpolys):
"""Model for the `/indicators/mapping-saturation/data` endpoint.
Expand Down
2 changes: 1 addition & 1 deletion ohsome_quality_api/oqt.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ async def create_indicator(
if "id" not in feature.keys():
feature["id"] = i
# Only enforce size limit if ohsome API data is not provided
# Disable size limit for the Mapping Saturation indicator
# or for certain indicators
if isinstance(topic, TopicDefinition) and key not in [
"mapping-saturation",
"currentness",
Expand Down
28 changes: 0 additions & 28 deletions ohsome_quality_api/utils/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,34 +3,6 @@
from schema import SchemaError


class ValidationError(Exception):
def __init__(self):
self.name = ""
self.message = ""


class AttributeTopicCombinationError(ValidationError):
"""Invalid attribute topic combination error."""

def __init__(self, attribute, topic, valid_attribute_names):
self.name = "AttributeTopicCombinationError"
self.message = (
"Invalid combination of attribute and topic: {} and {}. "
"Topic '{}' supports these attributes: {}"
).format(attribute, topic, topic, valid_attribute_names)


class IndicatorTopicCombinationError(ValidationError):
"""Invalid indicator topic combination error."""

def __init__(self, indicator, topic):
self.name = "IndicatorTopicCombinationError"
self.message = "Invalid combination of indicator and topic: {} and {}".format(
indicator,
topic,
)


class OhsomeApiError(Exception):
"""Request to ohsome API failed."""

Expand Down
28 changes: 1 addition & 27 deletions ohsome_quality_api/utils/validators.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,10 @@
from geojson import Feature

from ohsome_quality_api.attributes.definitions import (
AttributeEnum,
get_attributes,
)
from ohsome_quality_api.config import get_config_value
from ohsome_quality_api.indicators.definitions import get_valid_indicators
from ohsome_quality_api.topics.definitions import TopicEnum
from ohsome_quality_api.utils.exceptions import (
AttributeTopicCombinationError,
IndicatorTopicCombinationError,
SizeRestrictionError,
)
from ohsome_quality_api.utils.exceptions import SizeRestrictionError
from ohsome_quality_api.utils.helper_geo import calculate_area


def validate_attribute_topic_combination(attribute: AttributeEnum, topic: TopicEnum):
"""As attributes are only meaningful for a certain topic,
we need to check if the given combination is valid."""

valid_attributes_for_topic = get_attributes()[topic]
valid_attribute_names = [attribute for attribute in valid_attributes_for_topic]

if attribute not in valid_attributes_for_topic:
raise AttributeTopicCombinationError(attribute, topic, valid_attribute_names)


def validate_indicator_topic_combination(indicator: str, topic: str):
if indicator not in get_valid_indicators(topic):
raise IndicatorTopicCombinationError(indicator, topic)


def validate_area(feature: Feature):
"""Check area size of feature against size limit configuration value."""
size_limit = float(get_config_value("geom_size_limit"))
Expand Down
23 changes: 6 additions & 17 deletions tests/integrationtests/api/test_indicators.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
"""

import pytest
from approvaltests.approvals import verify
from schema import Optional, Or, Schema

from ohsome_quality_api.attributes.definitions import get_attributes
from tests.integrationtests.utils import oqapi_vcr
from tests.integrationtests.utils import PytestNamer, oqapi_vcr

ENDPOINT = "/indicators/"

Expand Down Expand Up @@ -134,11 +134,11 @@ def test_indicators_attribute_completeness(
assert schema.is_valid(response.json())


@pytest.mark.usefixtures("schema")
def test_indicators_attribute_completeness_without_attribute(
client,
bpolys,
headers,
schema,
):
endpoint = ENDPOINT + "attribute-completeness"
parameters = {
Expand All @@ -151,11 +151,11 @@ def test_indicators_attribute_completeness_without_attribute(
assert content["type"] == "RequestValidationError"


@pytest.mark.usefixtures("schema")
def test_indicators_attribute_completeness_with_invalid_attribute_for_topic(
client,
bpolys,
headers,
schema,
):
endpoint = ENDPOINT + "attribute-completeness"
parameters = {
Expand All @@ -167,19 +167,8 @@ def test_indicators_attribute_completeness_with_invalid_attribute_for_topic(
response = client.post(endpoint, json=parameters, headers=headers)
assert response.status_code == 422
content = response.json()

message = content["detail"][0]["msg"]
all_attributes_for_topic = [
attribute for attribute in (get_attributes()["building-count"])
]

expected = (
"Invalid combination of attribute and topic: maxspeed and building-count. "
"Topic 'building-count' supports these attributes: {}"
).format(all_attributes_for_topic)

assert message == expected
assert content["type"] == "AttributeTopicCombinationError"
assert content["type"] == "RequestValidationError"
verify(content["detail"][0]["msg"], namer=PytestNamer())


@oqapi_vcr.use_cassette
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Value error, Invalid combination of attribute maxspeed and topic building-count. Topic building-count supports these attributes: height, house-number, address-street, address-city, address-postcode, address-country, address-state, address-suburb, address-district, address-housenumber, building-levels, roof-shape, roof-levels, building-material, roof-material, roof-colour, building-colour
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Value error, Invalid combination of attribute maxspeed and topic building-count. Topic building-count supports these attributes: height, house-number, address-street, address-city, address-postcode, address-country, address-state, address-suburb, address-district, address-housenumber, building-levels, roof-shape, roof-levels, building-material, roof-material, roof-colour, building-colour
Loading

0 comments on commit de22d59

Please sign in to comment.