Skip to content

Commit

Permalink
Merge pull request #485 from MohTahsin/main
Browse files Browse the repository at this point in the history
Inference Profile Cost Tracing End-to-End Solution - initial commit - WIP
  • Loading branch information
claumazz authored Feb 14, 2025
2 parents 9780419 + e9e236d commit 02298c5
Show file tree
Hide file tree
Showing 21 changed files with 1,602 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# AWS Inference Profile Cost Tracing

This project automates the process of creating and setting up AWS Inference Profiles with cost tracing and monitoring capabilities. It leverages tags and custom CloudWatch dashboards to allow customers to monitor their usage and costs associated with invoking large language models (LLMs) from Anthropic's Bedrock service.

## Project Overview

The project operates based on a configuration file (`config.json`) that defines the AWS resources to be created, such as Inference Profiles, IAM roles, CloudWatch dashboards, and SNS topics for alerts. Each Inference Profile contains a set of tags that represent attributes like the customer account, application ID, model name, and environment.

When invoking an LLM through the deployed API Gateway, the project automatically associates the request with the appropriate Inference Profile based on the provided tags. It then publishes metrics to CloudWatch, including token counts and costs, enabling cost tracking and monitoring at a granular level.

## Getting Started

1. Clone the repository to your local machine.
2. Install the required dependencies (e.g., AWS CLI, Python libraries).
3. Configure your AWS credentials and region.
4. Modify the `config.json` file to suit your requirements (e.g., Inference Profile tags, cost thresholds, SNS email).
5. Run the `setup.py` script to create and deploy all necessary AWS resources.

```
python setup.py
```

6. After the setup is complete, you can invoke the LLM through the deployed API Gateway, passing the required headers (e.g., `inference-profile-id`, `region`, `tags`).

## Monitoring and Alerting

The project creates a custom CloudWatch dashboard named `BedrockInvocationDashboard` to visualize the metrics related to LLM invocations and costs. Additionally, it sets up an SNS topic (`BedrockInvocationAlarms`) to receive email alerts based on configurable thresholds for cost, token usage, and request counts.

## Customization

You can easily extend or modify the project to suit your specific needs. For example, you could add support for additional LLM providers, customize the dashboard layout, or integrate with other monitoring and alerting systems.

## Contributing

Contributions to this project are welcome. If you encounter any issues or have ideas for improvements, please open an issue or submit a pull request.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"aws_region": "us-west-2", "model_id": "anthropic.claude-3-haiku-20240307-v1:0", "sns_topic_name": "BedrockInvocationAlarms", "sns_email": "[email protected]", "lambda_function_name": "InvokeBedrockFunction", "lambda_role_arn": "arn:aws:iam::123456789:role/YOUR-ROLE", "api_name": "BedrockInvocationAPI", "api_stage": "dev", "dashboard_name": "BedrockInvocationDashboard", "cost_alarm_threshold": 1, "token_alarm_threshold": 2000, "request_alarm_threshold": 10, "profiles": [{"name": "customer1_websearch", "description": "For Customer-1 using Websearch Bot", "tags": [{"key": "CreatedBy", "value": "Dev-Account"}, {"key": "ApplicationID", "value": "Web-Search-Bot"}, {"key": "TenantID", "value": "Customer-1"}, {"key": "CustomerAccountID", "value": "123987456"}, {"key": "ModelProvider", "value": "Anthropic"}, {"key": "ModelName", "value": "Claude-Haiku"}, {"key": "Environment", "value": "Dev"}]}, {"name": "customer1_codingbot", "description": "For Customer-1 using Coding Assistant Bot", "tags": [{"key": "CreatedBy", "value": "Dev-Account"}, {"key": "ApplicationID", "value": "Coding-Assistant-Bot"}, {"key": "TenantID", "value": "Customer-1"}, {"key": "CustomerAccountID", "value": "456987123"}, {"key": "ModelProvider", "value": "Anthropic"}, {"key": "ModelName", "value": "Claude-Haiku"}, {"key": "Environment", "value": "Dev"}]}, {"name": "customer2_websearch", "description": "For Customer-2 using Websearch Bot", "tags": [{"key": "CreatedBy", "value": "Dev-Account"}, {"key": "ApplicationID", "value": "Web-Search-Bot"}, {"key": "TenantID", "value": "Customer-2"}, {"key": "CustomerAccountID", "value": "123456789"}, {"key": "ModelProvider", "value": "Anthropic"}, {"key": "ModelName", "value": "Claude-Haiku"}, {"key": "Environment", "value": "Dev"}]}, {"name": "customer2_codingbot", "description": "For Customer-2 using Coding Assistant Bot", "tags": [{"key": "CreatedBy", "value": "Dev-Account"}, {"key": "ApplicationID", "value": "Coding-Assistant-Bot"}, {"key": "TenantID", "value": "Customer-2"}, {"key": "CustomerAccountID", "value": "987654321"}, {"key": "ModelProvider", "value": "Anthropic"}, {"key": "ModelName", "value": "Claude-Haiku"}, {"key": "Environment", "value": "Dev"}]}]}

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,292 @@
import json
import boto3
from botocore.exceptions import ClientError
import os

def get_s3_file_content(bucket_name, object_key):
"""
Retrieve the content of a file from an S3 bucket.
Args:
bucket_name (str): The name of the S3 bucket.
object_key (str): The key of the object in the S3 bucket.
profile_name (str): The AWS profile name to use. Defaults to 'cost-tracing'.
Returns:
str: The content of the file.
Raises:
Exception: If there's an error retrieving the file.
"""
try:
# Create an S3 client
s3 = boto3.client('s3')

# Get the object
response = s3.get_object(Bucket=bucket_name, Key=object_key)

# Read the file content
file_content = response['Body'].read().decode('utf-8')

return file_content

except ClientError as e:
if e.response['Error']['Code'] == 'NoSuchKey':
raise Exception(f"The object {object_key} does not exist in bucket {bucket_name}")
elif e.response['Error']['Code'] == 'NoSuchBucket':
raise Exception(f"The bucket {bucket_name} does not exist")
else:
raise Exception(f"An error occurred: {e}")
except Exception as e:
raise Exception(f"An unexpected error occurred: {e}")


def has_matching_keys(original_list, comparison_list):
"""
Check if all dictionaries in comparison_list have the same key-value elements as original_list.
Parameters:
- original_list (list of dict): The reference list of dictionaries.
- comparison_list (list of dict): The list to compare against the original list.
Returns:
- bool: True if all dictionaries in comparison_list match original_list, False otherwise.
"""
pairs = zip(original_list, comparison_list)
return any(x != y for x, y in pairs)


def profile_lookup(payload_tags):
bucket_name = 'inference-cost-tracing'
cost_file = 'config/config.json'
config = json.loads(get_s3_file_content(bucket_name, cost_file))
for profile in config['profiles']:
if has_matching_keys(profile['tags'], payload_tags):
for _id in config['profile_ids']:
if profile['name'] == list(_id.keys()).pop():
return _id[profile['name']]


def lambda_handler(event, context):

headers = event.get('headers', {}) # extract headers from the API Gateway call
if not headers:
raise Exception("No message")

inference_profile_id = headers.get('inference-profile-id', None)
region = headers.get('region', None)
if not inference_profile_id:
tags_for_lookup = headers.get('tags', {})
if not tags_for_lookup:
inference_profile_id = None
else:
inference_profile_id = profile_lookup(tags_for_lookup)

bucket_name = 'inference-cost-tracing'
cost_file = 'config/models.json'
cost = get_s3_file_content(bucket_name, cost_file)
message = event.get('body', []) # extract the input data from the request body
if not message:
raise Exception("No message")

bedrock_client = boto3.client('bedrock', region_name=region)
inference_client = boto3.client("bedrock-runtime", region_name=region)
cloudwatch = boto3.client('cloudwatch', region_name=region)
cost_mapping = json.loads(cost)
region_cost_mapping = cost_mapping[region]
if inference_profile_id:
# Get model ID from the inference profile tags
inference_profile = bedrock_client.get_inference_profile(
inferenceProfileIdentifier=inference_profile_id
)
profile_name = inference_profile.get('inferenceProfileName', '')
profile_arn = inference_profile.get('inferenceProfileArn', '')
try:
model_arn = inference_profile.get('models', [])
model_id_tag = model_arn[0].get('modelArn', []).split('/')[-1]
except:
raise Exception(f"ModelName tag not found in Inference Profile '{profile_name}'.")

try:
tags_list = bedrock_client.list_tags_for_resource(resourceARN=profile_arn).get('tags', [])
tags = {d['key']: d['value'] for d in tags_list}

model_token_cost = region_cost_mapping['text'].get(model_id_tag, {})

model_id = model_id_tag # Use the ModelName tag value as model ID
# Invoke the model using invoke_model
response = inference_client.converse(
modelId=model_id,
messages=message
)
# Read the AI response
result = response['output']
input_token_cost = model_token_cost['input_cost'] * (response['usage']['inputTokens'] / 1000000)
output_token_cost = model_token_cost['output_cost'] * (response['usage']['outputTokens'] / 1000000)

# Publish counts to CloudWatch
cloudwatch.put_metric_data(
Namespace='BedrockInvocationTracing',
MetricData=[
### invocation data
{
'MetricName': 'InputTokens',
'Dimensions': [
{'Name': 'InferenceProfile', 'Value': profile_name},
],
'Value': response['usage']['inputTokens'],
'Unit': 'Count'
},
{
'MetricName': 'OutputTokens',
'Dimensions': [
{'Name': 'InferenceProfile', 'Value': profile_name},
],
'Value': response['usage']['outputTokens'],
'Unit': 'Count'
},
#### cost
{
'MetricName': 'InputTokensCost',
'Dimensions': [
{'Name': 'InferenceProfile', 'Value': profile_name},
],
'Value': input_token_cost,
'Unit': 'Count'
},
{
'MetricName': 'OutputTokensCost',
'Dimensions': [
{'Name': 'InferenceProfile', 'Value': profile_name},
],
'Value': output_token_cost,
'Unit': 'Count'
},
####
{
'MetricName': 'InvocationSuccess',
'Dimensions': [
{'Name': 'InferenceProfile', 'Value': profile_name},
],
'Value': 1,
'Unit': 'Count'
},
### Inference Profile data
{
'MetricName': tags['ModelName'],
'Dimensions': [
{'Name': 'InferenceProfile', 'Value': profile_name},
],
'Value': 1,
'Unit': 'Count'
},
{
'MetricName': tags['TenantID'],
'Dimensions': [
{'Name': 'InferenceProfile', 'Value': profile_name},
],
'Value': 1,
'Unit': 'Count'
},
{
'MetricName': tags['CreatedBy'],
'Dimensions': [
{'Name': 'InferenceProfile', 'Value': profile_name},
],
'Value': 1,
'Unit': 'Count'
},
{
'MetricName': tags['ModelProvider'],
'Dimensions': [
{'Name': 'InferenceProfile', 'Value': profile_name},
],
'Value': 1,
'Unit': 'Count'
},
{
'MetricName': tags['CustomerAccountID'],
'Dimensions': [
{'Name': 'InferenceProfile', 'Value': profile_name},
],
'Value': 1,
'Unit': 'Count'
},
{
'MetricName': tags['Environment'],
'Dimensions': [
{'Name': 'InferenceProfile', 'Value': profile_name},
],
'Value': 1,
'Unit': 'Count'
},
{
'MetricName': tags['ApplicationID'],
'Dimensions': [
{'Name': 'InferenceProfile', 'Value': profile_name},
],
'Value': 1,
'Unit': 'Count'
},
#####
]
)
# Return successful response
return {
'statusCode': 200,
'body': json.dumps(result)
}
except Exception as e:
error_message = str(e)
status_code = 500
# Set token counts to zero on error
input_token_count = 0
output_token_count = 0
print(error_message)
# Publish failure metric and token counts to CloudWatch
cloudwatch.put_metric_data(
Namespace='BedrockInvocationTracing',
MetricData=[
{
'MetricName': 'InvocationFailure',
'Dimensions': [
{'Name': 'InferenceProfile', 'Value': profile_name},
],
'Value': 1,
'Unit': 'Count'
},
{
'MetricName': 'InputTokens',
'Dimensions': [
{'Name': 'InferenceProfile', 'Value': profile_name},
],
'Value': input_token_count,
'Unit': 'Count'
},
{
'MetricName': 'OutputTokens',
'Dimensions': [
{'Name': 'InferenceProfile', 'Value': profile_name},
],
'Value': output_token_count,
'Unit': 'Count'
},
]
)

# If error return error response
return {
'statusCode': status_code,
'body': json.dumps({'error': error_message})
}
else:
model_id = event['model_id']
response = inference_client.converse(
modelId=model_id,
messages=message
)
# Return successful response
return {
'statusCode': 200,
'body': json.dumps(response['output'])
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
boto3==1.35.73
botocore==1.35.73
certifi==2024.8.30
charset-normalizer==3.4.0
idna==3.10
jmespath==1.0.1
python-dateutil==2.9.0.post0
PyYAML==6.0.2
requests==2.32.3
s3transfer==0.10.4
six==1.16.0
urllib3==2.2.3
streamlit
streamlit-authenticator

Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
s3_bucket_name = 'inference-cost-tracing'
s3_config_file = 'config/config.json'
Loading

0 comments on commit 02298c5

Please sign in to comment.