Skip to content

Commit

Permalink
add conftest.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ronylpatil committed Mar 26, 2024
1 parent 55fdeb3 commit 5124887
Showing 1 changed file with 48 additions and 0 deletions.
48 changes: 48 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import pytest
import pathlib
import yaml
from fastapi.testclient import TestClient
from prod.api import app
from src.data.load_dataset import extract_data
from src.features.build_features import feat_eng
from src.data.make_dataset import load_data
from src.models.train_model import train_model

curr_dir = pathlib.Path(__file__)
home_dir = curr_dir.parent.parent
params = yaml.safe_load(open(f"{home_dir.as_posix()}/params.yaml"))
drive_ = params['load_dataset']['drive_link']
processed_data = params['make_dataset']['processed_data']
data_source = f'{home_dir.as_posix()}{processed_data}/train.csv'
parameters = params['train_model']
TARGET = params['base']['target']
train_dir = f"{home_dir.as_posix()}{params['build_features']['extended_data']}/extended_train.csv"
test_dir = f"{home_dir.as_posix()}{params['build_features']['extended_data']}/extended_test.csv"
model_dir = f"{home_dir.as_posix()}{parameters['model_dir']}"
pathlib.Path(model_dir).mkdir(parents = True, exist_ok = True)
train_data = load_data(train_dir)
test_data = load_data(test_dir)
X_train = train_data.drop(columns = [TARGET]).values
Y = train_data[TARGET]
X_test = test_data.drop(columns = [TARGET]).values
y = test_data[TARGET]

# testing load_data function
@pytest.fixture(params = [drive_])
def load_dataset(request) :
# calling extract_data method & returning its output
return extract_data(request.param)

@pytest.fixture(params = [data_source])
def load(request) :
return feat_eng(load_data(request.param))

# test model training functionality
@pytest.fixture
def get_model() :
return train_model(X_train, Y, parameters['n_estimators'], parameters['criterion'], parameters['max_depth'], min_samples_leaf = parameters['min_samples_leaf'],
min_samples_split = parameters['min_samples_split'], random_state = parameters['random_state'], yaml_file_obj = params)

@pytest.fixture
def client() :
return TestClient(app)

0 comments on commit 5124887

Please sign in to comment.