Generate with Tensorflow demo
+
+ + Use this page to test your implementation with Tensorflow. Enter + text and receive the model output as a response. +
+From ac04605508be5c4dbb08a44e7f99132f8fa9b1fd Mon Sep 17 00:00:00 2001 From: loks0n <22452787+loks0n@users.noreply.github.com> Date: Wed, 29 May 2024 12:19:30 +0100 Subject: [PATCH 1/4] feat: python generate with tensorflow --- python/generate_with_tensorflow/.gitignore | 160 ++++++++++++++++++ python/generate_with_tensorflow/README.md | 63 +++++++ .../generate_with_tensorflow/requirements.txt | 2 + python/generate_with_tensorflow/src/main.py | 48 ++++++ python/generate_with_tensorflow/src/train.py | 75 ++++++++ python/generate_with_tensorflow/src/utils.py | 35 ++++ .../static/index.html | 91 ++++++++++ 7 files changed, 474 insertions(+) create mode 100644 python/generate_with_tensorflow/.gitignore create mode 100644 python/generate_with_tensorflow/README.md create mode 100644 python/generate_with_tensorflow/requirements.txt create mode 100644 python/generate_with_tensorflow/src/main.py create mode 100644 python/generate_with_tensorflow/src/train.py create mode 100644 python/generate_with_tensorflow/src/utils.py create mode 100644 python/generate_with_tensorflow/static/index.html diff --git a/python/generate_with_tensorflow/.gitignore b/python/generate_with_tensorflow/.gitignore new file mode 100644 index 00000000..68bc17f9 --- /dev/null +++ b/python/generate_with_tensorflow/.gitignore @@ -0,0 +1,160 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ diff --git a/python/generate_with_tensorflow/README.md b/python/generate_with_tensorflow/README.md new file mode 100644 index 00000000..bab0ace7 --- /dev/null +++ b/python/generate_with_tensorflow/README.md @@ -0,0 +1,63 @@ +# 🤖 Python Generate with TensorFlow Function + +Generate text using a TensorFlow-based RNN model. + +## 🧰 Usage + +### GET / + +HTML form for interacting with the function. + +### POST / + +Query the model for a text generation completion. + +**Parameters** + +| Name | Description | Location | Type | Sample Value | +| ------------ | ------------------------------------ | -------- | ------------------ | ------------------ | +| Content-Type | The content type of the request body | Header | `application/json` | N/A | +| prompt | Text to prompt the model | Body | String | `Once upon a time` | + +Sample `200` Response: + +Response from the model. + +```json +{ + "ok": true, + "completion": "Once upon a time, in a land far, far away, there lived a wise old owl." +} +``` + +Sample `400` Response: + +Response when the request body is missing. + +```json +{ + "ok": false, + "error": "Missing body with a prompt." +} +``` + +Sample `500` Response: + +Response when the model fails to respond. + +```json +{ + "ok": false, + "error": "Failed to query model." +} +``` + +## ⚙️ Configuration + +| Setting | Value | +| ----------------- | -------------------------------------------------------- | +| Runtime | Python ML (3.11) | +| Entrypoint | `src/main.py` | +| Build Commands | `pip install -r requirements.txt && python src/train.py` | +| Permissions | `any` | +| Timeout (Seconds) | 30 | diff --git a/python/generate_with_tensorflow/requirements.txt b/python/generate_with_tensorflow/requirements.txt new file mode 100644 index 00000000..95d00589 --- /dev/null +++ b/python/generate_with_tensorflow/requirements.txt @@ -0,0 +1,2 @@ +tensorflow +numpy \ No newline at end of file diff --git a/python/generate_with_tensorflow/src/main.py b/python/generate_with_tensorflow/src/main.py new file mode 100644 index 00000000..2c861b5f --- /dev/null +++ b/python/generate_with_tensorflow/src/main.py @@ -0,0 +1,48 @@ +import tensorflow as tf +import numpy as np +from .utils import get_static_file, throw_if_missing + + +def main(context): + if context.req.method == "GET": + return context.res.send( + get_static_file("index.html"), + 200, + {"content-type": "text/html; charset=utf-8"}, + ) + + try: + throw_if_missing(context.req.body, ["prompt"]) + except ValueError as err: + return context.res.json({"ok": False, "error": err.message}, 400) + + prompt = context.req.body["prompt"] + generated_text = generate_text(prompt) + return context.res.json({"ok": True, "completion": generated_text}, 200) + + +def generate_text(prompt): + # Load the trained model and tokenizer + model = tf.keras.models.load_model("text_generation_model.h5") + char2idx = np.load("char2idx.npy", allow_pickle=True).item() + idx2char = np.load("idx2char.npy", allow_pickle=True) + + # Vectorize the prompt + input_eval = [char2idx[s] for s in prompt] + input_eval = tf.expand_dims(input_eval, 0) + + # Generate text + text_generated = [] + temperature = 1.0 + + model.reset_states() + for _ in range(1000): + predictions = model(input_eval) + predictions = tf.squeeze(predictions, 0) + predictions = predictions / temperature + predicted_id = tf.random.categorical(predictions, num_samples=1)[-1, 0].numpy() + + input_eval = tf.expand_dims([predicted_id], 0) + text_generated.append(idx2char[predicted_id]) + + return prompt + "".join(text_generated) diff --git a/python/generate_with_tensorflow/src/train.py b/python/generate_with_tensorflow/src/train.py new file mode 100644 index 00000000..54e65806 --- /dev/null +++ b/python/generate_with_tensorflow/src/train.py @@ -0,0 +1,75 @@ +import tensorflow as tf +import numpy as np +import os + + +def main(): + path_to_file = tf.keras.utils.get_file( + "shakespeare.txt", + "https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt", + ) + text = open(path_to_file, "rb").read().decode(encoding="utf-8") + vocab = sorted(set(text)) + char2idx = {u: i for i, u in enumerate(vocab)} + idx2char = np.array(vocab) + + text_as_int = np.array([char2idx[c] for c in text]) + seq_length = 100 + char_dataset = tf.data.Dataset.from_tensor_slices(text_as_int) + sequences = char_dataset.batch(seq_length + 1, drop_remainder=True) + + def split_input_target(chunk): + input_text = chunk[:-1] + target_text = chunk[1:] + return input_text, target_text + + dataset = sequences.map(split_input_target) + BATCH_SIZE = 64 + BUFFER_SIZE = 10000 + dataset = dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True) + + vocab_size = len(vocab) + embedding_dim = 256 + rnn_units = 1024 + + model = tf.keras.Sequential( + [ + tf.keras.layers.Embedding( + vocab_size, embedding_dim, batch_input_shape=[BATCH_SIZE, None] + ), + tf.keras.layers.GRU( + rnn_units, + return_sequences=True, + stateful=True, + recurrent_initializer="glorot_uniform", + ), + tf.keras.layers.Dense(vocab_size), + ] + ) + + def loss(labels, logits): + return tf.keras.losses.sparse_categorical_crossentropy( + labels, logits, from_logits=True + ) + + model.compile(optimizer="adam", loss=loss) + + EPOCHS = 10 + checkpoint_dir = "./training_checkpoints" + checkpoint_prefix = f"{checkpoint_dir}/ckpt_{{epoch}}" + + checkpoint_callback = tf.keras.callbacks.ModelCheckpoint( + filepath=checkpoint_prefix, save_weights_only=True + ) + + model.fit(dataset, epochs=EPOCHS, callbacks=[checkpoint_callback]) + + model.save("text_generation_model.h5") + np.save("char2idx.npy", char2idx) + np.save("idx2char.npy", idx2char) + + os.remove(path_to_file) + + +if __name__ == "__main__": + main() diff --git a/python/generate_with_tensorflow/src/utils.py b/python/generate_with_tensorflow/src/utils.py new file mode 100644 index 00000000..f5fae37d --- /dev/null +++ b/python/generate_with_tensorflow/src/utils.py @@ -0,0 +1,35 @@ +import os + +__dirname = os.path.dirname(os.path.abspath(__file__)) +static_folder = os.path.join(__dirname, "../static") + + +def get_static_file(file_name: str) -> str: + """ + Returns the contents of a file in the static folder + + Parameters: + file_name (str): Name of the file to read + + Returns: + (str): Contents of static/{file_name} + """ + file_path = os.path.join(static_folder, file_name) + with open(file_path, "r") as file: + return file.read() + + +def throw_if_missing(obj: object, keys: list[str]) -> None: + """ + Throws an error if any of the keys are missing from the object + + Parameters: + obj (object): Object to check + keys (list[str]): List of keys to check + + Raises: + ValueError: If any keys are missing + """ + missing = [key for key in keys if key not in obj or not obj[key]] + if missing: + raise ValueError(f"Missing required fields: {', '.join(missing)}") diff --git a/python/generate_with_tensorflow/static/index.html b/python/generate_with_tensorflow/static/index.html new file mode 100644 index 00000000..27a35f60 --- /dev/null +++ b/python/generate_with_tensorflow/static/index.html @@ -0,0 +1,91 @@ + + +
+ + + +
+ + Use this page to test your implementation with Tensorflow. Enter + text and receive the model output as a response. +
+