Skip to content

Commit

Permalink
Merge pull request #165 from athina-ai/loop-step
Browse files Browse the repository at this point in the history
Added loop step
  • Loading branch information
vivek-athina authored Jan 30, 2025
2 parents de3d23f + 2043621 commit b0ffd44
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 1 deletion.
2 changes: 2 additions & 0 deletions athina/steps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from athina.steps.conditional import ConditionalStep
from athina.steps.chain import Chain
from athina.steps.iterator import Map
from athina.steps.loop import Loop
from athina.steps.llm import PromptExecution
from athina.steps.api import ApiCall
from athina.steps.extract_entities import ExtractEntities
Expand Down Expand Up @@ -42,4 +43,5 @@
"SpiderCrawl",
"ParseDocument",
"ConditionalStep",
"Loop",
]
95 changes: 95 additions & 0 deletions athina/steps/loop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import asyncio
from typing import Dict, List, Any, Optional
from athina.steps.base import Step
from concurrent.futures import ThreadPoolExecutor


class Loop(Step):
loop_type: str
loop_input: Optional[str]
loop_count: Optional[int]
sequence: List[Step]
execution_mode: Optional[str]
max_workers: int = 5

async def _execute_single_step(self, step: Step, context: Dict) -> Dict:
"""Execute a single step asynchronously using ThreadPoolExecutor."""
loop = asyncio.get_running_loop()
with ThreadPoolExecutor(max_workers=1) as executor:
return await loop.run_in_executor(
executor,
step.execute,
context
)

async def _execute_sequence(self, inputs: Dict, semaphore: asyncio.Semaphore) -> Dict:
"""Execute a sequence of steps asynchronously with proper context handling."""
async with semaphore:
context = inputs.copy()
executed_steps = []
final_output = None

for step in self.sequence:
result = await self._execute_single_step(step, context)
executed_steps.append(result)
context = {
**context,
f"{step.name}": result.get("data", {}),
}
final_output = result.get("data") # Ensure final output is correctly captured

return {
"status": "success",
"data": final_output, # Ensure only final result is returned
"metadata": {"executed_steps": executed_steps}
}

async def _execute_loop(self, inputs: Dict) -> Dict:
"""Handles loop execution, managing parallelism properly."""
semaphore = asyncio.Semaphore(self.max_workers if self.execution_mode == "parallel" else 1)
results = []

if self.loop_type == "map":
items = inputs.get(self.loop_input, [])
if not isinstance(items, list):
return {"status": "error", "data": "Input not of type list", "metadata": {}}

tasks = [
self._execute_sequence(
{**inputs, "item": item, "index": idx, "count": len(items)},
semaphore
)
for idx, item in enumerate(items)
]
else:
if not isinstance(self.loop_count, int) or self.loop_count <= 0:
return {"status": "error", "data": "Invalid loop count", "metadata": {}}

tasks = [
self._execute_sequence(
{**inputs, "index": i, "count": self.loop_count},
semaphore
)
for i in range(self.loop_count)
]

results = await asyncio.gather(*tasks) # Gather results concurrently

return {
"status": "success",
"data": [r["data"] for r in results], # Ensure correct final output format
"metadata": {"executed_steps": [r["metadata"] for r in results]}
}

def execute(self, inputs: Dict) -> Dict:
"""Handles execution, avoiding issues with already running event loops."""
try:
loop = asyncio.get_event_loop()
if loop.is_running():
future = asyncio.ensure_future(self._execute_loop(inputs))
loop.run_until_complete(future)
return future.result()
else:
return asyncio.run(self._execute_loop(inputs))
except Exception as e:
return {"status": "error", "data": str(e), "metadata": {}}
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "athina"
version = "1.7.9"
version = "1.7.10"
description = "Python SDK to configure and run evaluations for your LLM-based application"
authors = ["Shiv Sakhuja <[email protected]>", "Akshat Gupta <[email protected]>", "Vivek Aditya <[email protected]>", "Akhil Bisht <[email protected]>"]
readme = "README.md"
Expand Down

0 comments on commit b0ffd44

Please sign in to comment.