Skip to content

Commit

Permalink
Samplers: Add frequency and presence penalty
Browse files Browse the repository at this point in the history
Un-alias repetition penalty from the frequency penalty parameter.

Signed-off-by: kingbri <[email protected]>
  • Loading branch information
kingbri1 committed Dec 25, 2023
1 parent 442bb59 commit a8fb3bd
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 12 deletions.
16 changes: 4 additions & 12 deletions OAI/types/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,6 @@ class CommonCompletionRequest(BaseModel):
# Default to 150 as 16 makes no sense as a default
max_tokens: Optional[int] = 150

# Aliased to repetition_penalty
frequency_penalty: Optional[float] = Field(
description="Aliased to Repetition Penalty", default=0.0
)

# Sampling params
token_healing: Optional[bool] = False
temperature: Optional[float] = 1.0
Expand All @@ -73,6 +68,8 @@ class CommonCompletionRequest(BaseModel):
typical: Optional[float] = 1.0
min_p: Optional[float] = 0.0
tfs: Optional[float] = 1.0
frequency_penalty: Optional[float] = 0.0
presence_penalty: Optional[float] = 0.0
repetition_penalty: Optional[float] = 1.0
repetition_decay: Optional[int] = 0
mirostat_mode: Optional[int] = 0
Expand All @@ -94,13 +91,6 @@ def to_gen_params(self):
if isinstance(self.stop, str):
self.stop = [self.stop]

# Set repetition_penalty to frequency_penalty if repetition_penalty
# isn't already defined
if (
self.repetition_penalty is None or self.repetition_penalty == 1.0
) and self.frequency_penalty:
self.repetition_penalty = self.frequency_penalty

return {
"stop": self.stop,
"max_tokens": self.max_tokens,
Expand All @@ -116,6 +106,8 @@ def to_gen_params(self):
"typical": self.typical,
"min_p": self.min_p,
"tfs": self.tfs,
"frequency_penalty": self.frequency_penalty,
"presence_penalty": self.presence_penalty,
"repetition_penalty": self.repetition_penalty,
"repetition_range": unwrap(self.repetition_range, -1),
"repetition_decay": self.repetition_decay,
Expand Down
22 changes: 22 additions & 0 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,22 @@ def generate_gen(self, prompt: str, **kwargs):
"installed ExLlamaV2 version."
)

if (unwrap(kwargs.get("frequency_penalty"), 0.0)) != 0.0 and not hasattr(
gen_settings, "token_frequency_penalty"
):
logger.warning(
"Frequency penalty is not supported by the currently "
"installed ExLlamaV2 version."
)

if (unwrap(kwargs.get("presence_penalty"), 0.0)) != 0.0 and not hasattr(
gen_settings, "token_presence_penalty"
):
logger.warning(
"Presence penalty is not supported by the currently "
"installed ExLlamaV2 version."
)

# Apply settings
gen_settings.temperature = unwrap(kwargs.get("temperature"), 1.0)
gen_settings.temperature_last = unwrap(kwargs.get("temperature_last"), False)
Expand All @@ -543,6 +559,12 @@ def generate_gen(self, prompt: str, **kwargs):
# Default tau and eta fallbacks don't matter if mirostat is off
gen_settings.mirostat_tau = unwrap(kwargs.get("mirostat_tau"), 1.5)
gen_settings.mirostat_eta = unwrap(kwargs.get("mirostat_eta"), 0.1)
gen_settings.token_frequency_penalty = unwrap(
kwargs.get("frequency_penalty"), 0.0
)
gen_settings.token_presence_penalty = unwrap(
kwargs.get("presence_penalty"), 0.0
)
gen_settings.token_repetition_penalty = unwrap(
kwargs.get("repetition_penalty"), 1.0
)
Expand Down

0 comments on commit a8fb3bd

Please sign in to comment.