diff --git a/spacy_llm/models/rest/bedrock/model.py b/spacy_llm/models/rest/bedrock/model.py index 149e292a..38663483 100644 --- a/spacy_llm/models/rest/bedrock/model.py +++ b/spacy_llm/models/rest/bedrock/model.py @@ -14,6 +14,8 @@ class Models(str, Enum): AI21_JURASSIC_ULTRA = "ai21.j2-ultra-v1" AI21_JURASSIC_MID = "ai21.j2-mid-v1" COHERE_COMMAND = "cohere.command-text-v14" + ANTHROPIC_CLAUDE = "anthropic.claude-v2" + ANTHROPIC_CLAUDE_INSTANT = "anthropic.claude-instant-v1" TITAN_PARAMS = ["maxTokenCount", "stopSequences", "temperature", "topP"] @@ -26,6 +28,13 @@ class Models(str, Enum): "frequencyPenalty", ] COHERE_PARAMS = ["max_tokens", "temperature"] +ANTHROPIC_PARAMS = [ + "max_tokens_to_sample", + "temperature", + "top_k", + "top_p", + "stop_sequences", +] class Bedrock(REST): @@ -49,6 +58,8 @@ def __init__( config_params = AI21_JURASSIC_PARAMS if self._model_id in [Models.COHERE_COMMAND]: config_params = COHERE_PARAMS + if self._model_id in [Models.ANTHROPIC_CLAUDE_INSTANT, Models.ANTHROPIC_CLAUDE]: + config_params = ANTHROPIC_PARAMS for i in config_params: self._config[i] = config[i] @@ -149,6 +160,11 @@ def _request(json_data: str) -> str: responses = json.loads(r["body"].read().decode())["generations"][0][ "text" ] + elif self._model_id in [ + Models.ANTHROPIC_CLAUDE_INSTANT, + Models.ANTHROPIC_CLAUDE, + ]: + responses = json.loads(r["body"].read().decode())["completion"] return responses @@ -166,7 +182,15 @@ def _request(json_data: str) -> str: responses = _request(json.dumps({"prompt": prompt, **self._config})) elif self._model_id in [Models.COHERE_COMMAND]: responses = _request(json.dumps({"prompt": prompt, **self._config})) - + elif self._model_id in [ + Models.ANTHROPIC_CLAUDE_INSTANT, + Models.ANTHROPIC_CLAUDE, + ]: + responses = _request( + json.dumps( + {"prompt": f"\n\nHuman: {prompt}\n\nAssistant:", **self._config} + ) + ) api_responses.append(responses) return api_responses @@ -195,4 +219,6 @@ def get_model_names(self) -> Tuple[str, ...]: "ai21.j2-ultra-v1", "ai21.j2-mid-v1", "cohere.command-text-v14", + "anthropic.claude-v2", + "anthropic.claude-instant-v1", ) diff --git a/spacy_llm/models/rest/bedrock/registry.py b/spacy_llm/models/rest/bedrock/registry.py index 9ca8f958..aeaa4534 100644 --- a/spacy_llm/models/rest/bedrock/registry.py +++ b/spacy_llm/models/rest/bedrock/registry.py @@ -9,10 +9,12 @@ _DEFAULT_TEMPERATURE: float = 0.0 _DEFAULT_MAX_TOKEN_COUNT: int = 512 _DEFAULT_TOP_P: int = 1 +_DEFAULT_TOP_K: int = 250 _DEFAULT_STOP_SEQUENCES: List[str] = [] _DEFAULT_COUNT_PENALTY: Dict[str, Any] = {"scale": 0} _DEFAULT_PRESENCE_PENALTY: Dict[str, Any] = {"scale": 0} _DEFAULT_FREQUENCY_PENALTY: Dict[str, Any] = {"scale": 0} +_DEFAULT_MAX_TOKEN_TO_SAMPLE: int = 300 @registry.llm_models("spacy.Bedrock.v1") @@ -33,6 +35,10 @@ def bedrock( stop_sequences=_DEFAULT_STOP_SEQUENCES, # Params for Cohere models max_tokens=_DEFAULT_MAX_TOKEN_COUNT, + # Params for Anthropic models + max_tokens_to_sample=_DEFAULT_MAX_TOKEN_TO_SAMPLE, + top_k=_DEFAULT_TOP_K, + top_p=_DEFAULT_TOP_P, ), max_tries: int = _DEFAULT_RETRIES, ) -> Callable[[Iterable[str]], Iterable[str]]: