diff --git a/part1/train.py b/part1/train.py index cbd3c1d..0b5c35c 100644 --- a/part1/train.py +++ b/part1/train.py @@ -40,12 +40,17 @@ def train_mnist_model(epochs): if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--model-path', type=str, - required=False, default='/tmp/mnist_model/') + required=False, default='/tmp/mnist_model/mnist_model.keras') parser.add_argument('--epochs', type=int, required=False, default=2) args = parser.parse_args() + model_dir = os.path.dirname(args.model_path) + if not os.path.exists(model_dir): + os.makedirs(model_dir) + print('Training MNIST model') model = train_mnist_model(epochs=args.epochs) print('Saving model to', args.model_path) - tf.saved_model.save(model, args.model_path) + model.save(args.model_path) +