-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #45 from athina-ai/feature/develop_flow_refactor
Feature/develop flow refactor
- Loading branch information
Showing
6 changed files
with
242 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from athina.datasets.dataset import Dataset | ||
|
||
__all__ = [ | ||
'Dataset' | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
import json | ||
from typing import Any, List, Optional | ||
import requests | ||
from dataclasses import dataclass, field, asdict | ||
from athina.services.athina_api_service import AthinaApiService | ||
|
||
@dataclass | ||
class DatasetRow: | ||
query: Optional[str] = None | ||
context: Optional[List[str]] = None | ||
response: Optional[str] = None | ||
expected_response: Optional[str] = None | ||
|
||
@dataclass | ||
class Dataset: | ||
id: str | ||
source: str | ||
name: str | ||
description: Optional[str] = None | ||
language_model_id: Optional[str] = None | ||
prompt_template: Optional[Any] = None | ||
rows: List[DatasetRow] = field(default_factory=list) | ||
|
||
@staticmethod | ||
def create(name: str, description: Optional[str] = None, language_model_id: Optional[str] = None, prompt_template: Optional[Any] = None, rows: List[DatasetRow] = None): | ||
""" | ||
Creates a new dataset with the specified properties. | ||
Parameters: | ||
- name (str): The name of the dataset. This is a required field. | ||
- description (Optional[str]): An optional textual description of the dataset, providing additional context. | ||
- language_model_id (Optional[str]): An optional identifier for the language model associated with this dataset. | ||
- prompt_template (Optional[Any]): An optional template for prompts used in this dataset. | ||
Returns: | ||
The newly created dataset object | ||
Raises: | ||
- Exception: If the dataset could not be created due to an error like invalid parameters, database errors, etc. | ||
""" | ||
dataset_data = { | ||
"source": "dev_sdk", | ||
"name": name, | ||
"description": description, | ||
"language_model_id": language_model_id, | ||
"prompt_template": prompt_template, | ||
"dataset_rows": rows or [] | ||
} | ||
|
||
try: | ||
created_dataset_data = AthinaApiService.create_dataset(dataset_data) | ||
except Exception as e: | ||
raise | ||
dataset = Dataset(id=created_dataset_data['id'], source=created_dataset_data['source'], name=created_dataset_data['name'], description=created_dataset_data['description'], language_model_id=created_dataset_data['language_model_id'], prompt_template=created_dataset_data['prompt_template']) | ||
return dataset | ||
|
||
@staticmethod | ||
def add_rows(dataset_id: str, rows: List[DatasetRow]): | ||
""" | ||
Adds rows to a dataset in batches of 100. | ||
Parameters: | ||
- dataset_id (str): The ID of the dataset to add rows to. | ||
- rows (List[DatasetRow]): The rows to add to the dataset. | ||
Raises: | ||
- Exception: If the API returns an error or the limit of 1000 rows is exceeded. | ||
""" | ||
batch_size = 100 | ||
for i in range(0, len(rows), batch_size): | ||
batch = rows[i:i+batch_size] | ||
try: | ||
AthinaApiService.add_dataset_rows(dataset_id, batch) | ||
except Exception as e: | ||
raise | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import os\n", | ||
"from athina.datasets import Dataset\n", | ||
"from athina.keys import AthinaApiKey\n", | ||
"api_key = os.getenv('ATHINA_API_KEY')\n", | ||
"if not api_key:\n", | ||
" raise ValueError(\"ATHINA_API_KEY environment variable is not set.\")\n", | ||
"AthinaApiKey.set_key(api_key)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"try:\n", | ||
" dataset = Dataset.create(\n", | ||
" name='test_dataset_15',\n", | ||
" description='This is a test dataset',\n", | ||
" language_model_id='gpt-4',\n", | ||
" rows=[\n", | ||
" {\n", | ||
" 'query': 'What is the capital of Greece?',\n", | ||
" 'context': ['Greece is a country in southeastern Europe.', 'Athens is the capital of Greece.'],\n", | ||
" 'response': 'Athens',\n", | ||
" 'expected_response': 'Athens'\n", | ||
" }\n", | ||
" ]\n", | ||
" )\n", | ||
"except Exception as e:\n", | ||
" print(f\"Failed to create dataset: {e}\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"try:\n", | ||
" Dataset.add_rows(\n", | ||
" dataset_id=dataset.id,\n", | ||
" rows=[\n", | ||
" {\n", | ||
" 'query': 'What is the capital of France?',\n", | ||
" 'context': ['France is a country in Western Europe.', 'Paris is the capital of France.'],\n", | ||
" 'response': 'Paris',\n", | ||
" 'expected_response': 'Paris'\n", | ||
" },\n", | ||
" ]\n", | ||
" )\n", | ||
"except Exception as e:\n", | ||
" print(f\"Failed to add rows more than 1000: {e}\")" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": ".venv", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.9.6" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters