Skip to content

Commit

Permalink
[EAGLE-5536]: 3 options to download checkpoints (#515)
Browse files Browse the repository at this point in the history
* introduce 3 options for when to download checkpoints

* updated one test

* updated another test

* updated two more tests

* remove any and fix windows

* passing tests for windows too

* stage fixes

* fix test

---------

Co-authored-by: luv-bansal <[email protected]>
  • Loading branch information
zeiler and luv-bansal authored Feb 12, 2025
1 parent 9bfc8dc commit bf16a69
Show file tree
Hide file tree
Showing 11 changed files with 116 additions and 81 deletions.
8 changes: 7 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
## [[11.1.3]](https://github.com/Clarifai/clarifai-python/releases/tag/11.1.3) - [PyPI](https://pypi.org/project/clarifai/11.1.2/) - 2025-02-11
## [[11.1.4]](https://github.com/Clarifai/clarifai-python/releases/tag/11.1.4) - [PyPI](https://pypi.org/project/clarifai/11.1.4/) - 2025-02-12

### Changed

- Introduce 3 times when you can download checkpoints [(#515)] (https://github.com/Clarifai/clarifai-python/pull/515)

## [[11.1.3]](https://github.com/Clarifai/clarifai-python/releases/tag/11.1.3) - [PyPI](https://pypi.org/project/clarifai/11.1.3/) - 2025-02-11

### Changed

Expand Down
2 changes: 1 addition & 1 deletion clarifai/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "11.1.3"
__version__ = "11.1.4"
28 changes: 20 additions & 8 deletions clarifai/cli/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,24 @@ def model():
required=True,
help='Path to the model directory.')
@click.option(
'--download_checkpoints',
is_flag=True,
'--stage',
required=False,
type=click.Choice(['runtime', 'build', 'upload'], case_sensitive=True),
default="upload",
show_default=True,
help=
'Flag to download checkpoints before uploading and including them in the tar file that is uploaded. Defaults to False, which will attempt to download them at docker build time.',
'The stage we are calling download checkpoints from. Typically this would "upload" and will download checkpoints if config.yaml checkpoints section has when set to "upload". Other options include "runtime" to be used in load_model or "upload" to be used during model upload. Set this stage to whatever you have in config.yaml to force downloading now.'
)
@click.option(
'--skip_dockerfile',
is_flag=True,
help=
'Flag to skip generating a dockerfile so that you can manually edit an already created dockerfile.',
)
def upload(model_path, download_checkpoints, skip_dockerfile):
def upload(model_path, stage, skip_dockerfile):
"""Upload a model to Clarifai."""
from clarifai.runners.models.model_builder import upload_model
upload_model(model_path, download_checkpoints, skip_dockerfile)
upload_model(model_path, stage, skip_dockerfile)


@model.command()
Expand All @@ -44,14 +47,23 @@ def upload(model_path, download_checkpoints, skip_dockerfile):
required=False,
default=None,
help=
'Option path to write the checkpoints to. This will place them in {out_path}/ If not provided it will default to {model_path}/1/checkpoints where the config.yaml is read..'
'Option path to write the checkpoints to. This will place them in {out_path}/1/checkpoints If not provided it will default to {model_path}/1/checkpoints where the config.yaml is read.'
)
@click.option(
'--stage',
required=False,
type=click.Choice(['runtime', 'build', 'upload'], case_sensitive=True),
default="build",
show_default=True,
help=
'The stage we are calling download checkpoints from. Typically this would be in the build stage which is the default. Other options include "runtime" to be used in load_model or "upload" to be used during model upload. Set this stage to whatever you have in config.yaml to force downloading now.'
)
def download_checkpoints(model_path, out_path):
def download_checkpoints(model_path, out_path, stage):
"""Download checkpoints from external source to local model_path"""

from clarifai.runners.models.model_builder import ModelBuilder
builder = ModelBuilder(model_path, download_validation_only=True)
builder.download_checkpoints(out_path)
builder.download_checkpoints(stage=stage, checkpoint_path_override=out_path)


@model.command()
Expand Down
37 changes: 3 additions & 34 deletions clarifai/runners/dockerfile_template/Dockerfile.template
Original file line number Diff line number Diff line change
@@ -1,30 +1,11 @@
# syntax=docker/dockerfile:1.13-labs
#############################
# User specific requirements installed in the pip_packages
#############################
FROM --platform=$TARGETPLATFORM ${FINAL_IMAGE} as pip_packages
FROM --platform=$TARGETPLATFORM ${FINAL_IMAGE} as final

COPY --link requirements.txt /home/nonroot/requirements.txt

# Update clarifai package so we always have latest protocol to the API. Everything should land in /venv
RUN ["pip", "install", "--no-cache-dir", "-r", "/home/nonroot/requirements.txt"]
RUN ["pip", "show", "clarifai"]
#############################

#############################
# Downloader dependencies image
#############################
FROM --platform=$TARGETPLATFORM ${DOWNLOADER_IMAGE} as downloader

# make sure we have the latest clarifai package. This version is filled in by SDK.
RUN ["pip", "install", "clarifai==${CLARIFAI_VERSION}"]
#####


#############################
# Final runtime image
#############################
FROM --platform=$TARGETPLATFORM ${FINAL_IMAGE} as final

# Set the NUMBA cache dir to /tmp
# Set the TORCHINDUCTOR cache dir to /tmp
Expand All @@ -34,28 +15,16 @@ ENV NUMBA_CACHE_DIR=/tmp/numba_cache \
HOME=/tmp \
DEBIAN_FRONTEND=noninteractive

#####
# Copy the python requirements needed to download checkpoints
#####
COPY --link=true --from=downloader /venv /venv
#####

#####
# Copy the files needed to download
#####
# This creates the directory that HF downloader will populate and with nonroot:nonroot permissions up.
COPY --chown=nonroot:nonroot downloader/unused.yaml /home/nonroot/main/1/checkpoints/.cache/unused.yaml

#####
# Download checkpoints
# Download checkpoints if config.yaml has checkpoints.when = "build"
COPY --link=true config.yaml /home/nonroot/main/
RUN ["python", "-m", "clarifai.cli", "model", "download-checkpoints", "--model_path", "/home/nonroot/main", "--out_path", "/home/nonroot/main"]
#####


#####
# Copy the python packages from the builder stage.
COPY --link=true --from=pip_packages /venv /venv
RUN ["python", "-m", "clarifai.cli", "model", "download-checkpoints", "--model_path", "/home/nonroot/main", "--out_path", "/home/nonroot/main/1/checkpoints", "--stage", "build"]
#####

# Copy in the actual files like config.yaml, requirements.txt, and most importantly 1/model.py
Expand Down
103 changes: 71 additions & 32 deletions clarifai/runners/models/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@

from clarifai.client import BaseClient
from clarifai.runners.models.model_class import ModelClass
from clarifai.runners.utils.const import (AVAILABLE_PYTHON_IMAGES, AVAILABLE_TORCH_IMAGES,
CONCEPTS_REQUIRED_MODEL_TYPE, DEFAULT_PYTHON_VERSION,
PYTHON_BASE_IMAGE, TORCH_BASE_IMAGE)
from clarifai.runners.utils.const import (
AVAILABLE_PYTHON_IMAGES, AVAILABLE_TORCH_IMAGES, CONCEPTS_REQUIRED_MODEL_TYPE,
DEFAULT_DOWNLOAD_CHECKPOINT_WHEN, DEFAULT_PYTHON_VERSION, DEFAULT_RUNTIME_DOWNLOAD_PATH,
PYTHON_BASE_IMAGE, TORCH_BASE_IMAGE)
from clarifai.runners.utils.loader import HuggingFaceLoader
from clarifai.urls.helper import ClarifaiUrlHelper
from clarifai.utils.logging import logger
Expand Down Expand Up @@ -145,19 +146,32 @@ def _validate_config_checkpoints(self):
:return: repo_id location of checkpoint.
:return: hf_token token to access checkpoint.
"""
if "checkpoints" not in self.config:
return None, None, None, DEFAULT_DOWNLOAD_CHECKPOINT_WHEN
assert "type" in self.config.get("checkpoints"), "No loader type specified in the config file"
loader_type = self.config.get("checkpoints").get("type")
if not loader_type:
logger.info("No loader type specified in the config file for checkpoints")
return None, None, None
checkpoints = self.config.get("checkpoints")
if 'when' not in checkpoints:
logger.warn(
f"No 'when' specified in the config file for checkpoints, defaulting to download at {DEFAULT_DOWNLOAD_CHECKPOINT_WHEN}"
)
when = checkpoints.get("when", DEFAULT_DOWNLOAD_CHECKPOINT_WHEN)
assert when in [
"upload",
"build",
"runtime",
], "Invalid value for when in the checkpoint loader when, needs to be one of ['upload', 'build', 'runtime']"
assert loader_type == "huggingface", "Only huggingface loader supported for now"
if loader_type == "huggingface":
assert "repo_id" in self.config.get("checkpoints"), "No repo_id specified in the config file"
repo_id = self.config.get("checkpoints").get("repo_id")

# get from config.yaml otherwise fall back to HF_TOKEN env var.
hf_token = self.config.get("checkpoints").get("hf_token", os.environ.get("HF_TOKEN", None))
return loader_type, repo_id, hf_token
return loader_type, repo_id, hf_token, when

def _check_app_exists(self):
resp = self.client.STUB.GetApp(service_pb2.GetAppRequest(user_app_id=self.client.user_app_id))
Expand Down Expand Up @@ -202,7 +216,7 @@ def _validate_config(self):
assert model_type_id in CONCEPTS_REQUIRED_MODEL_TYPE, f"Model type {model_type_id} not supported for concepts"

if self.config.get("checkpoints"):
loader_type, _, hf_token = self._validate_config_checkpoints()
loader_type, _, hf_token, _ = self._validate_config_checkpoints()

if loader_type == "huggingface" and hf_token:
is_valid_token = HuggingFaceLoader.validate_hftoken(hf_token)
Expand Down Expand Up @@ -428,31 +442,51 @@ def _checkpoint_path(self, folder):

@property
def checkpoint_suffix(self):
return '1/checkpoints'
return os.path.join('1', 'checkpoints')

@property
def tar_file(self):
return f"{self.folder}.tar.gz"

def download_checkpoints(self, checkpoint_path_override: str = None):
def default_runtime_checkpoint_path(self):
return DEFAULT_RUNTIME_DOWNLOAD_PATH

def download_checkpoints(self,
stage: str = DEFAULT_DOWNLOAD_CHECKPOINT_WHEN,
checkpoint_path_override: str = None):
"""
Downloads the checkpoints specified in the config file.
:param checkpoint_path_override: The path to download the checkpoints to. If not provided, the
default path is used based on the folder ModelUploader was initialized with. The
checkpoint_suffix will be appended to the path.
:param stage: The stage of the build process. This is used to determine when to download the
checkpoints. The stage can be one of ['build', 'upload', 'runtime']. If you want to force
downloading now then set stage to match e when field of the checkpoints section of you config.yaml.
:param checkpoint_path_override: The path to download the checkpoints to (with 1/checkpoints added as suffix). If not provided, the
default path is used based on the folder ModelUploader was initialized with. The checkpoint_suffix will be appended to the path.
If stage is 'runtime' and checkpoint_path_override is None, the default runtime path will be used.
:return: The path to the downloaded checkpoints. Even if it doesn't download anything, it will return the default path.
"""
path = self.checkpoint_path # default checkpoint path.
if not self.config.get("checkpoints"):
logger.info("No checkpoints specified in the config file")
return True
return path

loader_type, repo_id, hf_token = self._validate_config_checkpoints()
loader_type, repo_id, hf_token, when = self._validate_config_checkpoints()
if stage not in ["build", "upload", "runtime"]:
raise Exception("Invalid stage provided, must be one of ['build', 'upload', 'runtime']")
if when != stage:
logger.info(
f"Skipping downloading checkpoints for stage {stage} since config.yaml says to download them at stage {when}"
)
return path

success = True
if loader_type == "huggingface":
loader = HuggingFaceLoader(repo_id=repo_id, token=hf_token)
path = self._checkpoint_path(
checkpoint_path_override) if checkpoint_path_override else self.checkpoint_path
# for runtime default to /tmp path
if stage == "runtime" and checkpoint_path_override is None:
checkpoint_path_override = self.default_runtime_checkpoint_path()
path = checkpoint_path_override if checkpoint_path_override else self.checkpoint_path
success = loader.download_checkpoints(path)

if loader_type:
Expand All @@ -461,7 +495,7 @@ def download_checkpoints(self, checkpoint_path_override: str = None):
sys.exit(1)
else:
logger.info(f"Downloaded checkpoints for model {repo_id}")
return success
return path

def _concepts_protos_from_concepts(self, concepts):
concept_protos = []
Expand Down Expand Up @@ -520,11 +554,12 @@ def get_model_version_proto(self):
self._concepts_protos_from_concepts(labels))
return model_version_proto

def upload_model_version(self, download_checkpoints):
def upload_model_version(self):
file_path = f"{self.folder}.tar.gz"
logger.debug(f"Will tar it into file: {file_path}")

model_type_id = self.config.get('model').get('model_type_id')
loader_type, repo_id, hf_token, when = self._validate_config_checkpoints()

if (model_type_id in CONCEPTS_REQUIRED_MODEL_TYPE) and 'concepts' not in self.config:
logger.info(
Expand All @@ -534,15 +569,13 @@ def upload_model_version(self, download_checkpoints):
logger.info(
"Checkpoints specified in the config.yaml file, will download the HF model's config.json file to infer the concepts."
)

if not download_checkpoints and not HuggingFaceLoader.validate_config(
self.checkpoint_path):

input(
"Press Enter to download the HuggingFace model's config.json file to infer the concepts and continue..."
)
loader_type, repo_id, hf_token = self._validate_config_checkpoints()
if loader_type == "huggingface":
# If we don't already have the concepts, download the config.json file from HuggingFace
if loader_type == "huggingface":
# If the config.yaml says we'll download in the future (build time or runtime) then we need to get this config now.
if when != "upload" and not HuggingFaceLoader.validate_config(self.checkpoint_path):
input(
"Press Enter to download the HuggingFace model's config.json file to infer the concepts and continue..."
)
loader = HuggingFaceLoader(repo_id=repo_id, token=hf_token)
loader.download_config(self.checkpoint_path)

Expand All @@ -557,7 +590,7 @@ def upload_model_version(self, download_checkpoints):
def filter_func(tarinfo):
name = tarinfo.name
exclude = [self.tar_file, "*~"]
if not download_checkpoints:
if when != "upload":
exclude.append(self.checkpoint_suffix)
return None if any(name.endswith(ex) for ex in exclude) else tarinfo

Expand All @@ -569,12 +602,12 @@ def filter_func(tarinfo):
logger.debug(f"Size of the tar is: {file_size} bytes")

self.storage_request_size = self._get_tar_file_content_size(file_path)
if not download_checkpoints and self.config.get("checkpoints"):
if when != "upload" and self.config.get("checkpoints"):
# Get the checkpoint size to add to the storage request.
# First check for the env variable, then try querying huggingface. If all else fails, use the default.
checkpoint_size = os.environ.get('CHECKPOINT_SIZE_BYTES', 0)
if not checkpoint_size:
_, repo_id, _ = self._validate_config_checkpoints()
_, repo_id, _, _ = self._validate_config_checkpoints()
checkpoint_size = HuggingFaceLoader.get_huggingface_checkpoint_total_size(repo_id)
if not checkpoint_size:
checkpoint_size = self.DEFAULT_CHECKPOINT_SIZE
Expand Down Expand Up @@ -702,10 +735,16 @@ def monitor_model_build(self):
return False


def upload_model(folder, download_checkpoints, skip_dockerfile):
def upload_model(folder, stage, skip_dockerfile):
"""
Uploads a model to Clarifai.
:param folder: The folder containing the model files.
:param stage: The stage we are calling download checkpoints from. Typically this would "upload" and will download checkpoints if config.yaml checkpoints section has when set to "upload". Other options include "runtime" to be used in load_model or "upload" to be used during model upload. Set this stage to whatever you have in config.yaml to force downloading now.
:param skip_dockerfile: If True, will not create a Dockerfile.
"""
builder = ModelBuilder(folder)
if download_checkpoints:
builder.download_checkpoints()
builder.download_checkpoints(stage=stage)
if not skip_dockerfile:
builder.create_dockerfile()
exists = builder.check_model_exists()
Expand All @@ -717,4 +756,4 @@ def upload_model(folder, download_checkpoints, skip_dockerfile):
logger.info(f"New model will be created at {builder.model_url} with it's first version.")

input("Press Enter to continue...")
builder.upload_model_version(download_checkpoints)
builder.upload_model_version()
3 changes: 2 additions & 1 deletion clarifai/runners/models/model_run_locally.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,8 @@ def main(model_path,
)
sys.exit(1)
manager = ModelRunLocally(model_path)
manager.builder.download_checkpoints()
# stage="any" forces downloaded now regardless of config.yaml
manager.builder.download_checkpoints(stage="any")
if inside_container:
if not manager.is_docker_installed():
sys.exit(1)
Expand Down
6 changes: 6 additions & 0 deletions clarifai/runners/utils/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@

DEFAULT_PYTHON_VERSION = 3.12

# By default we download at runtime.
DEFAULT_DOWNLOAD_CHECKPOINT_WHEN = "runtime"

# Folder for downloading checkpoints at runtime.
DEFAULT_RUNTIME_DOWNLOAD_PATH = "/tmp/.cache"

# List of available torch images
# Keep sorted by most recent cuda version.
AVAILABLE_TORCH_IMAGES = [
Expand Down
1 change: 1 addition & 0 deletions tests/runners/hf_mbart_model/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ inference_compute_info:
checkpoints:
type: "huggingface"
repo_id: "sshleifer/tiny-mbart"
when: "build"
5 changes: 3 additions & 2 deletions tests/runners/test_download_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,5 +58,6 @@ def test_validate_download(checkpoint_dir):
def test_download_checkpoints(dummy_runner_models_dir):
model_folder_path = os.path.join(os.path.dirname(__file__), "dummy_runner_models")
model_builder = ModelBuilder(model_folder_path, download_validation_only=True)
isdownloaded = model_builder.download_checkpoints()
assert isdownloaded is True
checkpoint_dir = model_builder.download_checkpoints(stage="build") # any forces download now.
assert checkpoint_dir == os.path.join(
os.path.dirname(__file__), "dummy_runner_models", "1", "checkpoints")
2 changes: 1 addition & 1 deletion tests/runners/test_model_run_locally.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def test_hf_test_model_success(hf_model_run_locally):
Test that test_model succeeds with the dummy model.
This calls the script's test_model method, which runs a subprocess.
"""
hf_model_run_locally.builder.download_checkpoints()
hf_model_run_locally.builder.download_checkpoints(stage="build")
hf_model_run_locally.create_temp_venv()
hf_model_run_locally.install_requirements()

Expand Down
Loading

0 comments on commit bf16a69

Please sign in to comment.