Skip to content

Commit

Permalink
fix predict_data
Browse files Browse the repository at this point in the history
after backbone/head split the model output is not longer a dict for regression task
  • Loading branch information
TjarkMiener committed Feb 12, 2025
1 parent 4d56346 commit e2a4c8a
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions ctlearn/tools/predict_LST1.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
ReconstructedGeometryContainer,
ReconstructedEnergyContainer,
)
from ctapipe.coordinates import EngineeringCameraFrame
from ctapipe.coordinates import CameraFrame, EngineeringCameraFrame
from ctapipe.core import Tool
from ctapipe.core.tool import ToolConfigurationError
from ctapipe.core.traits import (
Expand Down Expand Up @@ -487,13 +487,13 @@ def start(self):
energy_feature_vectors = self.backbone_energy.predict_on_batch(input_data)
energy_fvs.extend(energy_feature_vectors)
predict_data = self.head_energy.predict_on_batch(energy_feature_vectors)
energy.extend(predict_data["energy"])
energy.extend(predict_data.T[0])
if self.load_direction_model_from is not None:
direction_feature_vectors = self.backbone_direction.predict_on_batch(input_data)
direction_fvs.extend(direction_feature_vectors)
predict_data = self.head_direction.predict_on_batch(direction_feature_vectors)
az.extend(predict_data["direction"].T[0])
alt.extend(predict_data["direction"].T[1])
az.extend(predict_data.T[0])
alt.extend(predict_data.T[1])

# Create the prediction tables
example_identifiers = Table(
Expand Down

0 comments on commit e2a4c8a

Please sign in to comment.