Skip to content

Commit

Permalink
Merge pull request #45 from athina-ai/feature/develop_flow_refactor
Browse files Browse the repository at this point in the history
Feature/develop flow refactor
  • Loading branch information
akshat-g authored Apr 7, 2024
2 parents dc042f0 + 8398266 commit fc955e3
Show file tree
Hide file tree
Showing 6 changed files with 242 additions and 9 deletions.
5 changes: 5 additions & 0 deletions athina/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from athina.datasets.dataset import Dataset

__all__ = [
'Dataset'
]
77 changes: 77 additions & 0 deletions athina/datasets/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import json
from typing import Any, List, Optional
import requests
from dataclasses import dataclass, field, asdict
from athina.services.athina_api_service import AthinaApiService

@dataclass
class DatasetRow:
query: Optional[str] = None
context: Optional[List[str]] = None
response: Optional[str] = None
expected_response: Optional[str] = None

@dataclass
class Dataset:
id: str
source: str
name: str
description: Optional[str] = None
language_model_id: Optional[str] = None
prompt_template: Optional[Any] = None
rows: List[DatasetRow] = field(default_factory=list)

@staticmethod
def create(name: str, description: Optional[str] = None, language_model_id: Optional[str] = None, prompt_template: Optional[Any] = None, rows: List[DatasetRow] = None):
"""
Creates a new dataset with the specified properties.
Parameters:
- name (str): The name of the dataset. This is a required field.
- description (Optional[str]): An optional textual description of the dataset, providing additional context.
- language_model_id (Optional[str]): An optional identifier for the language model associated with this dataset.
- prompt_template (Optional[Any]): An optional template for prompts used in this dataset.
Returns:
The newly created dataset object
Raises:
- Exception: If the dataset could not be created due to an error like invalid parameters, database errors, etc.
"""
dataset_data = {
"source": "dev_sdk",
"name": name,
"description": description,
"language_model_id": language_model_id,
"prompt_template": prompt_template,
"dataset_rows": rows or []
}

try:
created_dataset_data = AthinaApiService.create_dataset(dataset_data)
except Exception as e:
raise
dataset = Dataset(id=created_dataset_data['id'], source=created_dataset_data['source'], name=created_dataset_data['name'], description=created_dataset_data['description'], language_model_id=created_dataset_data['language_model_id'], prompt_template=created_dataset_data['prompt_template'])
return dataset

@staticmethod
def add_rows(dataset_id: str, rows: List[DatasetRow]):
"""
Adds rows to a dataset in batches of 100.
Parameters:
- dataset_id (str): The ID of the dataset to add rows to.
- rows (List[DatasetRow]): The rows to add to the dataset.
Raises:
- Exception: If the API returns an error or the limit of 1000 rows is exceeded.
"""
batch_size = 100
for i in range(0, len(rows), batch_size):
batch = rows[i:i+batch_size]
try:
AthinaApiService.add_dataset_rows(dataset_id, batch)
except Exception as e:
raise



3 changes: 1 addition & 2 deletions athina/evals/eval_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@ class FunctionEvalTypeId(Enum):
NOT_GIBBERISH_TEXT = "NotGibberishText"
CONTAINS_NO_SENSITIVE_TOPICS = "ContainsNoSensitiveTopics"
OPENAI_CONTENT_MODERATION = "OpenAiContentModeration"



class GroundedEvalTypeId(Enum):
ANSWER_SIMILARITY = "AnswerSimilarity"
CONTEXT_SIMILARITY = "ContextSimilarity"
Expand Down
68 changes: 67 additions & 1 deletion athina/services/athina_api_service.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import pkg_resources
import requests
from dataclasses import asdict
from retrying import retry
from typing import List, Optional
from typing import List, Optional, Dict
from athina.errors.exceptions import NoAthinaApiKeyException
from athina.interfaces.athina import (
AthinaFilters,
Expand Down Expand Up @@ -119,6 +120,71 @@ def log_eval_results(
str(e),
)
raise

@staticmethod
def create_dataset(
dataset: Dict
):
"""
Creates a dataset by calling the Athina API
"""
try:
endpoint = f"{API_BASE_URL}/api/v1/dataset"
response = requests.post(
endpoint,
headers=AthinaApiService._headers(),
json=dataset,
)
if response.status_code == 401:
response_json = response.json()
error_message = response_json.get('error', 'Unknown Error')
details_message = 'please check your athina api key and try again'
raise CustomException(error_message, details_message)
elif response.status_code != 200 and response.status_code != 201:
response_json = response.json()
error_message = response_json.get('error', 'Unknown Error')
details_message = response_json.get(
'details', {}).get('message', 'No Details')
raise CustomException(error_message, details_message)
return response.json()['data']['dataset']
except Exception as e:
raise

@staticmethod
def add_dataset_rows(dataset_id: str, rows: List[Dict]):
"""
Adds rows to a dataset by calling the Athina API.
Parameters:
- dataset_id (str): The ID of the dataset to which rows are added.
- rows (List[Dict]): A list of rows to add to the dataset, where each row is represented as a dictionary.
Returns:
The API response data for the dataset after adding the rows.
Raises:
- CustomException: If the API call fails or returns an error.
"""
try:
endpoint = f"{API_BASE_URL}/api/v1/dataset/{dataset_id}/add-rows"
response = requests.post(
endpoint,
headers=AthinaApiService._headers(),
json={"dataset_rows": rows},
)
if response.status_code == 401:
response_json = response.json()
error_message = response_json.get('error', 'Unknown Error')
details_message = 'please check your athina api key and try again'
raise CustomException(error_message, details_message)
elif response.status_code != 200 and response.status_code != 201:
response_json = response.json()
error_message = response_json.get('error', 'Unknown Error')
details_message = response_json.get('details', {}).get('message', 'No Details')
raise CustomException(error_message, details_message)
return response.json()['data']
except Exception as e:
raise

@staticmethod
def create_eval_request(
Expand Down
86 changes: 86 additions & 0 deletions examples/dataset_creation.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"from athina.datasets import Dataset\n",
"from athina.keys import AthinaApiKey\n",
"api_key = os.getenv('ATHINA_API_KEY')\n",
"if not api_key:\n",
" raise ValueError(\"ATHINA_API_KEY environment variable is not set.\")\n",
"AthinaApiKey.set_key(api_key)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"try:\n",
" dataset = Dataset.create(\n",
" name='test_dataset_15',\n",
" description='This is a test dataset',\n",
" language_model_id='gpt-4',\n",
" rows=[\n",
" {\n",
" 'query': 'What is the capital of Greece?',\n",
" 'context': ['Greece is a country in southeastern Europe.', 'Athens is the capital of Greece.'],\n",
" 'response': 'Athens',\n",
" 'expected_response': 'Athens'\n",
" }\n",
" ]\n",
" )\n",
"except Exception as e:\n",
" print(f\"Failed to create dataset: {e}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"try:\n",
" Dataset.add_rows(\n",
" dataset_id=dataset.id,\n",
" rows=[\n",
" {\n",
" 'query': 'What is the capital of France?',\n",
" 'context': ['France is a country in Western Europe.', 'Paris is the capital of France.'],\n",
" 'response': 'Paris',\n",
" 'expected_response': 'Paris'\n",
" },\n",
" ]\n",
" )\n",
"except Exception as e:\n",
" print(f\"Failed to add rows more than 1000: {e}\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
12 changes: 6 additions & 6 deletions examples/guard.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@
"Guarding query\n",
"-----------------------\n",
"\n",
"OpenAI Content Moderation: Passed in 454ms - The text was not flagged\n",
"Prompt Injection: Failed in 734ms - Prompt injection detected with a score of 0.9999991655349731.\n",
"OpenAI Content Moderation: Passed in 472ms - The text was not flagged\n",
"Prompt Injection: Failed in 576ms - Prompt injection detected with a score of 0.9999991655349731.\n",
"\n",
"ERROR: Detected a bad query. Allowing the query, but sent an alert on Slack.\n"
]
Expand Down Expand Up @@ -108,7 +108,7 @@
"Guarding query\n",
"-----------------------\n",
"\n",
"OpenAI Content Moderation: Failed in 353ms - The text was flagged in these categories: hate, harassment, hate/threatening, harassment/threatening, violence\n",
"OpenAI Content Moderation: Failed in 301ms - The text was flagged in these categories: hate, harassment, hate/threatening, harassment/threatening, violence\n",
"\n",
"ERROR: Detected a bad query. Allowing the query, but sent an alert on Slack.\n"
]
Expand Down Expand Up @@ -162,7 +162,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 5,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -197,7 +197,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 6,
"metadata": {},
"outputs": [
{
Expand All @@ -216,7 +216,7 @@
"\n",
"\n",
"Response should not mention competitors: Passed in 0ms - No keywords found in output\n",
"PII Detection: Failed in 1023ms - ['FIRSTNAME detected: Alt', 'FIRSTNAME detected: man', 'MASKEDNUMBER detected: 0x34932942984194912488439']\n",
"PII Detection: Failed in 1096ms - ['FIRSTNAME detected: Alt', 'FIRSTNAME detected: man', 'MASKEDNUMBER detected: 0x34932942984194912488439']\n",
"\n",
"ERROR: Detected a bad response. Fallback strategy initiated.\n",
"Safe response: I'm sorry, I can't help with that.\n"
Expand Down

0 comments on commit fc955e3

Please sign in to comment.