Skip to content

Commit

Permalink
Merge pull request #162 from athina-ai/vivek/7th-jan-bug-fixes
Browse files Browse the repository at this point in the history
Vivek/7th jan bug fixes
  • Loading branch information
vivek-athina authored Jan 11, 2025
2 parents b3c7b2f + 4e1525b commit 7cb484e
Show file tree
Hide file tree
Showing 6 changed files with 146 additions and 85 deletions.
62 changes: 40 additions & 22 deletions athina/steps/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,6 @@
from jinja2 import Environment


def prepare_input_data(data: Dict[str, Any]) -> Dict[str, Any]:
"""Prepare input data by converting complex types to JSON strings."""
return {
key: json.dumps(value) if isinstance(value, (list, dict)) else value
for key, value in data.items()
}


def prepare_template_data(
env: Environment,
template_dict: Optional[Dict[str, str]],
Expand All @@ -31,6 +23,19 @@ def prepare_template_data(
return prepared_dict


def debug_json_structure(body_str: str, error: json.JSONDecodeError) -> dict:
"""Analyze JSON structure and identify problematic keys."""
lines = body_str.split("\n")
error_line_num = error.lineno - 1

return {
"original_body": body_str,
"problematic_line": (
lines[error_line_num] if error_line_num < len(lines) else None
),
}


def prepare_body(
env: Environment, body_template: Optional[str], input_data: Dict[str, Any]
) -> Optional[str]:
Expand Down Expand Up @@ -112,31 +117,44 @@ async def execute_async(self, input_data: Any) -> Union[Dict[str, Any], None]:
)
# Prepare the environment and input data
self.env = self._create_jinja_env()
prepared_input_data = prepare_input_data(input_data)

# Prepare request components
prepared_body = prepare_body(self.env, self.body, prepared_input_data)
prepared_headers = prepare_template_data(
self.env, self.headers, prepared_input_data
)
prepared_params = prepare_template_data(
self.env, self.params, prepared_input_data
)
prepared_body = prepare_body(self.env, self.body, input_data)
prepared_headers = prepare_template_data(self.env, self.headers, input_data)
prepared_params = prepare_template_data(self.env, self.params, input_data)
# Prepare the URL by rendering the template
prepared_url = self.env.from_string(self.url).render(**input_data)

timeout = aiohttp.ClientTimeout(total=self.timeout)

for attempt in range(self.retries):
try:
async with aiohttp.ClientSession(timeout=timeout) as session:
json_body = (
json.loads(prepared_body, strict=False)
if prepared_body
else None
)
try:
json_body = (
json.loads(prepared_body, strict=False)
if prepared_body
else None
)
except json.JSONDecodeError as e:
debug_info = debug_json_structure(prepared_body, e)
return self._create_step_result(
status="error",
data=json.dumps(
{
"message": f"Failed to parse request body as JSON",
"error_type": "JSONDecodeError",
"error_details": str(e),
"debug_info": debug_info,
},
indent=2,
),
start_time=start_time,
)

async with session.request(
method=self.method,
url=self.url,
url=prepared_url,
headers=prepared_headers,
params=prepared_params,
json=json_body,
Expand Down
46 changes: 19 additions & 27 deletions athina/steps/chroma_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class ChromaRetrieval(Step):
port (int): The port of the Chroma server.
collection_name (str): The name of the Chroma collection.
limit (int): The maximum number of results to fetch.
input_column (str): The column name in the input data.
user_query (str): the query which will be sent to chroma.
openai_api_key (str): The OpenAI API key.
auth_type (str): The authentication type for the Chroma server (e.g., "token" or "basic").
auth_credentials (str): The authentication credentials for the Chroma server.
Expand All @@ -35,9 +35,8 @@ class ChromaRetrieval(Step):
host: str
port: int
collection_name: str
key: str
limit: int
input_column: str
user_query: str
openai_api_key: str
auth_type: Optional[AuthType] = None
auth_credentials: Optional[str] = None
Expand Down Expand Up @@ -76,12 +75,6 @@ def __init__(self, *args, **kwargs):
self._collection = self._client.get_collection(
name=self.collection_name, embedding_function=self._embedding_function
)
# Create a custom Jinja2 environment with double curly brace delimiters and PreserveUndefined
self.env = Environment(
variable_start_string="{{",
variable_end_string="}}",
undefined=PreserveUndefined,
)

"""Makes a call to chromadb collection to fetch relevant chunks"""

Expand All @@ -95,31 +88,30 @@ def execute(self, input_data: Any) -> Union[Dict[str, Any], None]:
start_time=start_time,
)

query = input_data.get(self.input_column)
if query is None:
self.env = self._create_jinja_env()

query_text = self.env.from_string(self.user_query).render(**input_data)

if query_text is None:
return self._create_step_result(
status="error",
data="Input column not found.",
start_time=start_time,
status="error", data="Query text is Empty.", start_time=start_time
)

try:
if isinstance(query, list) and isinstance(query[0], float):
response = self._collection.query(
query_embeddings=[query],
n_results=self.limit,
include=["documents", "metadatas", "distances"],
)
else:
response = self._collection.query(
query_texts=[query],
n_results=self.limit,
include=["documents", "metadatas", "distances"],
response = self._collection.query(
query_texts=[query_text],
n_results=self.limit,
include=["documents", "metadatas", "distances"],
)
result = [
{"text": text, "score": distance}
for text, distance in zip(
response["documents"][0], response["distances"][0]
)

]
return self._create_step_result(
status="success",
data=response["documents"][0],
data=result,
start_time=start_time,
)
except Exception as e:
Expand Down
48 changes: 35 additions & 13 deletions athina/steps/pinecone_retrieval.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
# Step to make a call to pinecone index to fetch relevent chunks
import pinecone
from typing import Optional, Union, Dict, Any

from pydantic import Field, PrivateAttr
Expand All @@ -9,6 +7,7 @@
from llama_index.core import VectorStoreIndex
from llama_index.core.retrievers import VectorIndexRetriever
import time
import traceback


class PineconeRetrieval(Step):
Expand All @@ -22,36 +21,48 @@ class PineconeRetrieval(Step):
metadata_filters: filters to apply to metadata.
environment: pinecone environment.
api_key: api key for the pinecone server
input_column: column name in the input data
user_query: the query which will be sent to pinecone
env: jinja environment
"""

index_name: str
top_k: int
api_key: str
input_column: str
user_query: str
env: Environment = None
metadata_filters: Optional[Dict[str, Any]] = None
namespace: Optional[str] = None
environment: Optional[str] = None
text_key: Optional[str] = None # Optional parameter for text key
_vector_store: PineconeVectorStore = PrivateAttr()
_vector_index: VectorStoreIndex = PrivateAttr()
_retriever: VectorIndexRetriever = PrivateAttr()

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Initialize base vector store arguments
vector_store_args = {"api_key": self.api_key, "index_name": self.index_name}
# Add text_key only if specified by user
if self.text_key:
vector_store_args["text_key"] = self.text_key

# Only add environment if it's provided
if self.environment is not None:
vector_store_args["environment"] = self.environment

if self.namespace is not None:
# Only add namespace if it's provided and not None
if self.namespace:
vector_store_args["namespace"] = self.namespace

# Initialize vector store with filtered arguments
self._vector_store = PineconeVectorStore(**vector_store_args)

# Create vector index from store
self._vector_index = VectorStoreIndex.from_vector_store(
vector_store=self._vector_store
)

# Initialize retriever with specified top_k
self._retriever = VectorIndexRetriever(
index=self._vector_index, similarity_top_k=self.top_k
)
Expand All @@ -60,9 +71,10 @@ class Config:
arbitrary_types_allowed = True

def execute(self, input_data: Any) -> Union[Dict[str, Any], None]:
"""makes a call to pinecone index to fetch relevent chunks"""
"""Makes a call to pinecone index to fetch relevant chunks"""
start_time = time.perf_counter()

# Validate input data
if input_data is None:
input_data = {}

Expand All @@ -73,26 +85,36 @@ def execute(self, input_data: Any) -> Union[Dict[str, Any], None]:
start_time=start_time,
)

input_text = input_data.get(self.input_column, None)
# Create Jinja environment and render query
self.env = self._create_jinja_env()
query_text = self.env.from_string(self.user_query).render(**input_data)

if input_text is None:
if not query_text:
return self._create_step_result(
status="error",
data="Input column not found.",
data="Query text is Empty.",
start_time=start_time,
)

try:
response = self._retriever.retrieve(input_text)
result = [node.get_content() for node in response]
# Perform retrieval
response = self._retriever.retrieve(query_text)
result = [
{
"text": node.get_content(),
"score": node.get_score(),
}
for node in response
]
return self._create_step_result(
status="success", data=result, start_time=start_time
)
return self._create_step_result(
status="success",
data=result,
start_time=start_time,
)
except Exception as e:
import traceback

traceback.print_exc()
print(f"Error during retrieval: {str(e)}")
return self._create_step_result(
Expand Down
16 changes: 11 additions & 5 deletions athina/steps/qdrant_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@ class QdrantRetrieval(Step):
url: url of the qdrant server
top_k: How many chunks to fetch.
api_key: api key for the qdrant server
input_column: the query which will be sent to qdrant
user_query: the query which will be sent to qdrant
env: jinja environment
"""

collection_name: str
url: str
top_k: int
api_key: str
input_column: str
user_query: str
env: Environment = None
_qdrant_client: qdrant_client.QdrantClient = PrivateAttr()
_vector_store: QdrantVectorStore = PrivateAttr()
Expand Down Expand Up @@ -70,11 +70,11 @@ def execute(self, input_data: Any) -> Union[Dict[str, Any], None]:

self.env = self._create_jinja_env()

query_text = self.env.from_string(self.input_column).render(**input_data)
query_text = self.env.from_string(self.user_query).render(**input_data)

if query_text is None:
return self._create_step_result(
status="error", data="Query text not found.", start_time=start_time
status="error", data="Query text is Empty.", start_time=start_time
)

try:
Expand All @@ -84,7 +84,13 @@ def execute(self, input_data: Any) -> Union[Dict[str, Any], None]:
return self._create_step_result(
status="success", data=[], start_time=start_time
)
result = [node.get_content() for node in response]
result = [
{
"text": node.get_content(),
"score": node.get_score(),
}
for node in response
]
return self._create_step_result(
status="success", data=result, start_time=start_time
)
Expand Down
Loading

0 comments on commit 7cb484e

Please sign in to comment.