Skip to content

Commit

Permalink
Update train_model.py
Browse files Browse the repository at this point in the history
  • Loading branch information
rcervinoucm authored Feb 27, 2025
1 parent 6abd223 commit 4f24364
Showing 1 changed file with 38 additions and 15 deletions.
53 changes: 38 additions & 15 deletions ctlearn/tools/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,32 @@ class TrainCTLearnModel(Tool):
allow_none=False,
help="Set whether to save model in an ONNX file.",
).tag(config=True)

early_stopping = Bool(
default_value=False,
allow_none=True,
help="Set whether to have aerly stopping",
).tag(config=True)

early_stopping_patience = Int(
default_value=4,
allow_none=True,
help="EarlyStopping patience",
).tag(config=True)

early_stopping_metric = CaselessStrEnum(
['loss', 'val_loss', 'accuracy', 'val_accuracy', 'precision', 'val_precision', 'recall', 'val_recall', 'auc', 'val_auc', 'mae', 'val_mae', 'mse', 'val_mse', 'rmse', 'val_rmse', 'cosine_proximity', 'val_cosine_proximity', 'logcosh', 'val_logcosh'],
default_value="val_loss",
allow_none=True,
help="EarlyStopping monitor metric",
).tag(config=True)

early_stopping_restore_best = Bool(
default_value=True,
allow_none=True,
help="Set whether to save the best on early stopping or not",
).tag(config=True)


overwrite = Bool(help="Overwrite output dir if it exists").tag(config=True)

Expand Down Expand Up @@ -350,6 +376,17 @@ def setup(self):
filename=f"{self.output_dir}/training_log.csv", append=True
)
self.callbacks = [model_checkpoint_callback, tensorboard_callback, csv_logger_callback]

if self.early_stopping is True:
# EarlyStopping callback
early_stopping_callback = keras.callbacks.EarlyStopping(
monitor=self.early_stopping_metric,
patience=self.early_stopping_patience,
verbose=1,
restore_best_weights=self.early_stopping_restore_best
)
self.callbacks.append(early_stopping_callback)

# Learning rate reducing callback
if self.lr_reducing is not None:
# Validate the learning rate reducing parameters
Expand All @@ -366,6 +403,7 @@ def setup(self):
self.callbacks.append(lr_reducing_callback)



def start(self):

# Open a strategy scope.
Expand Down Expand Up @@ -517,18 +555,3 @@ def main():

if __name__ == "main":
main()















0 comments on commit 4f24364

Please sign in to comment.