diff --git a/ctlearn/core/tests/test_loader.py b/ctlearn/core/tests/test_loader.py index 55e71a9..97f8717 100644 --- a/ctlearn/core/tests/test_loader.py +++ b/ctlearn/core/tests/test_loader.py @@ -4,8 +4,9 @@ from dl1_data_handler.reader import DLImageReader from ctlearn.core.loader import DLDataLoader + def test_data_loader(dl1_tmp_path, dl1_gamma_file): - """check """ + """check""" # Create a configuration suitable for the test config = Config( { @@ -20,12 +21,17 @@ def test_data_loader(dl1_tmp_path, dl1_gamma_file): dl1_loader = DLDataLoader( DLDataReader=dl1_reader, indices=[0], - tasks=["type", "energy", "direction"], - batch_size=1 + tasks=["type", "energy", "cameradirection", "skydirection"], + batch_size=1, ) # Get the features and labels fgrom the data loader for one batch features, labels = dl1_loader[0] # Check that all the correct labels are present - assert "type" in labels and "energy" in labels and "direction" in labels + assert ( + "type" in labels + and "energy" in labels + and "cameradirection" in labels + and "skydirection" in labels + ) # Check the shape of the features - assert features["input"].shape == (1, 110, 110, 2) \ No newline at end of file + assert features["input"].shape == (1, 110, 110, 2)