Skip to content

Commit

Permalink
bug fixes and improvement
Browse files Browse the repository at this point in the history
  • Loading branch information
ronylpatil committed Mar 26, 2024
1 parent a160cd7 commit 55fdeb3
Showing 1 changed file with 31 additions and 56 deletions.
87 changes: 31 additions & 56 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from src.models.train_model import train_model
from src.visualization.visualize import conf_matrix


curr_dir = pathlib.Path(__file__)
home_dir = curr_dir.parent.parent

Expand All @@ -20,72 +21,19 @@
data_source = f'{home_dir.as_posix()}{processed_data}/train.csv'

drive_ = params['load_dataset']['drive_link']

# test data loading functionality
@pytest.fixture(params = [drive_])
def load_dataset(request) :
return extract_data(request.param)

def test_load_data(load_dataset) :
if isinstance(load_dataset, pd.DataFrame) :
assert isinstance(load_dataset, pd.DataFrame)
else :
pytest.fail('Unable to fetch data from remote server')

# ---- OR ----
# assert isinstance(load_dataset, pd.DataFrame), "Failure Message" # it will raise assertion error

# test feature generation functionality
# 'request' is a fixture that provides information about the currently executing test.
# It allows us to access various attributes and methods to get information about the test
# environment, access fixtures, and more.
# here we are accessing current parameters value using request.param
@pytest.fixture(params = [data_source])
def load(request) :
return feat_eng(load_data(request.param))

def test_build_features(load) :
if load.shape[1] == 20 :
assert load.shape[1] == 20
else :
pytest.fail('build_feature is not working properly')

parameters = params['train_model']
TARGET = params['base']['target']

train_dir = f"{home_dir.as_posix()}{params['build_features']['extended_data']}/extended_train.csv"
test_dir = f"{home_dir.as_posix()}{params['build_features']['extended_data']}/extended_test.csv"
model_dir = f"{home_dir.as_posix()}{parameters['model_dir']}"
pathlib.Path(model_dir).mkdir(parents = True, exist_ok = True)

train_data = load_data(train_dir)
test_data = load_data(test_dir)
X_train = train_data.drop(columns = [TARGET]).values
Y = train_data[TARGET]

X_test = test_data.drop(columns = [TARGET]).values
y = test_data[TARGET]

# test model training functionality
@pytest.fixture
def get_model() :
return train_model(X_train, Y, parameters['n_estimators'], parameters['criterion'], parameters['max_depth'], min_samples_leaf = parameters['min_samples_leaf'],
min_samples_split = parameters['min_samples_split'], random_state = parameters['random_state'], yaml_file_obj = params)

def test_train_model(get_model) :
if isinstance(get_model['model'], BaseEstimator) :
assert isinstance(get_model['model'], BaseEstimator)
else :
pytest.fail('not able to train the model')

# test ploting functionality
def test_conf_metrix(get_model) :
filename = conf_matrix(y_test = y, y_pred = get_model['model'].predict(X_test), labels = get_model['model'].classes_, path = './tests/trash', params_obj = params)
if pathlib.Path(filename).exists() :
assert pathlib.Path(filename).exists()
else :
pytest.fail('unable to generate confusion matrix')

# test api endpoint
# valid input
user_input1 = {
Expand Down Expand Up @@ -128,9 +76,36 @@ def test_conf_metrix(get_model) :
'sulphates': 0.65,
'alcohol': '11.4'}

@pytest.fixture
def client() :
return TestClient(app)
# pass fixture as parameter
def test_load_data(load_dataset) :
# load_dataset is pointing to the dataset returned by extract_data method
if isinstance(load_dataset, pd.DataFrame) :
assert isinstance(load_dataset, pd.DataFrame) # checking load_dataset is instance of pandas dataframe, if yes testcase will pass
else :
pytest.fail('Unable to fetch data from remote server') # else test case will fail and return the message

# ---- OR ----
# assert isinstance(load_dataset, pd.DataFrame), "Failure Message" # it will raise assertion error

def test_build_features(load) :
if load.shape[1] == 20 :
assert load.shape[1] == 20
else :
pytest.fail('build_feature is not working properly')

def test_train_model(get_model) :
if isinstance(get_model['model'], BaseEstimator) :
assert isinstance(get_model['model'], BaseEstimator)
else :
pytest.fail('not able to train the model')

# test ploting functionality
def test_conf_metrix(get_model) :
filename = conf_matrix(y_test = y, y_pred = get_model['model'].predict(X_test), labels = get_model['model'].classes_, path = './tests/trash', params_obj = params)
if pathlib.Path(filename).exists() :
assert pathlib.Path(filename).exists()
else :
pytest.fail('unable to generate confusion matrix')

# test response of api
def test_response(client) :
Expand Down

0 comments on commit 55fdeb3

Please sign in to comment.