diff --git a/ctlearn/tools/train_model.py b/ctlearn/tools/train_model.py index c4e17b5..035bd37 100644 --- a/ctlearn/tools/train_model.py +++ b/ctlearn/tools/train_model.py @@ -200,6 +200,32 @@ class TrainCTLearnModel(Tool): allow_none=False, help="Set whether to save model in an ONNX file.", ).tag(config=True) + + early_stopping = Bool( + default_value=False, + allow_none=True, + help="Set whether to have aerly stopping", + ).tag(config=True) + + early_stopping_patience = Int( + default_value=4, + allow_none=True, + help="EarlyStopping patience", + ).tag(config=True) + + early_stopping_metric = CaselessStrEnum( + ['loss', 'val_loss', 'accuracy', 'val_accuracy', 'precision', 'val_precision', 'recall', 'val_recall', 'auc', 'val_auc', 'mae', 'val_mae', 'mse', 'val_mse', 'rmse', 'val_rmse', 'cosine_proximity', 'val_cosine_proximity', 'logcosh', 'val_logcosh'], + default_value="val_loss", + allow_none=True, + help="EarlyStopping monitor metric", + ).tag(config=True) + + early_stopping_restore_best = Bool( + default_value=True, + allow_none=True, + help="Set whether to save the best on early stopping or not", + ).tag(config=True) + overwrite = Bool(help="Overwrite output dir if it exists").tag(config=True) @@ -350,6 +376,17 @@ def setup(self): filename=f"{self.output_dir}/training_log.csv", append=True ) self.callbacks = [model_checkpoint_callback, tensorboard_callback, csv_logger_callback] + + if self.early_stopping is True: + # EarlyStopping callback + early_stopping_callback = keras.callbacks.EarlyStopping( + monitor=self.early_stopping_metric, + patience=self.early_stopping_patience, + verbose=1, + restore_best_weights=self.early_stopping_restore_best + ) + self.callbacks.append(early_stopping_callback) + # Learning rate reducing callback if self.lr_reducing is not None: # Validate the learning rate reducing parameters @@ -366,6 +403,7 @@ def setup(self): self.callbacks.append(lr_reducing_callback) + def start(self): # Open a strategy scope. @@ -517,18 +555,3 @@ def main(): if __name__ == "main": main() - - - - - - - - - - - - - - -