Skip to content

Commit

Permalink
add pointing interpolator
Browse files Browse the repository at this point in the history
  • Loading branch information
TjarkMiener committed Jul 17, 2024
1 parent f5eddf3 commit 6d94cbc
Showing 1 changed file with 24 additions and 7 deletions.
31 changes: 24 additions & 7 deletions ctlearn/output_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from astropy.table import Table
from astropy.coordinates import SkyCoord
import astropy.units as u

from ctapipe.io.pointing import PointingInterpolator

def write_output(h5file, data, rest_data, reader, predictions, tasks):
prediction_dir = h5file.replace(f'{h5file.split("/")[-1]}', "")
Expand Down Expand Up @@ -144,13 +144,28 @@ def write_output(h5file, data, rest_data, reader, predictions, tasks):
# so there is only one tel_id in the file.
# For stereo we should fix this directly in the ctapipe plugin,
# which is currently under development.
for tel_id in reader.telescope_pointings:
tel_pointing = reader.telescope_pointings[tel_id]
pointing_interpolator = PointingInterpolator()
tel_id_int = None
for tel_id, pointing_table in reader.telescope_pointings.items():
tel_id_int = int(tel_id.replace("tel_", ""))
pointing_interpolator.add_table(tel_id_int, pointing_table)
trigger_info = reader.tel_trigger_table[reader.tel_trigger_table["tel_id"]== tel_id_int]
# Check if the number of predictions and trigger info match
# Actually this check is redundant since the dl1dh do not allow quality cuts when processing real data
# However, it is still good to have it here in case table are not properly filled.
if len(predictions[:, 0]) != len(trigger_info):
raise ValueError(
f"The number of predictions ({len(predictions[:, 0])}) and trigger info ({len(trigger_info)}) do not match."
)
event_id, obs_id, tel_id = [], [], []
reco_az, reco_alt = [], []
pointing_az, pointing_alt, time = [], [], []
for i, (az_off, alt_off) in enumerate(zip(predictions[:, 0], predictions[:, 1])):
tel_az = tel_pointing[i]['azimuth'] * u.rad
tel_alt = tel_pointing[i]['altitude'] * u.rad
event_id = trigger_info[i]['event_id']
obs_id = trigger_info[i]['obs_id']
tel_id = trigger_info[i]['tel_id']
time = trigger_info[i]['time']
tel_alt, tel_az = pointing_interpolator(tel_id_int, time)
pointing = SkyCoord(
tel_az.to_value(data.drc_unit),
tel_alt.to_value(data.drc_unit),
Expand All @@ -165,12 +180,14 @@ def write_output(h5file, data, rest_data, reader, predictions, tasks):
reco_alt.append(reco_direction.alt.to_value(data.drc_unit))
pointing_az.append(tel_az.to_value(u.deg))
pointing_alt.append(tel_alt.to_value(u.deg))
time.append(tel_pointing[i]['time'])
reco["event_id"] = np.array(event_id)
reco["obs_id"] = np.array(obs_id)
reco["tel_id"] = np.array(tel_id)
reco["time"] = np.array(time)
reco["reco_az"] = np.array(reco_az)
reco["reco_alt"] = np.array(reco_alt)
reco["pointing_az"] = np.array(pointing_az)
reco["pointing_alt"] = np.array(pointing_alt)
reco["time"] = np.array(time)
reco["reco_sep"] = np.array(predictions[:, 2])
else:
reco_az, reco_alt = [], []
Expand Down

0 comments on commit 6d94cbc

Please sign in to comment.