Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add onnxruntime genai example to the iterative scheduling tutorial #93

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions Conceptual_Guide/Part_7-iterative_scheduling/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def client2_callback(display, event, result, error):
event.set()


def run_inferences(url, model_name, display, max_tokens):
def run_inferences(url, model_name, display, max_tokens, framework):
# Create clients
client1 = grpcclient.InferenceServerClient(url)
client2 = grpcclient.InferenceServerClient(url)
Expand All @@ -79,7 +79,12 @@ def run_inferences(url, model_name, display, max_tokens):

# Setup the display initially with the prompts
display.clear()
parameters = {"ignore_eos": True, "max_tokens": max_tokens}

if framework == "ort":
# Ignore EOS is not supported in onnxruntime-genai
parameters = {"max_tokens": max_tokens}
else:
parameters = {"ignore_eos": True, "max_tokens": max_tokens}

client1.async_stream_infer(
model_name=model_name,
Expand Down Expand Up @@ -108,7 +113,10 @@ def run_inferences(url, model_name, display, max_tokens):
parser.add_argument("--url", type=str, default="localhost:8001")
parser.add_argument("--model", type=str, default="simple-gpt2")
parser.add_argument("--max-tokens", type=int, default=128)
parser.add_argument(
"--framework", type=str, default="huggingface", choices=["ort", "huggingface"]
)
args = parser.parse_args()
display = Display(args.max_tokens)

run_inferences(args.url, args.model, display, args.max_tokens)
run_inferences(args.url, args.model, display, args.max_tokens, args.framework)
200 changes: 200 additions & 0 deletions Conceptual_Guide/Part_7-iterative_scheduling/client/profiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import argparse
import time
from queue import SimpleQueue
from threading import Event

import prettytable
import tqdm
from tritonserver import InferenceRequest, Server


class TimeStampedQueue(SimpleQueue):
def __init__(self, event) -> None:
super().__init__()
self.event = event

def put(self, item):
current_time = time.time()
super().put((current_time, item))
if item.final:
self.event.set()


class PerfAnalyzer:
def __init__(self, server, model_name, concurrency, input_data):
self._concurrency = concurrency
self._input_data = input_data
self._requests = []

self._measurement_interval_seconds = 5
self._number_of_intervals = 5
self._model = server.model(model_name)
self._prepare_requests()
self._queues = []
self._request_timestamps = []

self._is_profiler_thread_running = True

def _prepare_requests(self):
for i in range(self._concurrency):
input_data = self._input_data[i % len(self._input_data)]
request = InferenceRequest(
self._model,
inputs=input_data["inputs"],
parameters=input_data["parameters"],
)
self._requests.append(request)

def profile(self):
self._queues = []
self._request_timestamps = []
for i in tqdm.tqdm(range(20)):
results = []
current_queues = []
for i, request in enumerate(self._requests):
input_data = self._input_data[i % len(self._input_data)]
current_queue = TimeStampedQueue(Event())
self._queues.append(current_queue)
current_queues.append(current_queue)
request = InferenceRequest(
self._model,
inputs=input_data["inputs"],
parameters=input_data["parameters"],
response_queue=current_queue,
)
time.sleep(0.05)
self._request_timestamps.append(time.time())
results.append(self._model.infer(request))

for queue in current_queues:
queue.event.wait()

def _calculate_response_throughput(self, timestamp_lists):
timestamps = []
for timestamp_list in timestamp_lists:
for timestamp in timestamp_list:
timestamps.append(timestamp)

start_time = min(timestamps)
end_time = max(timestamps)

total_seconds = end_time - start_time
return len(timestamps) / total_seconds

def _calculate_time_to_last_response(self, timestamp_lists):
time_to_last_response = []
for i, timestamp_list in enumerate(timestamp_lists):
time_to_last_response.append(
timestamp_list[-1] - self._request_timestamps[i]
)

return (sum(time_to_last_response) / len(time_to_last_response)) * 1000

def _calculate_inter_token_latency(self, timestamp_lists):
inter_token_latencies = []
for _, timestamp_list in enumerate(timestamp_lists):
before = None
for timestamp in timestamp_list:
if before is None:
before = timestamp
else:
inter_token_latencies.append(timestamp - before)
return sum(inter_token_latencies) / len(inter_token_latencies)

def get_stats(self):
timestamp_lists = []
for queue in self._queues:
timestamp_lists.append([])
while queue.qsize() > 0:
timestamp, _ = queue.get_nowait()
timestamp_lists[-1].append(timestamp)

return {
"response_throughput": self._calculate_response_throughput(timestamp_lists),
"time_to_last_response": self._calculate_time_to_last_response(
timestamp_lists
),
"inter_token_latency": self._calculate_inter_token_latency(timestamp_lists),
}


if __name__ == "__main__":
argument_parser = argparse.ArgumentParser(description="Profile a model")
argument_parser.add_argument(
"--model-name",
type=str,
required=True,
help="Name of the model to profile",
action="extend",
nargs="+",
)
argument_parser.add_argument(
"--model-repository",
type=str,
required=True,
help="Path to the model repository",
)
argument_parser.add_argument(
"--concurrency", type=int, default=10, help="Number of concurrent requests"
)
args = argument_parser.parse_args()
input_data = [
{
"inputs": {"text_input": [["Hello, how are you?"]]},
"parameters": {"ignore_eos": True, "max_tokens": 32},
}
]

data = []
server = Server(model_repository=args.model_repository, log_error=True)
server.start(wait_until_ready=True)
for model_name in args.model_name:
perf_analyzer = PerfAnalyzer(server, model_name, args.concurrency, input_data)
perf_analyzer.profile()
stats = perf_analyzer.get_stats()
data.append(stats)

table = prettytable.PrettyTable(
[
"Model Name",
"Tokens/sec",
"Time to last token [TTLT] (ms)",
"Inter token latency [ITL] (ms)",
]
)
for i, entry in enumerate(data):
table.add_row(
[
args.model_name[i],
f"{entry['response_throughput']:.2f}",
f"{entry['time_to_last_response']:.3f}",
f"{entry['inter_token_latency']:.3f}",
]
)
print(table)
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import json
from pathlib import Path

import numpy as np
import onnxruntime_genai as og
import triton_python_backend_utils as pb_utils


class State:
def __init__(self):
self.prompt_tokens_len = 0
self.tokens = []
self.max_tokens = 0
self.generator_params = None


class TritonPythonModel:
def initialize(self, args):
self.state = {}
self.model_path = str(Path(args["model_repository"]) / args["model_version"])
self.model = og.Model(self.model_path)
self.tokenizer = og.Tokenizer(self.model)

def create_batch(self, requests):
"""
Create a batch of requests to be processed by the model.

Args:
requests (list): A list of Triton requests to process.

Returns:
og.GeneratorParams: A generator parameters object for the model.
"""
generator_params = og.GeneratorParams(self.model)

input_ids = []
for request in requests:
input_tensor = str(
pb_utils.get_input_tensor_by_name(request, "text_input")
.as_numpy()
.item(),
encoding="utf-8",
)
correlation_id = (
pb_utils.get_input_tensor_by_name(request, "correlation_id")
.as_numpy()
.item()
)
start = (
pb_utils.get_input_tensor_by_name(request, "start").as_numpy().item()
)
if start:
state = State()
state.tokens = self.tokenizer.encode(input_tensor)
state.prompt_tokens_len = len(state.tokens)

# Store the parameters
parameters = json.loads(request.parameters())
state.max_tokens = parameters["max_tokens"]

self.state[correlation_id] = state
state = self.state[correlation_id]
input_ids.append(state.tokens)

# Find the max sequence length
max_len = max([len(x) for x in input_ids])
input_ids = [
[generator_params.pad_token_id] * (max_len - len(x)) + x for x in input_ids
]
generator_params.input_ids = np.asarray(input_ids)

return generator_params

def send_responses(self, requests, outputs, generator_params):
"""
Send responses for each request based on the model outputs and update
the state of each request.

Args:
requests (list): A list of Triton requests to process.
outputs (list): A list of generated tokens from the model for each request.
generator_params (og.GeneratorParams): Parameters used for generating responses.
"""
for i, request in enumerate(requests):
correlation_id = (
pb_utils.get_input_tensor_by_name(request, "correlation_id")
.as_numpy()
.item()
)
response_sender = request.get_response_sender()
generated_token = outputs[i]

# Maximum generated token length
max_tokens = (
self.state[correlation_id].max_tokens
+ self.state[correlation_id].prompt_tokens_len
)

self.state[correlation_id].tokens.append(outputs[i])
if (outputs[i] == generator_params.eos_token_id) or len(
self.state[correlation_id].tokens
) >= max_tokens:
flags = pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL
request.set_release_flags(pb_utils.TRITONSERVER_REQUEST_RELEASE_ALL)
del self.state[correlation_id]
else:
request.set_release_flags(
pb_utils.TRITONSERVER_REQUEST_RELEASE_RESCHEDULE
)
flags = 0

output_decoded = self.tokenizer.decode(generated_token)
response = pb_utils.InferenceResponse(
output_tensors=[
pb_utils.Tensor(
"text_output", np.array([output_decoded], dtype=np.object_)
)
]
)
response_sender.send(response, flags=flags)

def execute(self, requests):
generator_params = self.create_batch(requests)

generator = og.Generator(self.model, generator_params)

# Compute the logits and generate the next token
generator.compute_logits()
generator.generate_next_token()
outputs = generator.get_next_tokens()

# Send the responses for every request
self.send_responses(requests, outputs, generator_params)
Loading
Loading