Skip to content

Commit

Permalink
Refactor the model names map into one map per model provider
Browse files Browse the repository at this point in the history
  • Loading branch information
jamesmurdza committed Jan 21, 2025
1 parent a2eccb6 commit eeb2707
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 21 deletions.
5 changes: 4 additions & 1 deletion os_computer_use/llm_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,16 @@ class LLMProvider:
base_url = None
api_key = None

# Mapping of model aliases
aliases = []

def __init__(self, model):
# Validate base URL and API key
if not self.base_url:
raise ValueError("No base URL provided.")
if not self.api_key:
raise ValueError("No API key provided.")
self.model = model
self.model = self.aliases.get(model, model)
print(f"Using {self.__class__.__name__} with {self.model}")
# Initialize OpenAI client
self.client = OpenAI(base_url=self.base_url, api_key=self.api_key)
Expand Down
30 changes: 10 additions & 20 deletions os_computer_use/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,40 +7,28 @@
# Load environment variables from .env file
load_dotenv()

# Model names can vary from provider to provider, and are standardized here:
model_names = {
"llama": {"llama3.2": "llama3.2-90b-vision", "llama3.3": "llama3.3-70b"},
"openrouter": {
"llama3.2": "meta-llama/llama-3.2-90b-vision-instruct",
},
"fireworks": {
"llama3.2": "accounts/fireworks/models/llama-v3p2-90b-vision-instruct",
"llama3.3": "accounts/fireworks/models/llama-v3p3-70b-instruct",
},
"deepseek": {"deepseek-chat": "deepseek-chat"},
"gemini": {
"gemini-1.5-flash": "gemini-1.5-flash",
"gemini-2.0-flash": "gemini-2.0-flash-exp",
},
}


# LLM providers use the OpenAI specification and require a base URL:


class LlamaProvider(LLMProvider):
base_url = "https://api.llama-api.com"
api_key = os.getenv("LLAMA_API_KEY")
aliases = {"llama3.2": "llama3.2-90b-vision", "llama3.3": "llama3.3-70b"}


class OpenRouterProvider(LLMProvider):
base_url = "https://openrouter.ai/api/v1"
api_key = os.getenv("OPENROUTER_API_KEY")
aliases = {"llama3.2": "meta-llama/llama-3.2-90b-vision-instruct"}


class FireworksProvider(LLMProvider):
base_url = "https://api.fireworks.ai/inference/v1"
api_key = os.getenv("FIREWORKS_API_KEY")
aliases = {
"llama3.2": "accounts/fireworks/models/llama-v3p2-90b-vision-instruct",
"llama3.3": "accounts/fireworks/models/llama-v3p3-70b-instruct",
}


class DeepSeekProvider(LLMProvider):
Expand All @@ -57,5 +45,7 @@ class GeminiProvider(LLMProvider):

# grounding_model = ShowUIProvider()
grounding_model = OSAtlasProvider()
vision_model = FireworksProvider(model_names["fireworks"]["llama3.2"])
action_model = FireworksProvider(model_names["fireworks"]["llama3.3"])

vision_model = FireworksProvider("llama3.2")

action_model = FireworksProvider("llama3.3")

0 comments on commit eeb2707

Please sign in to comment.