Skip to content

Commit

Permalink
Update Agent
Browse files Browse the repository at this point in the history
  • Loading branch information
ashpreetbedi committed Nov 24, 2024
1 parent aa4ecaf commit 5f13fbf
Show file tree
Hide file tree
Showing 8 changed files with 287 additions and 151 deletions.
42 changes: 42 additions & 0 deletions cookbook/agents/24_agent_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import json
import httpx
from typing import Optional

from phi.agent import Agent
from phi.utils.log import logger


def get_top_hackernews_stories(agent: Agent, num_stories: Optional[int]) -> str:
"""Use this function to get top stories from Hacker News.
Args:
num_stories (int): Number of stories to return. Defaults to 10.
Returns:
str: JSON string of top stories.
"""

logger.info(f"Agent context: {agent.context}")

# Fetch top story IDs
response = httpx.get("https://hacker-news.firebaseio.com/v0/topstories.json")
story_ids = response.json()

# Fetch story details
stories = []
for story_id in story_ids[:num_stories]:
story_response = httpx.get(f"https://hacker-news.firebaseio.com/v0/item/{story_id}.json")
story = story_response.json()
if "text" in story:
story.pop("text", None)
stories.append(story)
return json.dumps(stories)


agent = Agent(
tools=[get_top_hackernews_stories],
show_tool_calls=True,
markdown=True,
context={"id": "123"},
)
agent.print_response("Summarize the top story on hackernews?", stream=True)
213 changes: 121 additions & 92 deletions phi/agent/agent.py

Large diffs are not rendered by default.

14 changes: 8 additions & 6 deletions phi/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,9 @@ def get_tools_for_api(self) -> Optional[List[Dict[str, Any]]]:
tools_for_api.append(tool)
return tools_for_api

def add_tool(self, tool: Union[Tool, Toolkit, Callable, Dict, Function], structured_outputs: bool = False) -> None:
def add_tool(
self, tool: Union[Tool, Toolkit, Callable, Dict, Function], strict: bool = False, agent: Optional[Any] = None
) -> None:
if self.tools is None:
self.tools = []

Expand All @@ -133,7 +135,7 @@ def add_tool(self, tool: Union[Tool, Toolkit, Callable, Dict, Function], structu
self.tools.append(tool)
logger.debug(f"Added tool {tool} to model.")

# If the tool is a Callable or Toolkit, add its functions to the Model
# If the tool is a Callable or Toolkit, process and add to the Model
elif callable(tool) or isinstance(tool, Toolkit) or isinstance(tool, Function):
if self.functions is None:
self.functions = {}
Expand All @@ -143,15 +145,15 @@ def add_tool(self, tool: Union[Tool, Toolkit, Callable, Dict, Function], structu
for name, func in tool.functions.items():
# If the function does not exist in self.functions, add to self.tools
if name not in self.functions:
if structured_outputs and self.supports_structured_outputs:
if strict and self.supports_structured_outputs:
func.strict = True
self.functions[name] = func
self.tools.append({"type": "function", "function": func.to_dict()})
logger.debug(f"Function {name} from {tool.name} added to model.")

elif isinstance(tool, Function):
if tool.name not in self.functions:
if structured_outputs and self.supports_structured_outputs:
if strict and self.supports_structured_outputs:
tool.strict = True
self.functions[tool.name] = tool
self.tools.append({"type": "function", "function": tool.to_dict()})
Expand All @@ -161,8 +163,8 @@ def add_tool(self, tool: Union[Tool, Toolkit, Callable, Dict, Function], structu
try:
function_name = tool.__name__
if function_name not in self.functions:
func = Function.from_callable(tool)
if structured_outputs and self.supports_structured_outputs:
func = Function.from_callable(tool, agent)
if strict and self.supports_structured_outputs:
func.strict = True
self.functions[func.name] = func
self.tools.append({"type": "function", "function": func.to_dict()})
Expand Down
12 changes: 6 additions & 6 deletions phi/model/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
from phi.utils.log import logger


class MessageContext(BaseModel):
"""The context added to user message for RAG"""
class MessageReferences(BaseModel):
"""The references added to user message for RAG"""

# The query used to retrieve the context.
# The query used to retrieve the references.
query: str
# Documents from the vector database.
docs: Optional[List[Dict[str, Any]]] = None
# Time taken to retrieve the context.
# Time taken to retrieve the references.
time: Optional[float] = None


Expand Down Expand Up @@ -45,8 +45,8 @@ class Message(BaseModel):
# Metrics for the message. This is not sent to the Model API.
metrics: Dict[str, Any] = Field(default_factory=dict)

# The context added to the message for RAG
context: Optional[MessageContext] = None
# The references added to the message for RAG
references: Optional[MessageReferences] = None

# The Unix timestamp the message was created.
created_at: int = Field(default_factory=lambda: int(time()))
Expand Down
4 changes: 2 additions & 2 deletions phi/run/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pydantic import BaseModel, ConfigDict, Field

from phi.reasoning.step import ReasoningStep
from phi.model.message import Message, MessageContext
from phi.model.message import Message, MessageReferences


class RunEvent(str, Enum):
Expand All @@ -25,7 +25,7 @@ class RunEvent(str, Enum):


class RunResponseExtraData(BaseModel):
context: Optional[List[MessageContext]] = None
references: Optional[List[MessageReferences]] = None
add_messages: Optional[List[Message]] = None
history: Optional[List[Message]] = None
reasoning_steps: Optional[List[ReasoningStep]] = None
Expand Down
55 changes: 41 additions & 14 deletions phi/tools/function.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import Any, Dict, Optional, Callable, get_type_hints
from pydantic import BaseModel, validate_call
from typing import Any, Dict, Optional, Callable, get_type_hints, Type, TypeVar
from pydantic import BaseModel, Field, validate_call

from phi.utils.log import logger

T = TypeVar("T")


class Function(BaseModel):
"""Model for Functions"""
Expand All @@ -14,7 +16,10 @@ class Function(BaseModel):
description: Optional[str] = None
# The parameters the functions accepts, described as a JSON Schema object.
# To describe a function that accepts no parameters, provide the value {"type": "object", "properties": {}}.
parameters: Dict[str, Any] = {"type": "object", "properties": {}}
parameters: Dict[str, Any] = Field(
default_factory=lambda: {"type": "object", "properties": {}},
description="JSON Schema object describing function parameters",
)
entrypoint: Optional[Callable] = None
strict: Optional[bool] = None

Expand All @@ -25,30 +30,52 @@ def to_dict(self) -> Dict[str, Any]:
return self.model_dump(exclude_none=True, include={"name", "description", "parameters", "strict"})

@classmethod
def from_callable(cls, c: Callable) -> "Function":
from inspect import getdoc
def from_callable(cls, c: Callable, agent: Optional[Any] = None) -> "Function":
from inspect import getdoc, signature
from functools import partial
from phi.utils.json_schema import get_json_schema

function_name = c.__name__
parameters = {"type": "object", "properties": {}, "required": []}
try:
# logger.info(f"Getting type hints for {c}")
sig = signature(c)
type_hints = get_type_hints(c)
# logger.info(f"Type hints for {c}: {type_hints}")
# logger.info(f"Getting JSON schema for {type_hints}")
parameters = get_json_schema(type_hints)
# logger.info(f"JSON schema for {c}: {parameters}")
# logger.debug(f"Type hints for {c.__name__}: {type_hints}")

# If function accepts the agent parameter, create a partial with the agent
# And remove the agent parameter from the type hints
if agent is not None and "agent" in sig.parameters:
c = partial(c, agent=agent)
del type_hints["agent"]
# logger.info(f"Type hints for {function_name}: {type_hints}")

# Filter out return type and only process parameters
param_type_hints = {
name: type_hints[name] for name in sig.parameters if name in type_hints and name != "return"
}
# logger.info(f"Arguments for {function_name}: {param_type_hints}")

# Get JSON schema for parameters only
parameters = get_json_schema(param_type_hints)

# Mark a field as required if it has no default value
parameters["required"] = [
name
for name, param in sig.parameters.items()
if param.default == param.empty and name != "self" and name != "agent"
]

logger.debug(f"JSON schema for {function_name}: {parameters}")
except Exception as e:
logger.warning(f"Could not parse args for {c.__name__}: {e}")
logger.warning(f"Could not parse args for {function_name}: {e}", exc_info=True)

return cls(
name=c.__name__,
name=function_name,
description=getdoc(c),
parameters=parameters,
entrypoint=validate_call(c),
)

def get_type_name(self, t):
def get_type_name(self, t: Type[T]):
name = str(t)
if "list" in name or "dict" in name:
return name
Expand Down
94 changes: 64 additions & 30 deletions phi/utils/json_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,59 +8,93 @@ def get_json_type_for_py_type(arg: str) -> str:
Get the JSON schema type for a given type.
:param arg: The type to get the JSON schema type for.
:return: The JSON schema type.
See: https://json-schema.org/understanding-json-schema/reference/type.html#type-specific-keywords
"""
# logger.info(f"Getting JSON type for: {arg}")
if arg in ("int", "float"):
if arg in ("int", "float", "complex", "Decimal"):
return "number"
elif arg == "str":
elif arg in ("str", "string"):
return "string"
elif arg == "bool":
elif arg in ("bool", "boolean"):
return "boolean"
elif arg in ("NoneType", "None"):
return "null"
return arg
elif arg in ("list", "tuple", "set", "frozenset"):
return "array"
elif arg in ("dict", "mapping"):
return "object"

# If the type is not recognized, return "object"
return "object"


def get_json_schema_for_arg(t: Any) -> Optional[Any]:
def get_json_schema_for_arg(t: Any) -> Optional[Dict[str, Any]]:
# logger.info(f"Getting JSON schema for arg: {t}")
json_schema = None
type_args = get_args(t)
# logger.info(f"Type args: {type_args}")
type_origin = get_origin(t)
# logger.info(f"Type origin: {type_origin}")

if type_origin is not None:
if type_origin is list:
json_schema_for_items = get_json_schema_for_arg(type_args[0])
json_schema = {"type": "array", "items": json_schema_for_items}
if type_origin in (list, tuple, set, frozenset):
json_schema_for_items = get_json_schema_for_arg(type_args[0]) if type_args else {"type": "string"}
return {"type": "array", "items": json_schema_for_items}
elif type_origin is dict:
json_schema = {"type": "object", "properties": {}}
# Handle both key and value types for dictionaries
key_schema = get_json_schema_for_arg(type_args[0]) if type_args else {"type": "string"}
value_schema = get_json_schema_for_arg(type_args[1]) if len(type_args) > 1 else {"type": "string"}
return {"type": "object", "propertyNames": key_schema, "additionalProperties": value_schema}
elif type_origin is Union:
json_schema = {"type": [get_json_type_for_py_type(arg.__name__) for arg in type_args]}
else:
json_schema = {"type": get_json_type_for_py_type(t.__name__)}
return json_schema
types = []
for arg in type_args:
if arg is not type(None):
try:
schema = get_json_schema_for_arg(arg)
if schema:
types.append(schema)
except Exception:
continue
return {"anyOf": types} if types else None

return {"type": get_json_type_for_py_type(t.__name__)}


def get_json_schema(type_hints: Dict[str, Any]) -> Dict[str, Any]:
json_schema: Dict[str, Any] = {"type": "object", "properties": {}, "required": []}
json_schema: Dict[str, Any] = {
"type": "object",
"properties": {},
"additionalProperties": False, # Prevent additional properties by default
}

for k, v in type_hints.items():
# logger.info(f"Parsing arg: {k} | {v}")
if k == "return":
continue

# Check if type is Optional (Union with NoneType)
type_origin = get_origin(v)
type_args = get_args(v)
is_optional = type_origin is Union and len(type_args) == 2 and type(None) in type_args
if not is_optional:
json_schema["required"].append(k)

arg_json_schema = get_json_schema_for_arg(v)
if arg_json_schema is not None:
# logger.info(f"json_schema: {arg_json_schema}")
json_schema["properties"][k] = arg_json_schema
else:
logger.warning(f"Could not parse argument {k} of type {v}")
try:
# Check if type is Optional (Union with NoneType)
type_origin = get_origin(v)
type_args = get_args(v)
is_optional = type_origin is Union and len(type_args) == 2 and any(arg is type(None) for arg in type_args)

# Get the actual type if it's Optional
if is_optional:
v = next(arg for arg in type_args if arg is not type(None))

arg_json_schema = get_json_schema_for_arg(v)
if arg_json_schema is not None:
if is_optional:
# Handle null type for optional fields
if isinstance(arg_json_schema["type"], list):
arg_json_schema["type"].append("null")
else:
arg_json_schema["type"] = [arg_json_schema["type"], "null"]

json_schema["properties"][k] = arg_json_schema

else:
logger.warning(f"Could not parse argument {k} of type {v}")
except Exception as e:
logger.error(f"Error processing argument {k}: {str(e)}")
continue

return json_schema
4 changes: 3 additions & 1 deletion scripts/create_venv.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,13 @@ main() {
print_info "Creating python3 venv: ${VENV_DIR}"
python3 -m venv "${VENV_DIR}"

# Activate the venv
source "${VENV_DIR}/bin/activate"

print_info "Installing base python packages"
pip3 install --upgrade pip pip-tools twine build

# Install workspace
source "${VENV_DIR}/bin/activate"
source "${CURR_DIR}/install.sh"

print_heading "Activate using: source ${VENV_DIR}/bin/activate"
Expand Down

0 comments on commit 5f13fbf

Please sign in to comment.