Skip to content

Commit

Permalink
Merge pull request #292 from appwrite/feat-generate-with-tensorflow
Browse files Browse the repository at this point in the history
feat: python generate with tensorflow
  • Loading branch information
christyjacob4 authored Jun 11, 2024
2 parents 882f67b + cdd4c06 commit bb73e15
Show file tree
Hide file tree
Showing 7 changed files with 474 additions and 0 deletions.
160 changes: 160 additions & 0 deletions python-ml/generate_with_tensorflow/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
.pybuilder/
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version

# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock

# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock

# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml

# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# pytype static type analyzer
.pytype/

# Cython debug symbols
cython_debug/

# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
63 changes: 63 additions & 0 deletions python-ml/generate_with_tensorflow/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# 🤖 Python Generate with TensorFlow Function

Generate text using a TensorFlow-based RNN model.

## 🧰 Usage

### GET /

HTML form for interacting with the function.

### POST /

Query the model for a text generation completion.

**Parameters**

| Name | Description | Location | Type | Sample Value |
| ------------ | ------------------------------------ | -------- | ------------------ | ------------------ |
| Content-Type | The content type of the request body | Header | `application/json` | N/A |
| prompt | Text to prompt the model | Body | String | `Once upon a time` |

Sample `200` Response:

Response from the model.

```json
{
"ok": true,
"completion": "Once upon a time, in a land far, far away, there lived a wise old owl."
}
```

Sample `400` Response:

Response when the request body is missing.

```json
{
"ok": false,
"error": "Missing body with a prompt."
}
```

Sample `500` Response:

Response when the model fails to respond.

```json
{
"ok": false,
"error": "Failed to query model."
}
```

## ⚙️ Configuration

| Setting | Value |
| ----------------- | -------------------------------------------------------- |
| Runtime | Python ML (3.11) |
| Entrypoint | `src/main.py` |
| Build Commands | `pip install -r requirements.txt && python src/train.py` |
| Permissions | `any` |
| Timeout (Seconds) | 30 |
2 changes: 2 additions & 0 deletions python-ml/generate_with_tensorflow/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
tensorflow
numpy
48 changes: 48 additions & 0 deletions python-ml/generate_with_tensorflow/src/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import tensorflow as tf
import numpy as np
from .utils import get_static_file, throw_if_missing


def main(context):
if context.req.method == "GET":
return context.res.send(
get_static_file("index.html"),
200,
{"content-type": "text/html; charset=utf-8"},
)

try:
throw_if_missing(context.req.body, ["prompt"])
except ValueError as err:
return context.res.json({"ok": False, "error": str(err)}, 400)

prompt = context.req.body["prompt"]
generated_text = generate_text(prompt)
return context.res.json({"ok": True, "completion": generated_text}, 200)


def generate_text(prompt):
# Load the trained model and tokenizer
model = tf.keras.models.load_model("text_generation_model.h5")
char2idx = np.load("char2idx.npy", allow_pickle=True).item()
idx2char = np.load("idx2char.npy", allow_pickle=True)

# Vectorize the prompt
input_eval = [char2idx[s] for s in prompt]
input_eval = tf.expand_dims(input_eval, 0)

# Generate text
text_generated = []
temperature = 1.0

model.reset_states()
for _ in range(1000):
predictions = model(input_eval)
predictions = tf.squeeze(predictions, 0)
predictions = predictions / temperature
predicted_id = tf.random.categorical(predictions, num_samples=1)[-1, 0].numpy()

input_eval = tf.expand_dims([predicted_id], 0)
text_generated.append(idx2char[predicted_id])

return prompt + "".join(text_generated)
75 changes: 75 additions & 0 deletions python-ml/generate_with_tensorflow/src/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import tensorflow as tf
import numpy as np
import os


def main():
path_to_file = tf.keras.utils.get_file(
"shakespeare.txt",
"https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt",
)
text = open(path_to_file, "rb").read().decode(encoding="utf-8")
vocab = sorted(set(text))
char2idx = {u: i for i, u in enumerate(vocab)}
idx2char = np.array(vocab)

text_as_int = np.array([char2idx[c] for c in text])
seq_length = 100
char_dataset = tf.data.Dataset.from_tensor_slices(text_as_int)
sequences = char_dataset.batch(seq_length + 1, drop_remainder=True)

def split_input_target(chunk):
input_text = chunk[:-1]
target_text = chunk[1:]
return input_text, target_text

dataset = sequences.map(split_input_target)
BATCH_SIZE = 64
BUFFER_SIZE = 10000
dataset = dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True)

vocab_size = len(vocab)
embedding_dim = 256
rnn_units = 1024

model = tf.keras.Sequential(
[
tf.keras.layers.Embedding(
vocab_size, embedding_dim, batch_input_shape=[BATCH_SIZE, None]
),
tf.keras.layers.GRU(
rnn_units,
return_sequences=True,
stateful=True,
recurrent_initializer="glorot_uniform",
),
tf.keras.layers.Dense(vocab_size),
]
)

model.compile(
optimizer="adam", loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True)
)

EPOCHS = 10
checkpoint_dir = "./training_checkpoints"

os.makedirs(checkpoint_dir, exist_ok=True)

checkpoint_prefix = f"{checkpoint_dir}/ckpt_{{epoch}}"

checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_prefix, save_weights_only=True
)

model.fit(dataset, epochs=EPOCHS, callbacks=[checkpoint_callback])

model.save("text_generation_model.h5")
np.save("char2idx.npy", char2idx)
np.save("idx2char.npy", idx2char)

os.remove(path_to_file)


if __name__ == "__main__":
main()
35 changes: 35 additions & 0 deletions python-ml/generate_with_tensorflow/src/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import os

__dirname = os.path.dirname(os.path.abspath(__file__))
static_folder = os.path.join(__dirname, "../static")


def get_static_file(file_name: str) -> str:
"""
Returns the contents of a file in the static folder
Parameters:
file_name (str): Name of the file to read
Returns:
(str): Contents of static/{file_name}
"""
file_path = os.path.join(static_folder, file_name)
with open(file_path, "r") as file:
return file.read()


def throw_if_missing(obj: object, keys: list[str]) -> None:
"""
Throws an error if any of the keys are missing from the object
Parameters:
obj (object): Object to check
keys (list[str]): List of keys to check
Raises:
ValueError: If any keys are missing
"""
missing = [key for key in keys if key not in obj or not obj[key]]
if missing:
raise ValueError(f"Missing required fields: {', '.join(missing)}")
Loading

0 comments on commit bb73e15

Please sign in to comment.