-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[EAGLE-5342] Added Model Upload Tests (#495)
* created temp dummy_models_path * created temp dummy_models_path * created temp dummy_models_path in run_locally * Added hf model run locally tests * Added hf_mbart_model dummy model for tests * remove xformers * Added model upload tests * Fix minor status_code_pb2 issue * Fix issues * Fix minor status_code_pb2 issue * Fix minor issue * Fix minor issue * reduce transformers version * fix requirements version * fix path for windows * fix tests for windows * use python builtin tar function for taring
- Loading branch information
1 parent
4a7a506
commit b85883f
Showing
6 changed files
with
414 additions
and
19 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
import os | ||
from typing import Iterator | ||
|
||
import torch | ||
from clarifai_grpc.grpc.api import resources_pb2, service_pb2 | ||
from clarifai_grpc.grpc.api.status import status_code_pb2, status_pb2 | ||
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | ||
|
||
from clarifai.runners.models.model_runner import ModelRunner | ||
from clarifai.utils.logging import logger | ||
|
||
NUM_GPUS = 1 | ||
|
||
|
||
def set_output(texts: list): | ||
assert isinstance(texts, list) | ||
output_protos = [] | ||
for text in texts: | ||
output_protos.append( | ||
resources_pb2.Output( | ||
data=resources_pb2.Data(text=resources_pb2.Text(raw=text)), | ||
status=status_pb2.Status(code=status_code_pb2.SUCCESS))) | ||
return output_protos | ||
|
||
|
||
class MyRunner(ModelRunner): | ||
"""A custom runner that loads the model and generates text using lmdeploy inference. | ||
""" | ||
|
||
def load_model(self): | ||
"""Load the model here""" | ||
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' | ||
logger.info(f"Running on device: {self.device}") | ||
checkpoints = os.path.join(os.path.dirname(__file__), "checkpoints") | ||
|
||
for root, dirs, files in os.walk(checkpoints): | ||
for f in files: | ||
logger.info(os.path.join(root, f)) | ||
|
||
# if checkpoints section is in config.yaml file then checkpoints will be downloaded at this path during model upload time. | ||
self.tokenizer = AutoTokenizer.from_pretrained(checkpoints) | ||
self.model = AutoModelForSeq2SeqLM.from_pretrained( | ||
checkpoints, torch_dtype="auto", device_map=self.device) | ||
|
||
def predict(self, request: service_pb2.PostModelOutputsRequest | ||
) -> Iterator[service_pb2.MultiOutputResponse]: | ||
"""This is the method that will be called when the runner is run. It takes in an input and | ||
returns an output. | ||
""" | ||
texts = [inp.data.text.raw for inp in request.inputs] | ||
|
||
raw_texts = [] | ||
for t in texts: | ||
inputs = self.tokenizer.encode(t, return_tensors="pt").to(self.device) | ||
outputs = self.model.generate(inputs) | ||
raw_texts.append(self.tokenizer.decode(outputs[0])) | ||
output_protos = set_output(raw_texts) | ||
|
||
return service_pb2.MultiOutputResponse(outputs=output_protos) | ||
|
||
def generate(self, request: service_pb2.PostModelOutputsRequest | ||
) -> Iterator[service_pb2.MultiOutputResponse]: | ||
"""Example yielding a whole batch of streamed stuff back.""" | ||
raise NotImplementedError("This method is not implemented yet.") | ||
|
||
def stream(self, request_iterator: Iterator[service_pb2.PostModelOutputsRequest] | ||
) -> Iterator[service_pb2.MultiOutputResponse]: | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
# Config file for the VLLM runner | ||
|
||
model: | ||
id: "hf-mbart-model" | ||
user_id: "user_id" | ||
app_id: "app_id" | ||
model_type_id: "text-to-text" | ||
|
||
build_info: | ||
python_version: "3.12" | ||
|
||
inference_compute_info: | ||
cpu_limit: "500m" | ||
cpu_memory: "500Mi" | ||
num_accelerators: 0 | ||
|
||
checkpoints: | ||
type: "huggingface" | ||
repo_id: "sshleifer/tiny-mbart" | ||
hf_token: "" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
torch==2.4.0 | ||
tokenizers>=0.19.0 | ||
transformers>=4.44 | ||
accelerate>=1.0.1 | ||
optimum>=1.20.0 | ||
sentencepiece==0.2.0 | ||
requests==2.23.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.