Skip to content

Commit

Permalink
ruff and black fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
strickvl committed May 2, 2024
1 parent 8893541 commit 3be142f
Show file tree
Hide file tree
Showing 6 changed files with 1,347 additions and 1,254 deletions.
10 changes: 4 additions & 6 deletions spacy_llm/models/langchain/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,12 +172,10 @@ def register_models() -> None:


@registry.llm_queries("spacy.CallLangChain.v1")
def query_langchain() -> (
Callable[
["langchain_community.llms.BaseLLM", Iterable[Iterable[Any]]],
Iterable[Iterable[Any]],
]
):
def query_langchain() -> Callable[
["langchain_community.llms.BaseLLM", Iterable[Iterable[Any]]],
Iterable[Iterable[Any]],
]:
"""Returns query Callable for LangChain.
RETURNS (Callable[["langchain_community.llms.BaseLLM", Iterable[Iterable[Any]]], Iterable[Iterable[Any]]]): Callable
executing simple prompts on the specified LangChain model.
Expand Down
19 changes: 11 additions & 8 deletions spacy_llm/models/rest/ollama/model.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import os
import warnings
from enum import Enum
from typing import Any, Dict, Iterable, List, Sized

import requests # type: ignore[import]
import srsly # type: ignore[import]
from requests import HTTPError

from ..base import REST
Expand All @@ -15,12 +12,13 @@ class Endpoints(str, Enum):
EMBEDDINGS = "http://localhost:11434/api/embeddings"
TAGS = "http://localhost:11434/api/tags"


class Ollama(REST):
@property
def credentials(self) -> Dict[str, str]:
# No credentials needed for local Ollama server
return {}

def _verify_auth(self) -> None:
# Healthcheck: Verify connectivity to Ollama server
try:
Expand All @@ -46,7 +44,12 @@ def _request(json_data: Dict[str, Any]) -> Dict[str, Any]:
call_method=requests.post,
url=self._endpoint,
headers=headers,
json={**json_data, **self._config, "model": self._name, "stream": False},
json={
**json_data,
**self._config,
"model": self._name,
"stream": False,
},
timeout=self._max_request_time,
)
try:
Expand All @@ -57,15 +60,15 @@ def _request(json_data: Dict[str, Any]) -> Dict[str, Any]:
raise ValueError(
f"Request to Ollama API failed: {res_content}"
) from ex

response = r.json()

if "error" in response:
if self._strict:
raise ValueError(f"API call failed: {response['error']}.")
else:
assert isinstance(prompts_for_doc, Sized)
return {"error": [response['error']] * len(prompts_for_doc)}
return {"error": [response["error"]] * len(prompts_for_doc)}

return response

Expand Down Expand Up @@ -163,5 +166,5 @@ def _get_context_lengths() -> Dict[str, int]:
"duckdb-nsql": 4096,
"alfred": 4096,
"notus": 4096,
"snowflake-arctic-embed": 4096
"snowflake-arctic-embed": 4096,
}
Loading

0 comments on commit 3be142f

Please sign in to comment.