diff --git a/ctlearn/output_handler.py b/ctlearn/output_handler.py index 84b1af62..0808b121 100644 --- a/ctlearn/output_handler.py +++ b/ctlearn/output_handler.py @@ -5,7 +5,7 @@ from astropy.table import Table from astropy.coordinates import SkyCoord import astropy.units as u - +from ctapipe.io.pointing import PointingInterpolator def write_output(h5file, data, rest_data, reader, predictions, tasks): prediction_dir = h5file.replace(f'{h5file.split("/")[-1]}', "") @@ -144,13 +144,28 @@ def write_output(h5file, data, rest_data, reader, predictions, tasks): # so there is only one tel_id in the file. # For stereo we should fix this directly in the ctapipe plugin, # which is currently under development. - for tel_id in reader.telescope_pointings: - tel_pointing = reader.telescope_pointings[tel_id] + pointing_interpolator = PointingInterpolator() + tel_id_int = None + for tel_id, pointing_table in reader.telescope_pointings.items(): + tel_id_int = int(tel_id.replace("tel_", "")) + pointing_interpolator.add_table(tel_id_int, pointing_table) + trigger_info = reader.tel_trigger_table[reader.tel_trigger_table["tel_id"]== tel_id_int] + # Check if the number of predictions and trigger info match + # Actually this check is redundant since the dl1dh do not allow quality cuts when processing real data + # However, it is still good to have it here in case table are not properly filled. + if len(predictions[:, 0]) != len(trigger_info): + raise ValueError( + f"The number of predictions ({len(predictions[:, 0])}) and trigger info ({len(trigger_info)}) do not match." + ) + event_id, obs_id, tel_id = [], [], [] reco_az, reco_alt = [], [] pointing_az, pointing_alt, time = [], [], [] for i, (az_off, alt_off) in enumerate(zip(predictions[:, 0], predictions[:, 1])): - tel_az = tel_pointing[i]['azimuth'] * u.rad - tel_alt = tel_pointing[i]['altitude'] * u.rad + event_id = trigger_info[i]['event_id'] + obs_id = trigger_info[i]['obs_id'] + tel_id = trigger_info[i]['tel_id'] + time = trigger_info[i]['time'] + tel_alt, tel_az = pointing_interpolator(tel_id_int, time) pointing = SkyCoord( tel_az.to_value(data.drc_unit), tel_alt.to_value(data.drc_unit), @@ -165,12 +180,14 @@ def write_output(h5file, data, rest_data, reader, predictions, tasks): reco_alt.append(reco_direction.alt.to_value(data.drc_unit)) pointing_az.append(tel_az.to_value(u.deg)) pointing_alt.append(tel_alt.to_value(u.deg)) - time.append(tel_pointing[i]['time']) + reco["event_id"] = np.array(event_id) + reco["obs_id"] = np.array(obs_id) + reco["tel_id"] = np.array(tel_id) + reco["time"] = np.array(time) reco["reco_az"] = np.array(reco_az) reco["reco_alt"] = np.array(reco_alt) reco["pointing_az"] = np.array(pointing_az) reco["pointing_alt"] = np.array(pointing_alt) - reco["time"] = np.array(time) reco["reco_sep"] = np.array(predictions[:, 2]) else: reco_az, reco_alt = [], []