Skip to content

Commit

Permalink
add conftest.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ronylpatil committed Mar 29, 2024
1 parent 5124887 commit f4eef83
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 11 deletions.
9 changes: 7 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,22 +27,27 @@
X_test = test_data.drop(columns = [TARGET]).values
y = test_data[TARGET]

# testing load_data function
# fixture with params
# return the output of extract_data function
@pytest.fixture(params = [drive_])
def load_dataset(request) :
# calling extract_data method & returning its output
return extract_data(request.param)

# return output of feat_eng function
@pytest.fixture(params = [data_source])
def load(request) :
return feat_eng(load_data(request.param))

# test model training functionality
# fixture without parameters
# return output of train_model function
@pytest.fixture
def get_model() :
return train_model(X_train, Y, parameters['n_estimators'], parameters['criterion'], parameters['max_depth'], min_samples_leaf = parameters['min_samples_leaf'],
min_samples_split = parameters['min_samples_split'], random_state = parameters['random_state'], yaml_file_obj = params)

# TestClient(app) is used to create a test client object for simulating/
# HTTP requests to a FastAPI application during testing
@pytest.fixture
def client() :
return TestClient(app)
21 changes: 12 additions & 9 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@
'sulphates': 0.65,
'alcohol': '11.4'}

# pass fixture as parameter
# test load_data function
def test_load_data(load_dataset) :
# load_dataset is pointing to the dataset returned by extract_data method
if isinstance(load_dataset, pd.DataFrame) :
Expand All @@ -87,55 +87,58 @@ def test_load_data(load_dataset) :
# ---- OR ----
# assert isinstance(load_dataset, pd.DataFrame), "Failure Message" # it will raise assertion error

# test build_features function
def test_build_features(load) :
if load.shape[1] == 20 :
assert load.shape[1] == 20
else :
pytest.fail('build_feature is not working properly')

# test train_model function
def test_train_model(get_model) :
if isinstance(get_model['model'], BaseEstimator) :
assert isinstance(get_model['model'], BaseEstimator)
else :
pytest.fail('not able to train the model')

# test ploting functionality
# test conf_metrix function
def test_conf_metrix(get_model) :
filename = conf_matrix(y_test = y, y_pred = get_model['model'].predict(X_test), labels = get_model['model'].classes_, path = './tests/trash', params_obj = params)
if pathlib.Path(filename).exists() :
assert pathlib.Path(filename).exists()
else :
pytest.fail('unable to generate confusion matrix')

# test response of api
# test API response
def test_response(client) :
response = client.get("/")
if response.status_code == 200 :
assert response.status_code == 200
assert response.json() == {'api status': 'up & running'}
else :
pytest.fail('api is not responding')
pytest.fail('API is not responding')

# test predict endpoint with valid input
# test api endpoint with valid input
def test_validIp(client) :
response = client.post("/predict", json = user_input1)
if response.status_code == 200 :
assert response.status_code == 200
assert 'predictions' in response.json()
assert 'probability' in response.json()
else :
pytest.fail('predict endpoint failed on valid input')
pytest.fail('"predict/" endpoint failed on valid input')

# test predict with invalid inputs
# test api endpoint with invalid inputs
def test_invalidIp(client) :
response1 = client.post("/predict", json = user_input2)
response2 = client.post("/predict", json = user_input3)
if response1.status_code == 422 and response2.status_code == 422 :
assert response1.status_code == 422
assert response2.status_code == 422
else :
pytest.fail('predict endpoint failed on invalid input')
pytest.fail('"predict/" endpoint failed on invalid input')

# test api endpoint output
def test_details(client) :
response = client.get("/details")
if response.status_code == 200 :
Expand All @@ -144,7 +147,7 @@ def test_details(client) :
assert 'model version' in response.json()
assert 'model aliases' in response.json()
else :
pytest.fails('details endpoint not responding')
pytest.fails('"details/" endpoint not responding')

'''
create & configure tox.ini file
Expand Down

0 comments on commit f4eef83

Please sign in to comment.