Skip to content

Commit

Permalink
Merge pull request #158 from athina-ai/vivek/bug-fixes-2nd-jan
Browse files Browse the repository at this point in the history
Vivek/bug fixes 2nd jan
  • Loading branch information
vivek-athina authored Jan 4, 2025
2 parents 108e94f + cbe53b0 commit eee7622
Show file tree
Hide file tree
Showing 17 changed files with 562 additions and 301 deletions.
102 changes: 52 additions & 50 deletions athina/steps/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
import time
from typing import Union, Dict, Any, Optional
import aiohttp
from jinja2 import Environment
from athina.helpers.jinja_helper import PreserveUndefined
from athina.steps.base import Step
import asyncio
from jinja2 import Environment


def prepare_input_data(data: Dict[str, Any]) -> Dict[str, Any]:
Expand All @@ -17,15 +16,6 @@ def prepare_input_data(data: Dict[str, Any]) -> Dict[str, Any]:
}


def create_jinja_env() -> Environment:
"""Create a Jinja2 environment with custom settings."""
return Environment(
variable_start_string="{{",
variable_end_string="}}",
undefined=PreserveUndefined,
)


def prepare_template_data(
env: Environment,
template_dict: Optional[Dict[str, str]],
Expand All @@ -51,33 +41,6 @@ def prepare_body(
return env.from_string(body_template).render(**input_data)


def process_response(
status_code: int,
response_text: str,
) -> Dict[str, Any]:
"""Process the API response and return formatted result."""
if status_code >= 400:
# If the status code is an error, return the error message
return {
"status": "error",
"data": f"Failed to make the API call.\nStatus code: {status_code}\nError:\n{response_text}",
}

try:
json_response = json.loads(response_text)
# If the response is JSON, return the JSON data
return {
"status": "success",
"data": json_response,
}
except json.JSONDecodeError:
# If the response is not JSON, return the text
return {
"status": "success",
"data": response_text,
}


class ApiCall(Step):
"""
Step that makes an external API call.
Expand All @@ -103,17 +66,52 @@ class ApiCall(Step):
class Config:
arbitrary_types_allowed = True

def process_response(
self,
status_code: int,
response_text: str,
start_time: float,
) -> Dict[str, Any]:
"""Process the API response and return formatted result."""
if status_code >= 400:
# If the status code is an error, return the error message
return self._create_step_result(
status="error",
data=f"Failed to make the API call.\nStatus code: {status_code}\nError:\n{response_text}",
start_time=start_time,
)

try:
json_response = json.loads(response_text)
# If the response is JSON, return the JSON data
return self._create_step_result(
status="success",
data=json_response,
start_time=start_time,
)
except json.JSONDecodeError:
# If the response is not JSON, return the text
return self._create_step_result(
status="success",
data=response_text,
start_time=start_time,
)

async def execute_async(self, input_data: Any) -> Union[Dict[str, Any], None]:
"""Make an async API call and return the response."""
start_time = time.perf_counter()

if input_data is None:
input_data = {}

if not isinstance(input_data, dict):
raise TypeError("Input data must be a dictionary.")

return self._create_step_result(
status="error",
data="Input data must be a dictionary.",
start_time=start_time,
)
# Prepare the environment and input data
self.env = create_jinja_env()
self.env = self._create_jinja_env()
prepared_input_data = prepare_input_data(input_data)

# Prepare request components
Expand Down Expand Up @@ -144,23 +142,27 @@ async def execute_async(self, input_data: Any) -> Union[Dict[str, Any], None]:
json=json_body,
) as response:
response_text = await response.text()
return process_response(response.status, response_text)
return self.process_response(
response.status, response_text, start_time
)

except asyncio.TimeoutError:
if attempt < self.retries - 1:
await asyncio.sleep(2)
continue
# If the request times out after multiple attempts, return an error message
return {
"status": "error",
"data": "Failed to make the API call.\nRequest timed out after multiple attempts.",
}
return self._create_step_result(
status="error",
data="Failed to make the API call.\nRequest timed out after multiple attempts.",
start_time=start_time,
)
except Exception as e:
# If an exception occurs, return the error message
return {
"status": "error",
"data": f"Failed to make the API call.\nError: {e.__class__.__name__}\nDetails:\n{str(e)}",
}
return self._create_step_result(
status="error",
data=f"Failed to make the API call.\nError: {e.__class__.__name__}\nDetails:\n{str(e)}",
start_time=start_time,
)

def execute(self, input_data: Any) -> Union[Dict[str, Any], None]:
"""Synchronous execute api call that runs the async method in an event loop."""
Expand Down
50 changes: 49 additions & 1 deletion athina/steps/base.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import os
import json
import logging
from typing import Dict, Any, List, Iterable, Optional, Callable
from typing import Dict, Any, List, Iterable, Optional, Callable, TypedDict, Literal
from pydantic import BaseModel
from jinja2 import Environment
from athina.helpers.jinja_helper import PreserveUndefined
from athina.helpers.json import JsonHelper, JsonExtractor
from athina.llms.abstract_llm_service import AbstractLlmService
from athina.llms.openai_service import OpenAiService
from athina.keys import OpenAiApiKey
import functools
import time


# Configure logging
Expand All @@ -22,6 +25,12 @@ class StepError(Exception):
pass


class StepResult(TypedDict):
status: Literal["success", "error"]
data: str
metadata: Dict[str, Any]


def step(func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
Expand Down Expand Up @@ -81,6 +90,45 @@ def extract_input_data(self, context: Dict[str, Any]) -> Any:
input_data = context
return input_data

def _create_step_result(
self,
status: Literal["success", "error"],
data: str,
start_time: float,
metadata: Dict[str, Any] = {},
exported_vars: Optional[Dict] = None,
) -> StepResult:
"""
Create a standardized result object for step execution.
Args:
status: Step execution status ("success" or "error")
data: Output data or error message
start_time: Time when step started execution (from perf_counter)
metadata: Optional dictionary of metadata
exported_vars: Optional dictionary of exported variables
"""
if "response_time" not in metadata:
execution_time_ms = round((time.perf_counter() - start_time) * 1000)
metadata = {"response_time": execution_time_ms}

if exported_vars is not None:
metadata["exported_vars"] = exported_vars

return {"status": status, "data": data, "metadata": metadata}

def _create_jinja_env(
self,
variable_start_string: str = "{{",
variable_end_string: str = "}}",
) -> Environment:
"""Create a Jinja2 environment with custom settings."""
return Environment(
variable_start_string=variable_start_string,
variable_end_string=variable_end_string,
undefined=PreserveUndefined,
)

@step
def run(
self,
Expand Down
27 changes: 23 additions & 4 deletions athina/steps/chroma_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from jinja2 import Environment
from athina.helpers.jinja_helper import PreserveUndefined
from chromadb.utils.embedding_functions import OpenAIEmbeddingFunction
import time


class AuthType(str, Enum):
Expand Down Expand Up @@ -85,12 +86,22 @@ def __init__(self, *args, **kwargs):
"""Makes a call to chromadb collection to fetch relevant chunks"""

def execute(self, input_data: Any) -> Union[Dict[str, Any], None]:
start_time = time.perf_counter()

if input_data is None or not isinstance(input_data, dict):
return {"status": "error", "data": "Input data must be a dictionary."}
return self._create_step_result(
status="error",
data="Input data must be a dictionary.",
start_time=start_time,
)

query = input_data.get(self.input_column)
if query is None:
return {"status": "error", "data": "Input column not found."}
return self._create_step_result(
status="error",
data="Input column not found.",
start_time=start_time,
)

try:
if isinstance(query, list) and isinstance(query[0], float):
Expand All @@ -106,9 +117,17 @@ def execute(self, input_data: Any) -> Union[Dict[str, Any], None]:
include=["documents", "metadatas", "distances"],
)

return {"status": "success", "data": response["documents"][0]}
return self._create_step_result(
status="success",
data=response["documents"][0],
start_time=start_time,
)
except Exception as e:
return {"status": "error", "data": str(e)}
return self._create_step_result(
status="error",
data=str(e),
start_time=start_time,
)

def close(self):
if self._client:
Expand Down
33 changes: 22 additions & 11 deletions athina/steps/classify_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Union, Dict, Any
from athina.steps import Step
import marvin
import time


class ClassifyText(Step):
Expand All @@ -22,17 +23,25 @@ class ClassifyText(Step):

def execute(self, input_data: Any) -> Union[Dict[str, Any], None]:
"""Classify the text and return the label."""
start_time = time.perf_counter()

if input_data is None:
input_data = {}

if not isinstance(input_data, dict):
raise TypeError("Input data must be a dictionary.")

return self._create_step_result(
status="error",
data="Input data must be a dictionary.",
start_time=start_time,
)
input_text = input_data.get(self.input_column, None)

if input_text is None:
return None
return self._create_step_result(
status="error",
data="Input column not found.",
start_time=start_time,
)

marvin.settings.openai.api_key = self.llm_api_key
marvin.settings.openai.chat.completions.model = self.language_model_id
Expand All @@ -42,12 +51,14 @@ def execute(self, input_data: Any) -> Union[Dict[str, Any], None]:
input_text,
labels=self.labels,
)
return {
"status": "success",
"data": result,
}
return self._create_step_result(
status="success",
data=result,
start_time=start_time,
)
except Exception as e:
return {
"status": "error",
"data": str(e),
}
return self._create_step_result(
status="error",
data=str(e),
start_time=start_time,
)
Loading

0 comments on commit eee7622

Please sign in to comment.