Skip to content

Commit

Permalink
Update train.py
Browse files Browse the repository at this point in the history
  • Loading branch information
kylechard authored Oct 11, 2024
1 parent 5d1aed6 commit f04fdaf
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions part1/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,17 @@ def train_mnist_model(epochs):
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--model-path', type=str,
required=False, default='/tmp/mnist_model/')
required=False, default='/tmp/mnist_model/mnist_model.keras')
parser.add_argument('--epochs', type=int, required=False, default=2)
args = parser.parse_args()

model_dir = os.path.dirname(args.model_path)
if not os.path.exists(model_dir):
os.makedirs(model_dir)

print('Training MNIST model')
model = train_mnist_model(epochs=args.epochs)

print('Saving model to', args.model_path)
tf.saved_model.save(model, args.model_path)
model.save(args.model_path)

0 comments on commit f04fdaf

Please sign in to comment.