Skip to content

Commit

Permalink
dict -> tool name
Browse files Browse the repository at this point in the history
  • Loading branch information
ccurme committed Aug 19, 2024
1 parent a68fa35 commit 7202cdd
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def chat_model_params(self) -> dict:
@property
def tool_choice_value(self) -> Optional[str]:
"""Value to use for tool choice when used in tests."""
return "dict"
return "tool_name"

@pytest.mark.xfail(reason="Not yet supported.")
def test_usage_metadata_streaming(self, model: BaseChatModel) -> None:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import base64
import json
from typing import List, Optional, Union
from typing import List, Optional

import httpx
import pytest
Expand Down Expand Up @@ -170,11 +170,8 @@ def test_stop_sequence(self, model: BaseChatModel) -> None:
def test_tool_calling(self, model: BaseChatModel) -> None:
if not self.has_tool_calling:
pytest.skip("Test requires tool calling.")
if self.tool_choice_value == "dict":
tool_choice: Union[dict, str, None] = {
"type": "function",
"function": {"name": "magic_function"},
}
if self.tool_choice_value == "tool_name":
tool_choice: Optional[str] = "magic_function"
else:
tool_choice = self.tool_choice_value
model_with_tools = model.bind_tools([magic_function], tool_choice=tool_choice)
Expand All @@ -195,11 +192,8 @@ def test_tool_calling_with_no_arguments(self, model: BaseChatModel) -> None:
if not self.has_tool_calling:
pytest.skip("Test requires tool calling.")

if self.tool_choice_value == "dict":
tool_choice: Union[dict, str, None] = {
"type": "function",
"function": {"name": "magic_function_no_args"},
}
if self.tool_choice_value == "tool_name":
tool_choice: Optional[str] = "magic_function_no_args"
else:
tool_choice = self.tool_choice_value
model_with_tools = model.bind_tools(
Expand Down Expand Up @@ -228,11 +222,8 @@ def test_bind_runnables_as_tools(self, model: BaseChatModel) -> None:
name="greeting_generator",
description="Generate a greeting in a particular style of speaking.",
)
if self.tool_choice_value == "dict":
tool_choice: Union[dict, str, None] = {
"type": "function",
"function": {"name": "greeting_generator"},
}
if self.tool_choice_value == "tool_name":
tool_choice: Optional[str] = "greeting_generator"
else:
tool_choice = self.tool_choice_value
model_with_tools = model.bind_tools([tool_], tool_choice=tool_choice)
Expand Down

0 comments on commit 7202cdd

Please sign in to comment.