Skip to content

Commit

Permalink
Add Cohere in Bedrock
Browse files Browse the repository at this point in the history
  • Loading branch information
viveksilimkhan1 committed Oct 31, 2023
1 parent 6102c35 commit fe38a95
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 1 deletion.
16 changes: 15 additions & 1 deletion spacy_llm/models/rest/bedrock/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class Models(str, Enum):
TITAN_LITE = "amazon.titan-text-lite-v1"
AI21_JURASSIC_ULTRA = "ai21.j2-ultra-v1"
AI21_JURASSIC_MID = "ai21.j2-mid-v1"
COHERE_COMMAND = "cohere.command-text-v14"


TITAN_PARAMS = ["maxTokenCount", "stopSequences", "temperature", "topP"]
Expand All @@ -24,6 +25,7 @@ class Models(str, Enum):
"presencePenalty",
"frequencyPenalty",
]
COHERE_PARAMS = ["max_tokens", "temperature"]


class Bedrock(REST):
Expand All @@ -45,6 +47,8 @@ def __init__(
config_params = TITAN_PARAMS
if self._model_id in [Models.AI21_JURASSIC_ULTRA, Models.AI21_JURASSIC_MID]:
config_params = AI21_JURASSIC_PARAMS
if self._model_id in [Models.COHERE_COMMAND]:
config_params = COHERE_PARAMS

for i in config_params:
self._config[i] = config[i]
Expand Down Expand Up @@ -141,6 +145,10 @@ def _request(json_data: str) -> str:
responses = json.loads(r["body"].read().decode())["completions"][0][
"data"
]["text"]
elif self._model_id in [Models.COHERE_COMMAND]:
responses = json.loads(r["body"].read().decode())["generations"][0][
"text"
]

return responses

Expand All @@ -151,7 +159,12 @@ def _request(json_data: str) -> str:
{"inputText": prompt, "textGenerationConfig": self._config}
)
)
if self._model_id in [Models.AI21_JURASSIC_ULTRA, Models.AI21_JURASSIC_MID]:
elif self._model_id in [
Models.AI21_JURASSIC_ULTRA,
Models.AI21_JURASSIC_MID,
]:
responses = _request(json.dumps({"prompt": prompt, **self._config}))
elif self._model_id in [Models.COHERE_COMMAND]:
responses = _request(json.dumps({"prompt": prompt, **self._config}))

api_responses.append(responses)
Expand Down Expand Up @@ -181,4 +194,5 @@ def get_model_names(self) -> Tuple[str, ...]:
"amazon.titan-text-lite-v1",
"ai21.j2-ultra-v1",
"ai21.j2-mid-v1",
"cohere.command-text-v14",
)
2 changes: 2 additions & 0 deletions spacy_llm/models/rest/bedrock/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ def bedrock(
presencePenalty=_DEFAULT_PRESENCE_PENALTY,
frequencyPenalty=_DEFAULT_FREQUENCY_PENALTY,
stop_sequences=_DEFAULT_STOP_SEQUENCES,
# Params for Cohere models
max_tokens=_DEFAULT_MAX_TOKEN_COUNT,
),
max_tries: int = _DEFAULT_RETRIES,
) -> Callable[[Iterable[str]], Iterable[str]]:
Expand Down

0 comments on commit fe38a95

Please sign in to comment.