Skip to content

Commit

Permalink
Merge pull request #240 from VRSEN/fix/bug-restarting-active-run
Browse files Browse the repository at this point in the history
Fix active run restart, zero top_p and other issues
  • Loading branch information
bonk1t authored Feb 26, 2025
2 parents 349e7b2 + b746634 commit e1c896c
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 48 deletions.
41 changes: 18 additions & 23 deletions agency_swarm/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def __init__(
self,
id: str = None,
name: str = None,
description: str = None,
description: str = "",
instructions: str = "",
tools: List[
Union[
Expand All @@ -85,7 +85,7 @@ def __init__(
] = None,
tool_resources: ToolResources = None,
temperature: float = None,
top_p: float = None,
top_p: float = 1.0,
response_format: Union[str, dict, type] = "auto",
tools_folder: str = None,
files_folder: Union[List[str], str] = None,
Expand All @@ -111,12 +111,12 @@ def __init__(
Parameters:
id (str, optional): Loads the assistant from OpenAI assistant ID. Assistant will be created or loaded from settings if ID is not provided. Defaults to None.
name (str, optional): Name of the agent. Defaults to the class name if not provided.
description (str, optional): A brief description of the agent's purpose. Defaults to None.
description (str, optional): A brief description of the agent's purpose. Defaults to empty string.
instructions (str, optional): Path to a file containing specific instructions for the agent. Defaults to an empty string.
tools (List[Union[Type[BaseTool], Type[Retrieval], Type[CodeInterpreter]]], optional): A list of tools (as classes) that the agent can use. Defaults to an empty list.
tool_resources (ToolResources, optional): A set of resources that are used by the assistant's tools. The resources are specific to the type of tool. For example, the code_interpreter tool requires a list of file IDs, while the file_search tool requires a list of vector store IDs. Defaults to None.
temperature (float, optional): The temperature parameter for the OpenAI API. Defaults to None.
top_p (float, optional): The top_p parameter for the OpenAI API. Defaults to None.
top_p (float, optional): The top_p parameter for the OpenAI API. Defaults to 1.0.
response_format (Union[str, Dict, type], optional): The response format for the OpenAI API. If BaseModel is provided, it will be converted to a response format. Defaults to None.
tools_folder (str, optional): Path to a directory containing tools associated with the agent. Each tool must be defined in a separate file. File must be named as the class name of the tool. Defaults to None.
files_folder (Union[List[str], str], optional): Path or list of paths to directories containing files associated with the agent. Defaults to None.
Expand Down Expand Up @@ -223,7 +223,7 @@ def init_oai(self):
if self.temperature is None
else self.temperature
)
self.top_p = self.top_p or self.assistant.top_p
self.top_p = self.top_p if self.top_p is not None else self.assistant.top_p
self.response_format = (
self.response_format or self.assistant.response_format
)
Expand Down Expand Up @@ -315,15 +315,12 @@ def _create_assistant(self):
extra_body = {}

# o-series models
if params['model'].startswith('o'):
params['temperature'] = None
params['top_p'] = None
extra_body['reasoning_effort'] = self.reasoning_effort

return self.client.beta.assistants.create(
**params,
extra_body=extra_body
)
if params["model"].startswith("o"):
params["temperature"] = None
params["top_p"] = None
extra_body["reasoning_effort"] = self.reasoning_effort

return self.client.beta.assistants.create(**params, extra_body=extra_body)

if self.assistant.tool_resources:
self.tool_resources = self.assistant.tool_resources.model_dump()
Expand Down Expand Up @@ -360,21 +357,19 @@ def _update_assistant(self):
"metadata": self.metadata,
"model": self.model,
}

extra_body = {}

# o-series models
if params['model'].startswith('o'):
params['temperature'] = None
params['top_p'] = None
extra_body['reasoning_effort'] = self.reasoning_effort
if params["model"].startswith("o"):
params["temperature"] = None
params["top_p"] = None
extra_body["reasoning_effort"] = self.reasoning_effort

self.assistant = self.client.beta.assistants.update(
self.id,
**params,
extra_body=extra_body
self.id, **params, extra_body=extra_body
)

self._update_settings()

def _upload_files(self):
Expand Down
34 changes: 25 additions & 9 deletions agency_swarm/threads/thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ def _create_run(
tool_choice,
temperature=None,
response_format: Optional[dict] = None,
):
) -> None:
try:
if event_handler:
with self.client.beta.threads.runs.stream(
Expand Down Expand Up @@ -503,19 +503,29 @@ def _create_run(
run_id=match.groups()[1],
check_status=False,
)
# Reattempt creating a new run after cancellation.
return self._create_run(
recipient_agent,
additional_instructions,
event_handler,
tool_choice,
temperature=temperature,
response_format=response_format,
)
elif (
"The server had an error processing your request" in e.message
and self._num_run_retries < 3
):
time.sleep(1)
self._create_run(
self._num_run_retries += 1
return self._create_run(
recipient_agent,
additional_instructions,
event_handler,
tool_choice,
temperature=temperature,
response_format=response_format,
)
self._num_run_retries += 1
else:
raise e

Expand Down Expand Up @@ -555,14 +565,20 @@ def _cancel_run(self, thread_id=None, run_id=None, check_status=True):
return

try:
actual_thread_id = thread_id or self.id
actual_run_id = run_id or (self._run.id if self._run else None)

if not actual_run_id:
return # Can't cancel without a run ID

self._run = self.client.beta.threads.runs.cancel(
thread_id=self.id, run_id=self._run.id
thread_id=actual_thread_id, run_id=actual_run_id
)
except BadRequestError as e:
if "Cannot cancel run with status" in e.message:
self._run = self.client.beta.threads.runs.poll(
thread_id=thread_id or self.id,
run_id=run_id or self._run.id,
thread_id=actual_thread_id,
run_id=actual_run_id,
poll_interval_ms=500,
)
else:
Expand Down Expand Up @@ -634,7 +650,7 @@ def execute_tool(

if not tool:
return (
f"Error: Function {tool_call.function.name} not found. Available functions: {[func.__name__ for func in funcs]}",
f"Error: Function {tool_name} not found. Available functions: {[func.__name__ for func in funcs]}",
False,
)

Expand All @@ -645,8 +661,8 @@ def execute_tool(
tool = tool(**args)

# check if the tool is already called
for tool_name in [name for name, _ in tool_outputs_and_names]:
if tool_name == tool_name and (
for output_tool_name in [name for name, _ in tool_outputs_and_names]:
if output_tool_name == tool_name and (
hasattr(tool, "ToolConfig")
and hasattr(tool.ToolConfig, "one_call_at_a_time")
and tool.ToolConfig.one_call_at_a_time
Expand Down
6 changes: 3 additions & 3 deletions tests/test_agency.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,14 +204,14 @@ def test_2_load_agent(self):
agent3.tools = self.__class__.agent1.tools
agent3.top_p = self.__class__.agency.top_p
agent3.file_search = self.__class__.agent1.file_search
agent3.temperature = self.__class__.agent1.temperature
agent3.model = self.__class__.agent1.model
agent3 = agent3.init_oai()

print("agent3", agent3.assistant.model_dump())
print("agent1", self.__class__.agent1.assistant.model_dump())

self.assertTrue(self.__class__.agent1.id == agent3.id)

# check that assistant settings match
# Check that the agents have the same settings
self.assertTrue(
agent3._check_parameters(self.__class__.agent1.assistant.model_dump())
)
Expand Down
8 changes: 5 additions & 3 deletions tests/test_communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,13 @@ def run(self):

def test_send_message_swarm(self):
response = self.agency.get_completion(
"Hello, can you send me to customer support? If tool responds says that you have NOT been rerouted, or if there is another error, please say 'error'"
"Hello, can you send me to customer support? If tool responds says that you have NOT been rerouted, or if there is another error, please say 'error'",
yield_messages=False,
)
self.assertFalse(
"error" in response.lower(), self.agency.main_thread.thread_url
)
response = self.agency.get_completion("Who are you?")
response = self.agency.get_completion("Who are you?", yield_messages=False)
self.assertTrue(
"customer support" in response.lower(), self.agency.main_thread.thread_url
)
Expand All @@ -62,8 +63,9 @@ def test_send_message_swarm(self):
self.assertEqual(main_thread.recipient_agent, self.customer_support)

# check if all messages in the same thread (this is how Swarm works)
messages = main_thread.get_messages()
self.assertTrue(
len(main_thread.get_messages()) >= 4
len(messages) >= 4
) # sometimes run does not cancel immediately, so there might be 5 messages

def test_send_message_double_recepient_error(self):
Expand Down
20 changes: 10 additions & 10 deletions tests/test_tool_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@ def setUp(self):

def test_move_file_tool(self):
tool = ToolFactory.from_langchain_tool(MoveFileTool())
print(json.dumps(tool.openai_schema, indent=4))
# print(json.dumps(tool.openai_schema, indent=4))
print(tool)

tool = tool(
destination_path="Move a file from one folder to another",
source_path="Move a file from one folder to another",
)

print(tool.model_dump())
# print(tool.model_dump())

tool.run()

Expand Down Expand Up @@ -77,15 +77,15 @@ class UserRelationships(BaseTool):
title="Relationship Type",
)

print("schema", json.dumps(UserRelationships.openai_schema, indent=4))
# print("schema", json.dumps(UserRelationships.openai_schema, indent=4))

# print("ref", json.dumps(reference_schema(deref_schema), indent=4))

tool = ToolFactory.from_openai_schema(
UserRelationships.openai_schema, lambda x: x
)

print(json.dumps(tool.openai_schema, indent=4))
# print(json.dumps(tool.openai_schema, indent=4))
user_detail_instance = {
"id": 1,
"age": 20,
Expand Down Expand Up @@ -114,9 +114,9 @@ def remove_empty_fields(d):

cleaned_schema = remove_empty_fields(user_relationships_schema)

print("clean schema", json.dumps(cleaned_schema, indent=4))
# print("clean schema", json.dumps(cleaned_schema, indent=4))

print("tool schema", json.dumps(tool.openai_schema, indent=4))
# print("tool schema", json.dumps(tool.openai_schema, indent=4))

tool_schema = tool.openai_schema

Expand Down Expand Up @@ -149,7 +149,7 @@ def test_custom_tool(self):

tool2 = ToolFactory.from_openai_schema(schema, lambda x: x)

print(json.dumps(tool.openai_schema, indent=4))
# print(json.dumps(tool.openai_schema, indent=4))

tool = tool(query="John Doe")

Expand All @@ -167,15 +167,15 @@ def test_get_weather_openapi(self):

self.assertFalse(tools[0].openai_schema.get("strict", False))

print(json.dumps(tools[0].openai_schema, indent=4))
# print(json.dumps(tools[0].openai_schema, indent=4))

def test_relevance_openapi_schema(self):
with open("./data/schemas/relevance.json", "r") as f:
tools = ToolFactory.from_openapi_schema(
f.read(), {"Authorization": os.environ.get("TEST_SCHEMA_API_KEY")}
)

print(json.dumps(tools[0].openai_schema, indent=4))
# print(json.dumps(tools[0].openai_schema, indent=4))

async def gather_output():
output = await tools[0](requestBody={"text": "test"}).run()
Expand Down Expand Up @@ -209,7 +209,7 @@ def test_ga4_openapi_schema(self):
with open("./data/schemas/ga4.json", "r") as f:
tools = ToolFactory.from_openapi_schema(f.read(), {})

print(json.dumps(tools[0].openai_schema, indent=4))
# print(json.dumps(tools[0].openai_schema, indent=4))

def test_import_from_file(self):
tool = ToolFactory.from_file("./data/tools/ExampleTool1.py")
Expand Down

0 comments on commit e1c896c

Please sign in to comment.