From 55fdeb3b8f0b526cdcdb4e5576dab6def768840d Mon Sep 17 00:00:00 2001 From: ronilpatil Date: Wed, 27 Mar 2024 01:17:53 +0530 Subject: [PATCH] bug fixes and improvement --- tests/test_main.py | 87 +++++++++++++++++----------------------------- 1 file changed, 31 insertions(+), 56 deletions(-) diff --git a/tests/test_main.py b/tests/test_main.py index 2957930..fc94848 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -11,6 +11,7 @@ from src.models.train_model import train_model from src.visualization.visualize import conf_matrix + curr_dir = pathlib.Path(__file__) home_dir = curr_dir.parent.parent @@ -20,72 +21,19 @@ data_source = f'{home_dir.as_posix()}{processed_data}/train.csv' drive_ = params['load_dataset']['drive_link'] - -# test data loading functionality -@pytest.fixture(params = [drive_]) -def load_dataset(request) : - return extract_data(request.param) - -def test_load_data(load_dataset) : - if isinstance(load_dataset, pd.DataFrame) : - assert isinstance(load_dataset, pd.DataFrame) - else : - pytest.fail('Unable to fetch data from remote server') - - # ---- OR ---- - # assert isinstance(load_dataset, pd.DataFrame), "Failure Message" # it will raise assertion error - -# test feature generation functionality -# 'request' is a fixture that provides information about the currently executing test. -# It allows us to access various attributes and methods to get information about the test -# environment, access fixtures, and more. -# here we are accessing current parameters value using request.param -@pytest.fixture(params = [data_source]) -def load(request) : - return feat_eng(load_data(request.param)) - -def test_build_features(load) : - if load.shape[1] == 20 : - assert load.shape[1] == 20 - else : - pytest.fail('build_feature is not working properly') - parameters = params['train_model'] TARGET = params['base']['target'] - train_dir = f"{home_dir.as_posix()}{params['build_features']['extended_data']}/extended_train.csv" test_dir = f"{home_dir.as_posix()}{params['build_features']['extended_data']}/extended_test.csv" model_dir = f"{home_dir.as_posix()}{parameters['model_dir']}" pathlib.Path(model_dir).mkdir(parents = True, exist_ok = True) - train_data = load_data(train_dir) test_data = load_data(test_dir) X_train = train_data.drop(columns = [TARGET]).values Y = train_data[TARGET] - X_test = test_data.drop(columns = [TARGET]).values y = test_data[TARGET] -# test model training functionality -@pytest.fixture -def get_model() : - return train_model(X_train, Y, parameters['n_estimators'], parameters['criterion'], parameters['max_depth'], min_samples_leaf = parameters['min_samples_leaf'], - min_samples_split = parameters['min_samples_split'], random_state = parameters['random_state'], yaml_file_obj = params) - -def test_train_model(get_model) : - if isinstance(get_model['model'], BaseEstimator) : - assert isinstance(get_model['model'], BaseEstimator) - else : - pytest.fail('not able to train the model') - -# test ploting functionality -def test_conf_metrix(get_model) : - filename = conf_matrix(y_test = y, y_pred = get_model['model'].predict(X_test), labels = get_model['model'].classes_, path = './tests/trash', params_obj = params) - if pathlib.Path(filename).exists() : - assert pathlib.Path(filename).exists() - else : - pytest.fail('unable to generate confusion matrix') - # test api endpoint # valid input user_input1 = { @@ -128,9 +76,36 @@ def test_conf_metrix(get_model) : 'sulphates': 0.65, 'alcohol': '11.4'} -@pytest.fixture -def client() : - return TestClient(app) +# pass fixture as parameter +def test_load_data(load_dataset) : + # load_dataset is pointing to the dataset returned by extract_data method + if isinstance(load_dataset, pd.DataFrame) : + assert isinstance(load_dataset, pd.DataFrame) # checking load_dataset is instance of pandas dataframe, if yes testcase will pass + else : + pytest.fail('Unable to fetch data from remote server') # else test case will fail and return the message + + # ---- OR ---- + # assert isinstance(load_dataset, pd.DataFrame), "Failure Message" # it will raise assertion error + +def test_build_features(load) : + if load.shape[1] == 20 : + assert load.shape[1] == 20 + else : + pytest.fail('build_feature is not working properly') + +def test_train_model(get_model) : + if isinstance(get_model['model'], BaseEstimator) : + assert isinstance(get_model['model'], BaseEstimator) + else : + pytest.fail('not able to train the model') + +# test ploting functionality +def test_conf_metrix(get_model) : + filename = conf_matrix(y_test = y, y_pred = get_model['model'].predict(X_test), labels = get_model['model'].classes_, path = './tests/trash', params_obj = params) + if pathlib.Path(filename).exists() : + assert pathlib.Path(filename).exists() + else : + pytest.fail('unable to generate confusion matrix') # test response of api def test_response(client) :